SimplexOptimizer.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.math4.legacy.optim.nonlinear.scalar.noderiv;

  18. import java.util.Arrays;
  19. import java.util.List;
  20. import java.util.ArrayList;
  21. import java.util.Comparator;
  22. import java.util.Collections;
  23. import java.util.Objects;
  24. import java.util.function.UnaryOperator;
  25. import java.util.function.IntSupplier;
  26. import java.util.concurrent.CopyOnWriteArrayList;

  27. import org.apache.commons.math4.legacy.core.MathArrays;
  28. import org.apache.commons.math4.legacy.analysis.MultivariateFunction;
  29. import org.apache.commons.math4.legacy.exception.MathUnsupportedOperationException;
  30. import org.apache.commons.math4.legacy.exception.MathInternalError;
  31. import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
  32. import org.apache.commons.math4.legacy.optim.ConvergenceChecker;
  33. import org.apache.commons.math4.legacy.optim.OptimizationData;
  34. import org.apache.commons.math4.legacy.optim.PointValuePair;
  35. import org.apache.commons.math4.legacy.optim.SimpleValueChecker;
  36. import org.apache.commons.math4.legacy.optim.InitialGuess;
  37. import org.apache.commons.math4.legacy.optim.MaxEval;
  38. import org.apache.commons.math4.legacy.optim.nonlinear.scalar.GoalType;
  39. import org.apache.commons.math4.legacy.optim.nonlinear.scalar.MultivariateOptimizer;
  40. import org.apache.commons.math4.legacy.optim.nonlinear.scalar.SimulatedAnnealing;
  41. import org.apache.commons.math4.legacy.optim.nonlinear.scalar.PopulationSize;
  42. import org.apache.commons.math4.legacy.optim.nonlinear.scalar.ObjectiveFunction;

  43. /**
  44.  * This class implements simplex-based direct search optimization.
  45.  *
  46.  * <p>
  47.  * Direct search methods only use objective function values, they do
  48.  * not need derivatives and don't either try to compute approximation
  49.  * of the derivatives. According to a 1996 paper by Margaret H. Wright
  50.  * (<a href="http://cm.bell-labs.com/cm/cs/doc/96/4-02.ps.gz">Direct
  51.  * Search Methods: Once Scorned, Now Respectable</a>), they are used
  52.  * when either the computation of the derivative is impossible (noisy
  53.  * functions, unpredictable discontinuities) or difficult (complexity,
  54.  * computation cost). In the first cases, rather than an optimum, a
  55.  * <em>not too bad</em> point is desired. In the latter cases, an
  56.  * optimum is desired but cannot be reasonably found. In all cases
  57.  * direct search methods can be useful.
  58.  *
  59.  * <p>
  60.  * Simplex-based direct search methods are based on comparison of
  61.  * the objective function values at the vertices of a simplex (which is a
  62.  * set of n+1 points in dimension n) that is updated by the algorithms
  63.  * steps.
  64.  *
  65.  * <p>
  66.  * In addition to those documented in
  67.  * {@link MultivariateOptimizer#optimize(OptimizationData[]) MultivariateOptimizer},
  68.  * an instance of this class will register the following data:
  69.  * <ul>
  70.  *  <li>{@link Simplex}</li>
  71.  *  <li>{@link Simplex.TransformFactory}</li>
  72.  *  <li>{@link SimulatedAnnealing}</li>
  73.  *  <li>{@link PopulationSize}</li>
  74.  * </ul>
  75.  *
  76.  * <p>
  77.  * Each call to {@code optimize} will re-use the start configuration of
  78.  * the current simplex and move it such that its first vertex is at the
  79.  * provided start point of the optimization.
  80.  * If the {@code optimize} method is called to solve a different problem
  81.  * and the number of parameters change, the simplex must be re-initialized
  82.  * to one with the appropriate dimensions.
  83.  *
  84.  * <p>
  85.  * Convergence is considered achieved when <em>all</em> the simplex points
  86.  * have converged.
  87.  * <p>
  88.  * Whenever {@link SimulatedAnnealing simulated annealing (SA)} is activated,
  89.  * and the SA phase has completed, convergence has probably not been reached
  90.  * yet; whenever it's the case, an additional (non-SA) search will be performed
  91.  * (using the current best simplex point as a start point).
  92.  * <p>
  93.  * Additional "best list" searches can be requested through setting the
  94.  * {@link PopulationSize} argument of the {@link #optimize(OptimizationData[])
  95.  * optimize} method.
  96.  *
  97.  * <p>
  98.  * This implementation does not directly support constrained optimization
  99.  * with simple bounds.
  100.  * The call to {@link #optimize(OptimizationData[]) optimize} will throw
  101.  * {@link MathUnsupportedOperationException} if bounds are passed to it.
  102.  *
  103.  * @see NelderMeadTransform
  104.  * @see MultiDirectionalTransform
  105.  * @see HedarFukushimaTransform
  106.  */
  107. public class SimplexOptimizer extends MultivariateOptimizer {
  108.     /** Default simplex side length ratio. */
  109.     private static final double SIMPLEX_SIDE_RATIO = 1e-1;
  110.     /** Simplex update function factory. */
  111.     private Simplex.TransformFactory updateRule;
  112.     /** Initial simplex. */
  113.     private Simplex initialSimplex;
  114.     /** Simulated annealing setup (optional). */
  115.     private SimulatedAnnealing simulatedAnnealing = null;
  116.     /** User-defined number of additional optimizations (optional). */
  117.     private int populationSize = 0;
  118.     /** Actual number of additional optimizations. */
  119.     private int additionalSearch = 0;
  120.     /** Callbacks. */
  121.     private final List<Observer> callbacks = new CopyOnWriteArrayList<>();

  122.     /**
  123.      * @param checker Convergence checker.
  124.      */
  125.     public SimplexOptimizer(ConvergenceChecker<PointValuePair> checker) {
  126.         super(checker);
  127.     }

  128.     /**
  129.      * @param rel Relative threshold.
  130.      * @param abs Absolute threshold.
  131.      */
  132.     public SimplexOptimizer(double rel,
  133.                             double abs) {
  134.         this(new SimpleValueChecker(rel, abs));
  135.     }

  136.     /**
  137.      * Callback interface for updating caller's code with the current
  138.      * state of the optimization.
  139.      */
  140.     @FunctionalInterface
  141.     public interface Observer {
  142.         /**
  143.          * Method called after each modification of the {@code simplex}.
  144.          *
  145.          * @param simplex Current simplex.
  146.          * @param isInit {@code true} at the start of a new search (either
  147.          * "main" or "best list"), after the initial simplex's vertices
  148.          * have been evaluated.
  149.          * @param numEval Number of evaluations of the objective function.
  150.          */
  151.         void update(Simplex simplex,
  152.                     boolean isInit,
  153.                     int numEval);
  154.     }

  155.     /**
  156.      * Register a callback.
  157.      *
  158.      * @param cb Callback.
  159.      */
  160.     public void addObserver(Observer cb) {
  161.         Objects.requireNonNull(cb, "Callback");
  162.         callbacks.add(cb);
  163.     }

  164.     /** {@inheritDoc} */
  165.     @Override
  166.     protected PointValuePair doOptimize() {
  167.         checkParameters();

  168.         final MultivariateFunction evalFunc = getObjectiveFunction();

  169.         final boolean isMinim = getGoalType() == GoalType.MINIMIZE;
  170.         final Comparator<PointValuePair> comparator = (o1, o2) -> {
  171.             final double v1 = o1.getValue();
  172.             final double v2 = o2.getValue();
  173.             return isMinim ? Double.compare(v1, v2) : Double.compare(v2, v1);
  174.         };

  175.         // Start points for additional search.
  176.         final List<PointValuePair> bestList = new ArrayList<>();

  177.         Simplex currentSimplex = initialSimplex.translate(getStartPoint()).evaluate(evalFunc, comparator);
  178.         notifyObservers(currentSimplex, true);
  179.         double temperature = Double.NaN; // Only used with simulated annealing.
  180.         Simplex previousSimplex = null;

  181.         if (simulatedAnnealing != null) {
  182.             temperature =
  183.                 temperature(currentSimplex.get(0),
  184.                             currentSimplex.get(currentSimplex.getDimension()),
  185.                             simulatedAnnealing.getStartProbability());
  186.         }

  187.         while (true) {
  188.             if (previousSimplex != null) { // Skip check at first iteration.
  189.                 if (hasConverged(previousSimplex, currentSimplex)) {
  190.                     return currentSimplex.get(0);
  191.                 }
  192.             }

  193.             // We still need to search.
  194.             previousSimplex = currentSimplex;

  195.             if (simulatedAnnealing != null) {
  196.                 // Update current temperature.
  197.                 temperature =
  198.                     simulatedAnnealing.getCoolingSchedule().apply(temperature,
  199.                                                                   currentSimplex);

  200.                 final double endTemperature =
  201.                     temperature(currentSimplex.get(0),
  202.                                 currentSimplex.get(currentSimplex.getDimension()),
  203.                                 simulatedAnnealing.getEndProbability());

  204.                 if (temperature < endTemperature) {
  205.                     break;
  206.                 }

  207.                 final UnaryOperator<Simplex> update =
  208.                     updateRule.create(evalFunc,
  209.                                       comparator,
  210.                                       simulatedAnnealing.metropolis(temperature));

  211.                 for (int i = 0; i < simulatedAnnealing.getEpochDuration(); i++) {
  212.                     // Simplex is transformed (and observers are notified).
  213.                     currentSimplex = applyUpdate(update,
  214.                                                  currentSimplex,
  215.                                                  evalFunc,
  216.                                                  comparator);
  217.                 }
  218.             } else {
  219.                 // No simulated annealing.
  220.                 final UnaryOperator<Simplex> update =
  221.                     updateRule.create(evalFunc, comparator, null);

  222.                 // Simplex is transformed (and observers are notified).
  223.                 currentSimplex = applyUpdate(update,
  224.                                              currentSimplex,
  225.                                              evalFunc,
  226.                                              comparator);
  227.             }

  228.             if (additionalSearch != 0) {
  229.                 // In "bestList", we must keep track of at least two points
  230.                 // in order to be able to compute the new initial simplex for
  231.                 // the additional search.
  232.                 final int max = Math.max(additionalSearch, 2);

  233.                 // Store best points.
  234.                 for (int i = 0; i < currentSimplex.getSize(); i++) {
  235.                     keepIfBetter(currentSimplex.get(i),
  236.                                  comparator,
  237.                                  bestList,
  238.                                  max);
  239.                 }
  240.             }

  241.             incrementIterationCount();
  242.         }

  243.         // No convergence.

  244.         if (additionalSearch > 0) {
  245.             // Additional optimizations.
  246.             // Reference to counter in the "main" search in order to retrieve
  247.             // the total number of evaluations in the "best list" search.
  248.             final IntSupplier evalCount = () -> getEvaluations();

  249.             return bestListSearch(evalFunc,
  250.                                   comparator,
  251.                                   bestList,
  252.                                   evalCount);
  253.         }

  254.         throw new MathInternalError(); // Should never happen.
  255.     }

  256.     /**
  257.      * Scans the list of (required and optional) optimization data that
  258.      * characterize the problem.
  259.      *
  260.      * @param optData Optimization data.
  261.      * The following data will be looked for:
  262.      * <ul>
  263.      *  <li>{@link Simplex}</li>
  264.      *  <li>{@link Simplex.TransformFactory}</li>
  265.      *  <li>{@link SimulatedAnnealing}</li>
  266.      *  <li>{@link PopulationSize}</li>
  267.      * </ul>
  268.      */
  269.     @Override
  270.     protected void parseOptimizationData(OptimizationData... optData) {
  271.         // Allow base class to register its own data.
  272.         super.parseOptimizationData(optData);

  273.         // The existing values (as set by the previous call) are reused
  274.         // if not provided in the argument list.
  275.         for (OptimizationData data : optData) {
  276.             if (data instanceof Simplex) {
  277.                 initialSimplex = (Simplex) data;
  278.             } else if (data instanceof Simplex.TransformFactory) {
  279.                 updateRule = (Simplex.TransformFactory) data;
  280.             } else if (data instanceof SimulatedAnnealing) {
  281.                 simulatedAnnealing = (SimulatedAnnealing) data;
  282.             } else if (data instanceof PopulationSize) {
  283.                 populationSize = ((PopulationSize) data).getPopulationSize();
  284.             }
  285.         }
  286.     }

  287.     /**
  288.      * Detects whether the simplex has shrunk below the user-defined
  289.      * tolerance.
  290.      *
  291.      * @param previous Simplex at previous iteration.
  292.      * @param current Simplex at current iteration.
  293.      * @return {@code true} if convergence is considered achieved.
  294.      */
  295.     private boolean hasConverged(Simplex previous,
  296.                                  Simplex current) {
  297.         final ConvergenceChecker<PointValuePair> checker = getConvergenceChecker();

  298.         for (int i = 0; i < current.getSize(); i++) {
  299.             final PointValuePair prev = previous.get(i);
  300.             final PointValuePair curr = current.get(i);

  301.             if (!checker.converged(getIterations(), prev, curr)) {
  302.                 return false;
  303.             }
  304.         }

  305.         return true;
  306.     }

  307.     /**
  308.      * @throws MathUnsupportedOperationException if bounds were passed to the
  309.      * {@link #optimize(OptimizationData[]) optimize} method.
  310.      * @throws NullPointerException if no initial simplex or no transform rule
  311.      * was passed to the {@link #optimize(OptimizationData[]) optimize} method.
  312.      * @throws IllegalArgumentException if {@link #populationSize} is negative.
  313.      */
  314.     private void checkParameters() {
  315.         Objects.requireNonNull(updateRule, "Update rule");
  316.         Objects.requireNonNull(initialSimplex, "Initial simplex");

  317.         if (getLowerBound() != null ||
  318.             getUpperBound() != null) {
  319.             throw new MathUnsupportedOperationException(LocalizedFormats.CONSTRAINT);
  320.         }

  321.         if (populationSize < 0) {
  322.             throw new IllegalArgumentException("Population size");
  323.         }

  324.         additionalSearch = simulatedAnnealing == null ?
  325.             Math.max(0, populationSize) :
  326.             Math.max(1, populationSize);
  327.     }

  328.     /**
  329.      * Computes the temperature as a function of the acceptance probability
  330.      * and the fitness difference between two of the simplex vertices (usually
  331.      * the best and worst points).
  332.      *
  333.      * @param p1 Simplex point.
  334.      * @param p2 Simplex point.
  335.      * @param prob Acceptance probability.
  336.      * @return the temperature.
  337.      */
  338.     private double temperature(PointValuePair p1,
  339.                                PointValuePair p2,
  340.                                double prob) {
  341.         return -Math.abs(p1.getValue() - p2.getValue()) / Math.log(prob);
  342.     }

  343.     /**
  344.      * Stores the given {@code candidate} if its fitness is better than
  345.      * that of the last (assumed to be the worst) point in {@code list}.
  346.      *
  347.      * <p>If the list is below the maximum size then the {@code candidate}
  348.      * is added if it is not already in the list. The list is sorted
  349.      * when it reaches the maximum size.
  350.      *
  351.      * @param candidate Point to be stored.
  352.      * @param comp Fitness comparator.
  353.      * @param list Starting points (modified in-place).
  354.      * @param max Maximum size of the {@code list}.
  355.      */
  356.     private static void keepIfBetter(PointValuePair candidate,
  357.                                      Comparator<PointValuePair> comp,
  358.                                      List<PointValuePair> list,
  359.                                      int max) {
  360.         final int listSize = list.size();
  361.         final double[] candidatePoint = candidate.getPoint();
  362.         if (listSize == 0) {
  363.             list.add(candidate);
  364.         } else if (listSize < max) {
  365.             // List is not fully populated yet.
  366.             for (PointValuePair p : list) {
  367.                 final double[] pPoint = p.getPoint();
  368.                 if (Arrays.equals(pPoint, candidatePoint)) {
  369.                     // Point was already stored.
  370.                     return;
  371.                 }
  372.             }
  373.             // Store candidate.
  374.             list.add(candidate);
  375.             // Sort the list when required
  376.             if (list.size() == max) {
  377.                 Collections.sort(list, comp);
  378.             }
  379.         } else {
  380.             final int last = max - 1;
  381.             if (comp.compare(candidate, list.get(last)) < 0) {
  382.                 for (PointValuePair p : list) {
  383.                     final double[] pPoint = p.getPoint();
  384.                     if (Arrays.equals(pPoint, candidatePoint)) {
  385.                         // Point was already stored.
  386.                         return;
  387.                     }
  388.                 }

  389.                 // Store better candidate and reorder the list.
  390.                 list.set(last, candidate);
  391.                 Collections.sort(list, comp);
  392.             }
  393.         }
  394.     }

  395.     /**
  396.      * Computes the smallest distance between the given {@code point}
  397.      * and any of the other points in the {@code list}.
  398.      *
  399.      * @param point Point.
  400.      * @param list List.
  401.      * @return the smallest distance.
  402.      */
  403.     private static double shortestDistance(PointValuePair point,
  404.                                            List<PointValuePair> list) {
  405.         double minDist = Double.POSITIVE_INFINITY;

  406.         final double[] p = point.getPoint();
  407.         for (PointValuePair other : list) {
  408.             final double[] pOther = other.getPoint();
  409.             if (!Arrays.equals(p, pOther)) {
  410.                 final double dist = MathArrays.distance(p, pOther);
  411.                 if (dist < minDist) {
  412.                     minDist = dist;
  413.                 }
  414.             }
  415.         }

  416.         return minDist;
  417.     }

  418.     /**
  419.      * Perform additional optimizations.
  420.      *
  421.      * @param evalFunc Objective function.
  422.      * @param comp Fitness comparator.
  423.      * @param bestList Best points encountered during the "main" search.
  424.      * List is assumed to be ordered from best to worst.
  425.      * @param evalCount Evaluation counter.
  426.      * @return the optimum.
  427.      */
  428.     private PointValuePair bestListSearch(MultivariateFunction evalFunc,
  429.                                           Comparator<PointValuePair> comp,
  430.                                           List<PointValuePair> bestList,
  431.                                           IntSupplier evalCount) {
  432.         PointValuePair best = bestList.get(0); // Overall best result.

  433.         // Additional local optimizations using each of the best
  434.         // points visited during the main search.
  435.         for (int i = 0; i < additionalSearch; i++) {
  436.             final PointValuePair start = bestList.get(i);
  437.             // Find shortest distance to the other points.
  438.             final double dist = shortestDistance(start, bestList);
  439.             final double[] init = start.getPoint();
  440.             // Create smaller initial simplex.
  441.             final Simplex simplex = Simplex.equalSidesAlongAxes(init.length,
  442.                                                                 SIMPLEX_SIDE_RATIO * dist);

  443.             final PointValuePair r = directSearch(init,
  444.                                                   simplex,
  445.                                                   evalFunc,
  446.                                                   getConvergenceChecker(),
  447.                                                   getGoalType(),
  448.                                                   callbacks,
  449.                                                   evalCount);
  450.             if (comp.compare(r, best) < 0) {
  451.                 best = r; // New overall best.
  452.             }
  453.         }

  454.         return best;
  455.     }

  456.     /**
  457.      * @param init Start point.
  458.      * @param simplex Initial simplex.
  459.      * @param eval Objective function.
  460.      * Note: It is assumed that evaluations of this function are
  461.      * incrementing the main counter.
  462.      * @param checker Convergence checker.
  463.      * @param goalType Whether to minimize or maximize the objective function.
  464.      * @param cbList Callbacks.
  465.      * @param evalCount Evaluation counter.
  466.      * @return the optimum.
  467.      */
  468.     private static PointValuePair directSearch(double[] init,
  469.                                                Simplex simplex,
  470.                                                MultivariateFunction eval,
  471.                                                ConvergenceChecker<PointValuePair> checker,
  472.                                                GoalType goalType,
  473.                                                List<Observer> cbList,
  474.                                                final IntSupplier evalCount) {
  475.         final SimplexOptimizer optim = new SimplexOptimizer(checker);

  476.         for (Observer cOrig : cbList) {
  477.             final SimplexOptimizer.Observer cNew = (spx, isInit, numEval) ->
  478.                 cOrig.update(spx, isInit, evalCount.getAsInt());

  479.             optim.addObserver(cNew);
  480.         }

  481.         return optim.optimize(MaxEval.unlimited(),
  482.                               new ObjectiveFunction(eval),
  483.                               goalType,
  484.                               new InitialGuess(init),
  485.                               simplex,
  486.                               new MultiDirectionalTransform());
  487.     }

  488.     /**
  489.      * @param simplex Current simplex.
  490.      * @param isInit Set to {@code true} at the start of a new search
  491.      * (either "main" or "best list"), after the evaluation of the initial
  492.      * simplex's vertices.
  493.      */
  494.     private void notifyObservers(Simplex simplex,
  495.                                  boolean isInit) {
  496.         for (Observer cb : callbacks) {
  497.             cb.update(simplex,
  498.                       isInit,
  499.                       getEvaluations());
  500.         }
  501.     }

  502.     /**
  503.      * Applies the {@code update} to the given {@code simplex} (and notifies
  504.      * observers).
  505.      *
  506.      * @param update Simplex transformation.
  507.      * @param simplex Current simplex.
  508.      * @param eval Objective function.
  509.      * @param comp Fitness comparator.
  510.      * @return the transformed simplex.
  511.      */
  512.     private Simplex applyUpdate(UnaryOperator<Simplex> update,
  513.                                 Simplex simplex,
  514.                                 MultivariateFunction eval,
  515.                                 Comparator<PointValuePair> comp) {
  516.         final Simplex transformed = update.apply(simplex).evaluate(eval, comp);

  517.         notifyObservers(transformed, false);

  518.         return transformed;
  519.     }
  520. }