BaseMultiStartMultivariateOptimizer.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.math4.legacy.optim;

  18. import java.util.function.Supplier;
  19. import org.apache.commons.math4.legacy.exception.MathIllegalStateException;
  20. import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
  21. import org.apache.commons.math4.legacy.exception.TooManyEvaluationsException;

  22. /**
  23.  * Base class multi-start optimizer for a multivariate function.
  24.  * <br>
  25.  * This class wraps an optimizer in order to use it several times in
  26.  * turn with different starting points (trying to avoid being trapped
  27.  * in a local extremum when looking for a global one).
  28.  * <em>It is not a "user" class.</em>
  29.  *
  30.  * @param <PAIR> Type of the point/value pair returned by the optimization
  31.  * algorithm.
  32.  *
  33.  * @since 3.0
  34.  */
  35. public abstract class BaseMultiStartMultivariateOptimizer<PAIR>
  36.     extends BaseMultivariateOptimizer<PAIR> {
  37.     /** Underlying classical optimizer. */
  38.     private final BaseMultivariateOptimizer<PAIR> optimizer;
  39.     /** Number of evaluations already performed for all starts. */
  40.     private int totalEvaluations;
  41.     /** Number of starts to go. */
  42.     private final int starts;
  43.     /** Generator of start points ("multi-start"). */
  44.     private final Supplier<double[]> generator;
  45.     /** Optimization data. */
  46.     private OptimizationData[] optimData;
  47.     /**
  48.      * Location in {@link #optimData} where the updated maximum
  49.      * number of evaluations will be stored.
  50.      */
  51.     private int maxEvalIndex = -1;
  52.     /**
  53.      * Location in {@link #optimData} where the updated start value
  54.      * will be stored.
  55.      */
  56.     private int initialGuessIndex = -1;

  57.     /**
  58.      * Create a multi-start optimizer from a single-start optimizer.
  59.      * <p>
  60.      * Note that if there are bounds constraints (see {@link #getLowerBound()}
  61.      * and {@link #getUpperBound()}), then a simple rejection algorithm is used
  62.      * at each restart. This implies that the random vector generator should have
  63.      * a good probability to generate vectors in the bounded domain, otherwise the
  64.      * rejection algorithm will hit the {@link #getMaxEvaluations()} count without
  65.      * generating a proper restart point. Users must be take great care of the <a
  66.      * href="http://en.wikipedia.org/wiki/Curse_of_dimensionality">curse of dimensionality</a>.
  67.      * </p>
  68.      * @param optimizer Single-start optimizer to wrap.
  69.      * @param starts Number of starts to perform. If {@code starts == 1},
  70.      * the {@link #optimize(OptimizationData[]) optimize} will return the
  71.      * same solution as the given {@code optimizer} would return.
  72.      * @param generator Generator to use for restarts.
  73.      * @throws NotStrictlyPositiveException if {@code starts < 1}.
  74.      */
  75.     public BaseMultiStartMultivariateOptimizer(final BaseMultivariateOptimizer<PAIR> optimizer,
  76.                                                final int starts,
  77.                                                final Supplier<double[]> generator) {
  78.         super(optimizer.getConvergenceChecker());

  79.         if (starts < 1) {
  80.             throw new NotStrictlyPositiveException(starts);
  81.         }

  82.         this.optimizer = optimizer;
  83.         this.starts = starts;
  84.         this.generator = generator;
  85.     }

  86.     /** {@inheritDoc} */
  87.     @Override
  88.     public int getEvaluations() {
  89.         return totalEvaluations;
  90.     }

  91.     /**
  92.      * Gets all the optima found during the last call to {@code optimize}.
  93.      * The optimizer stores all the optima found during a set of
  94.      * restarts. The {@code optimize} method returns the best point only.
  95.      * This method returns all the points found at the end of each starts,
  96.      * including the best one already returned by the {@code optimize} method.
  97.      * <br>
  98.      * The returned array as one element for each start as specified
  99.      * in the constructor. It is ordered with the results from the
  100.      * runs that did converge first, sorted from best to worst
  101.      * objective value (i.e in ascending order if minimizing and in
  102.      * descending order if maximizing), followed by {@code null} elements
  103.      * corresponding to the runs that did not converge. This means all
  104.      * elements will be {@code null} if the {@code optimize} method did throw
  105.      * an exception.
  106.      * This also means that if the first element is not {@code null}, it is
  107.      * the best point found across all starts.
  108.      * <br>
  109.      * The behaviour is undefined if this method is called before
  110.      * {@code optimize}; it will likely throw {@code NullPointerException}.
  111.      *
  112.      * @return an array containing the optima sorted from best to worst.
  113.      */
  114.     public abstract PAIR[] getOptima();

  115.     /**
  116.      * {@inheritDoc}
  117.      *
  118.      * @throws MathIllegalStateException if {@code optData} does not contain an
  119.      * instance of {@link MaxEval} or {@link InitialGuess}.
  120.      */
  121.     @Override
  122.     public PAIR optimize(OptimizationData... optData) {
  123.         // Store arguments in order to pass them to the internal optimizer.
  124.        optimData = optData;
  125.         // Set up base class and perform computations.
  126.         return super.optimize(optData);
  127.     }

  128.     /** {@inheritDoc} */
  129.     @Override
  130.     protected PAIR doOptimize() {
  131.         // Remove all instances of "MaxEval" and "InitialGuess" from the
  132.         // array that will be passed to the internal optimizer.
  133.         // The former is to enforce smaller numbers of allowed evaluations
  134.         // (according to how many have been used up already), and the latter
  135.         // to impose a different start value for each start.
  136.         for (int i = 0; i < optimData.length; i++) {
  137.             if (optimData[i] instanceof MaxEval) {
  138.                 optimData[i] = null;
  139.                 maxEvalIndex = i;
  140.             }
  141.             if (optimData[i] instanceof InitialGuess) {
  142.                 optimData[i] = null;
  143.                 initialGuessIndex = i;
  144.                 continue;
  145.             }
  146.         }
  147.         if (maxEvalIndex == -1) {
  148.             throw new MathIllegalStateException();
  149.         }
  150.         if (initialGuessIndex == -1) {
  151.             throw new MathIllegalStateException();
  152.         }

  153.         RuntimeException lastException = null;
  154.         totalEvaluations = 0;
  155.         clear();

  156.         final int maxEval = getMaxEvaluations();
  157.         final double[] min = getLowerBound();
  158.         final double[] max = getUpperBound();
  159.         final double[] startPoint = getStartPoint();

  160.         // Multi-start loop.
  161.         for (int i = 0; i < starts; i++) {
  162.             // CHECKSTYLE: stop IllegalCatch
  163.             try {
  164.                 // Decrease number of allowed evaluations.
  165.                 optimData[maxEvalIndex] = new MaxEval(maxEval - totalEvaluations);
  166.                 // New start value.
  167.                 double[] s = null;
  168.                 if (i == 0) {
  169.                     s = startPoint;
  170.                 } else {
  171.                     int attempts = 0;
  172.                     while (s == null) {
  173.                         if (attempts++ >= getMaxEvaluations()) {
  174.                             throw new TooManyEvaluationsException(getMaxEvaluations());
  175.                         }
  176.                         s = generator.get();
  177.                         for (int k = 0; s != null && k < s.length; ++k) {
  178.                             if ((min != null && s[k] < min[k]) || (max != null && s[k] > max[k])) {
  179.                                 // reject the vector
  180.                                 s = null;
  181.                                 break;
  182.                             }
  183.                         }
  184.                     }
  185.                 }
  186.                 optimData[initialGuessIndex] = new InitialGuess(s);
  187.                 // Optimize.
  188.                 final PAIR result = optimizer.optimize(optimData);
  189.                 store(result);
  190.             } catch (RuntimeException mue) {
  191.                 lastException = mue;
  192.             }
  193.             // CHECKSTYLE: resume IllegalCatch

  194.             totalEvaluations += optimizer.getEvaluations();
  195.         }

  196.         final PAIR[] optima = getOptima();
  197.         if (optima.length == 0) {
  198.             // All runs failed.
  199.             throw lastException; // Cannot be null if starts >= 1.
  200.         }

  201.         // Return the best optimum.
  202.         return optima[0];
  203.     }

  204.     /**
  205.      * Method that will be called in order to store each found optimum.
  206.      *
  207.      * @param optimum Result of an optimization run.
  208.      */
  209.     protected abstract void store(PAIR optimum);
  210.     /**
  211.      * Method that will called in order to clear all stored optima.
  212.      */
  213.     protected abstract void clear();
  214. }