001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.math4.optim.nonlinear.scalar.noderiv;
018
019import java.util.Arrays;
020import java.util.List;
021import java.util.ArrayList;
022import java.util.Collections;
023import java.util.Comparator;
024import org.apache.commons.rng.UniformRandomProvider;
025import org.apache.commons.rng.simple.RandomSource;
026import org.apache.commons.math4.analysis.MultivariateFunction;
027import org.apache.commons.math4.exception.DimensionMismatchException;
028import org.apache.commons.math4.exception.NotStrictlyPositiveException;
029import org.apache.commons.math4.exception.MathInternalError;
030import org.apache.commons.math4.optim.PointValuePair;
031import org.apache.commons.math4.optim.OptimizationData;
032import org.apache.commons.math4.optim.ConvergenceChecker;
033import org.apache.commons.math4.optim.SimpleValueChecker;
034import org.apache.commons.math4.optim.nonlinear.scalar.MultivariateOptimizer;
035import org.apache.commons.math4.optim.nonlinear.scalar.GoalType;
036import org.apache.commons.math4.optim.nonlinear.scalar.PopulationSize;
037import org.apache.commons.math4.optim.nonlinear.scalar.Sigma;
038import org.apache.commons.math4.util.MathArrays;
039
040/**
041 * Implements the <a href="https://www.sciencedirect.com/science/article/pii/S0045794914000935">
042 * Colliding Bodies Optimization</a> (CBO) meta-heuristics.
043 *
044 * <p>
045 * CBO is a global search algorithm that updates a {@link PopulationSize
046 * list of candidate solutions}.
047 * The initial candidates are {@link Sigma randomly spread} around the
048 * start point.
049 * The search gradually becomes more local as a linear function of the
050 * {@link org.apache.commons.math4.optim.MaxIter iterations}.
051 * </p>
052 *
053 * <p>
054 * Class is <em>not</em> thread-safe.
055 * </p>
056 */
057public class CollidingBodiesOptimizer extends MultivariateOptimizer {
058    /** RNG. */
059    private final UniformRandomProvider rng = RandomSource.create(RandomSource.SPLIT_MIX_64);
060    /** Comparator. */
061    private final Comparator<PointValuePair> comparator =
062        new Comparator<PointValuePair>() {
063            /** {@inheritDoc} */
064            @Override
065            public int compare(final PointValuePair o1,
066                               final PointValuePair o2) {
067                return Double.compare(o1.getValue(), o2.getValue());
068            }
069        };
070    /** Number of colliding bodies. */
071    private int numberOfBodies;
072    /** Input sigma. */
073    private double[] inputSigma;
074    /** Population of solution candidates. */
075    private List<PointValuePair> population;
076    /** Fitness function (lower is better). */
077    private MultivariateFunction fitness;
078    /** Dimension of solution space. */
079    private int dimension;
080
081    /**
082     * @param rel Relative threshold.
083     * @param abs Absolute threshold.
084     */
085    public CollidingBodiesOptimizer(double rel, double abs) {
086        this(new SimpleValueChecker(rel, abs));
087    }
088
089    /**
090     * @param checker Convergence checker.
091     */
092    public CollidingBodiesOptimizer(ConvergenceChecker<PointValuePair> checker) {
093        super(checker);
094    }
095
096    /**
097     * {@inheritDoc}
098     *
099     * @param optData Optimization data. In addition to those documented in
100     * {@link MultivariateOptimizer#parseOptimizationData(OptimizationData[])
101     * MultivariateOptimizer}, this method will register the following data:
102     * <ul>
103     *  <li>{@link PopulationSize} (if odd, it will be increased by one).</li>
104     *  <li>{@link Sigma}</li>
105     * </ul>
106     * @return {@inheritDoc}
107     */
108    @Override
109    public PointValuePair optimize(OptimizationData... optData) {
110        // Set up base class and perform computation.
111        return super.optimize(optData);
112    }
113
114    /** {@inheritDoc} */
115    @Override
116    protected PointValuePair doOptimize() {
117        dimension = getStartPoint().length;
118        initializeFitnessFunction();
119        initializePopulation();
120
121        final ConvergenceChecker<PointValuePair> checker = getConvergenceChecker();
122
123        final int lastIter = getMaxIterations() - 1;
124        while (getIterations() < lastIter) {
125            updatePopulation();
126
127            if (checker != null) {
128                // Assumes that the population is sorted (cf. "updatePopulation" method).
129                final PointValuePair best = population.get(0);
130                final PointValuePair worst = population.get(numberOfBodies - 1);
131                if (checker.converged(getIterations(), best, worst)) {
132                    return best;
133                }
134            }
135
136            incrementIterationCount();
137        }
138
139        return population.get(0); // Current best.
140    }
141
142    /**
143     * Scans the list of (required and optional) optimization data that
144     * characterize the problem.
145     *
146     * @param optData Optimization data.
147     * The following data will be looked for:
148     * <ul>
149     *  <li>{@link PopulationSize}</li>
150     *  <li>{@link Sigma}</li>
151     * </ul>
152     */
153    @Override
154    protected void parseOptimizationData(OptimizationData... optData) {
155        // Allow base class to register its own data.
156        super.parseOptimizationData(optData);
157
158        // The existing values (as set by the previous call) are reused if
159        // not provided in the argument list.
160        for (OptimizationData data : optData) {
161            if (data instanceof PopulationSize) {
162                numberOfBodies = ((PopulationSize) data).getPopulationSize();
163                if (numberOfBodies % 2 != 0) {
164                    ++numberOfBodies;
165                }
166                continue;
167            }
168            if (data instanceof Sigma) {
169                inputSigma = ((Sigma) data).getSigma();
170                continue;
171            }
172        }
173
174        checkParameters();
175    }
176
177    /**
178     * Checks dimensions and ranges.
179     */
180    private void checkParameters() {
181        final double[] init = getStartPoint();
182        if (inputSigma != null) {
183            if (inputSigma.length != init.length) {
184                throw new DimensionMismatchException(inputSigma.length, init.length);
185            }
186        } else {
187            inputSigma = new double[init.length];
188            Arrays.fill(inputSigma, 1); // Default.
189        }
190
191        if (numberOfBodies <= 0) {
192            throw new NotStrictlyPositiveException(numberOfBodies);
193        }
194    }
195
196    /**
197     * Creates the initial population.
198     */
199    private void initializePopulation() {
200        population = new ArrayList<>(numberOfBodies);
201
202        final double[] init = getStartPoint();
203        for (int i = 0; i < numberOfBodies; i++) {
204            final double[] point = new double[dimension];
205            for (int j = 0; j < dimension; j++) {
206                point[j] = init[j] + inputSigma[j] * (rng.nextDouble() - 0.5);
207            }
208
209            final PointValuePair p = new PointValuePair(point, fitness.value(point));
210            population.add(p);
211        }
212    }
213
214    /**
215     * Creates the fitness function.
216     */
217    private void initializeFitnessFunction() {
218        fitness = new MultivariateFunction() {
219                private final boolean isMinim = getGoalType() == GoalType.MINIMIZE;
220
221                /** {@inheritDoc} */
222                @Override
223                public double value(double[] point) {
224                    // Indirect call to "computeObjectiveValue" in order to
225                    // update the evaluations counter.
226                    final double v = computeObjectiveValue(point);
227                    return isMinim ? v : 1 / v;
228                }
229            };
230    }
231
232    /**
233     * Computes current masses.
234     * It is assumed that all fitnesses have been previously computed.
235     *
236     * @return the masses of the {@link #population}.
237     */
238    private double[] computeMass() {
239        final double[] mass = new double[numberOfBodies];
240
241        double totalMass = 0;
242        for (int i = 0; i < numberOfBodies; i++) {
243            totalMass += 1 / population.get(i).getValue();
244        }
245
246        for (int i = 0; i < numberOfBodies; i++) {
247            mass[i] = 1 / population.get(i).getValue() / totalMass;
248        }
249
250        // Sanity checks.
251        MathArrays.checkPositive(mass);
252        MathArrays.checkNotNaN(mass);
253
254        return mass;
255    }
256
257    /**
258     * Updates the population.
259     */
260    private void updatePopulation() {
261        // New population.
262        final List<PointValuePair> nextGen = new ArrayList<>(numberOfBodies);
263
264        // Sort from best to worst.
265        Collections.sort(population, comparator);
266        // Compute all "masses".
267        final double[] mass = computeMass();
268
269        // Coefficient of restitution.
270        final double cor = 1 - getIterations() / (double) getMaxIterations();
271        if (cor < 0 ||
272            cor > 1) {
273            throw new MathInternalError();
274        }
275
276        // First index of "worse" bodies.
277        final int max = numberOfBodies / 2;
278
279        for (int i = 0; i < max; i++) {
280            // Pairing.
281            final int bI = i; // "better" body.
282            final int wI = i + max; // "worse" body.
283
284            // Current "position" (i.e. candidate solutions).
285            final double[] bP = population.get(bI).getPoint();
286            final double[] wP = population.get(wI).getPoint();
287
288            // Position vectors.
289            final double[] bV = new double[dimension];
290            final double[] wV = new double[dimension];
291
292            // Coefficient for computing final velocities.
293            final double oneOverMassSum = 1 / (mass[bI] + mass[wI]);
294            final double bCoeff = (mass[wI] + cor * mass[wI]) * oneOverMassSum;
295            final double wCoeff = (mass[wI] - cor * mass[bI]) * oneOverMassSum;
296
297            // All updates can be done in-place within the same loop.
298            for (int j = 0; j < dimension; j++) {
299                // "bP" is unchanged (set to zero: At rest).
300
301                // Set "worse" body's position to the difference with
302                // the corresponding "better" body.
303                wV[j] = wP[j] - bP[j];
304
305                // Update velocities.
306                final double wVel = wV[j];
307                bV[j] = bCoeff * wVel;
308                wV[j] = wCoeff * wVel;
309
310                // Update positions.
311                final double rand = 2 * rng.nextDouble() - 1;
312                final double bPos = bP[j];
313                bP[j] = bPos + rand * bV[j];
314                wP[j] = bPos + rand * wV[j];
315            }
316
317            // Add new candidate solutions.
318            nextGen.add(new PointValuePair(wP, fitness.value(wP)));
319            nextGen.add(new PointValuePair(bP, fitness.value(bP)));
320        }
321
322        // Replace with next generation.
323        population = nextGen;
324    }
325}