NelderMeadTransform.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.math4.legacy.optim.nonlinear.scalar.noderiv;

  18. import java.util.Comparator;
  19. import java.util.function.UnaryOperator;
  20. import java.util.function.DoublePredicate;

  21. import org.apache.commons.math4.legacy.analysis.MultivariateFunction;
  22. import org.apache.commons.math4.legacy.optim.PointValuePair;

  23. /**
  24.  * <a href="https://en.wikipedia.org/wiki/Nelder%E2%80%93Mead_method">Nelder-Mead method</a>.
  25.  */
  26. public class NelderMeadTransform
  27.     implements Simplex.TransformFactory {
  28.     /** Default value for {@link #alpha}: {@value}. */
  29.     private static final double DEFAULT_ALPHA = 1;
  30.     /** Default value for {@link #gamma}: {@value}. */
  31.     private static final double DEFAULT_GAMMA = 2;
  32.     /** Default value for {@link #rho}: {@value}. */
  33.     private static final double DEFAULT_RHO = 0.5;
  34.     /** Default value for {@link #sigma}: {@value}. */
  35.     private static final double DEFAULT_SIGMA = 0.5;
  36.     /** Reflection coefficient. */
  37.     private final double alpha;
  38.     /** Expansion coefficient. */
  39.     private final double gamma;
  40.     /** Contraction coefficient. */
  41.     private final double rho;
  42.     /** Shrinkage coefficient. */
  43.     private final double sigma;

  44.     /**
  45.      * @param alpha Reflection coefficient.
  46.      * @param gamma Expansion coefficient.
  47.      * @param rho Contraction coefficient.
  48.      * @param sigma Shrinkage coefficient.
  49.      */
  50.     public NelderMeadTransform(double alpha,
  51.                                double gamma,
  52.                                double rho,
  53.                                double sigma) {
  54.         this.alpha = alpha;
  55.         this.gamma = gamma;
  56.         this.rho = rho;
  57.         this.sigma = sigma;
  58.     }

  59.     /**
  60.      * Transform with default values.
  61.      */
  62.     public NelderMeadTransform() {
  63.         this(DEFAULT_ALPHA,
  64.              DEFAULT_GAMMA,
  65.              DEFAULT_RHO,
  66.              DEFAULT_SIGMA);
  67.     }

  68.     /** {@inheritDoc} */
  69.     @Override
  70.     public UnaryOperator<Simplex> create(final MultivariateFunction evaluationFunction,
  71.                                          final Comparator<PointValuePair> comparator,
  72.                                          final DoublePredicate sa) {
  73.         return original -> {
  74.             // The simplex has n + 1 points if dimension is n.
  75.             final int n = original.getDimension();

  76.             // Interesting values.
  77.             final PointValuePair best = original.get(0);
  78.             final PointValuePair secondWorst = original.get(n - 1);
  79.             final PointValuePair worst = original.get(n);
  80.             final double[] xWorst = worst.getPoint();

  81.             // Centroid of the best vertices, dismissing the worst point (at index n).
  82.             final double[] centroid = Simplex.centroid(original.asList().subList(0, n));

  83.             // Reflection.
  84.             final PointValuePair reflected = Simplex.newPoint(centroid,
  85.                                                               -alpha,
  86.                                                               xWorst,
  87.                                                               evaluationFunction);
  88.             if (comparator.compare(reflected, secondWorst) < 0 &&
  89.                 comparator.compare(best, reflected) <= 0) {
  90.                 return original.replaceLast(reflected);
  91.             }

  92.             if (comparator.compare(reflected, best) < 0) {
  93.                 // Expansion.
  94.                 final PointValuePair expanded = Simplex.newPoint(centroid,
  95.                                                                  -gamma,
  96.                                                                  xWorst,
  97.                                                                  evaluationFunction);
  98.                 if (comparator.compare(expanded, reflected) < 0 ||
  99.                     (sa != null &&
  100.                      sa.test(expanded.getValue() - reflected.getValue()))) {
  101.                     return original.replaceLast(expanded);
  102.                 } else {
  103.                     return original.replaceLast(reflected);
  104.                 }
  105.             }

  106.             if (comparator.compare(reflected, worst) < 0) {
  107.                 // Outside contraction.
  108.                 final PointValuePair contracted = Simplex.newPoint(centroid,
  109.                                                                    rho,
  110.                                                                    reflected.getPoint(),
  111.                                                                    evaluationFunction);
  112.                 if (comparator.compare(contracted, reflected) < 0) {
  113.                     return original.replaceLast(contracted); // Accept contracted point.
  114.                 }
  115.             } else {
  116.                 // Inside contraction.
  117.                 final PointValuePair contracted = Simplex.newPoint(centroid,
  118.                                                                    rho,
  119.                                                                    xWorst,
  120.                                                                    evaluationFunction);
  121.                 if (comparator.compare(contracted, worst) < 0) {
  122.                     return original.replaceLast(contracted); // Accept contracted point.
  123.                 }
  124.             }

  125.             // Shrink.
  126.             return original.shrink(sigma, evaluationFunction);
  127.         };
  128.     }

  129.     /** {@inheritDoc} */
  130.     @Override
  131.     public String toString() {
  132.         return "Nelder-Mead [a=" + alpha +
  133.             " g=" + gamma +
  134.             " r=" + rho +
  135.             " s=" + sigma + "]";
  136.     }
  137. }