001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.math4.legacy.optim.nonlinear.scalar.noderiv;
018
019import java.util.Comparator;
020import java.util.function.UnaryOperator;
021import java.util.function.DoublePredicate;
022
023import org.apache.commons.math4.legacy.analysis.MultivariateFunction;
024import org.apache.commons.math4.legacy.optim.PointValuePair;
025
026/**
027 * <a href="https://en.wikipedia.org/wiki/Nelder%E2%80%93Mead_method">Nelder-Mead method</a>.
028 */
029public class NelderMeadTransform
030    implements Simplex.TransformFactory {
031    /** Default value for {@link #alpha}: {@value}. */
032    private static final double DEFAULT_ALPHA = 1;
033    /** Default value for {@link #gamma}: {@value}. */
034    private static final double DEFAULT_GAMMA = 2;
035    /** Default value for {@link #rho}: {@value}. */
036    private static final double DEFAULT_RHO = 0.5;
037    /** Default value for {@link #sigma}: {@value}. */
038    private static final double DEFAULT_SIGMA = 0.5;
039    /** Reflection coefficient. */
040    private final double alpha;
041    /** Expansion coefficient. */
042    private final double gamma;
043    /** Contraction coefficient. */
044    private final double rho;
045    /** Shrinkage coefficient. */
046    private final double sigma;
047
048    /**
049     * @param alpha Reflection coefficient.
050     * @param gamma Expansion coefficient.
051     * @param rho Contraction coefficient.
052     * @param sigma Shrinkage coefficient.
053     */
054    public NelderMeadTransform(double alpha,
055                               double gamma,
056                               double rho,
057                               double sigma) {
058        this.alpha = alpha;
059        this.gamma = gamma;
060        this.rho = rho;
061        this.sigma = sigma;
062    }
063
064    /**
065     * Transform with default values.
066     */
067    public NelderMeadTransform() {
068        this(DEFAULT_ALPHA,
069             DEFAULT_GAMMA,
070             DEFAULT_RHO,
071             DEFAULT_SIGMA);
072    }
073
074    /** {@inheritDoc} */
075    @Override
076    public UnaryOperator<Simplex> create(final MultivariateFunction evaluationFunction,
077                                         final Comparator<PointValuePair> comparator,
078                                         final DoublePredicate sa) {
079        return original -> {
080            // The simplex has n + 1 points if dimension is n.
081            final int n = original.getDimension();
082
083            // Interesting values.
084            final PointValuePair best = original.get(0);
085            final PointValuePair secondWorst = original.get(n - 1);
086            final PointValuePair worst = original.get(n);
087            final double[] xWorst = worst.getPoint();
088
089            // Centroid of the best vertices, dismissing the worst point (at index n).
090            final double[] centroid = Simplex.centroid(original.asList().subList(0, n));
091
092            // Reflection.
093            final PointValuePair reflected = Simplex.newPoint(centroid,
094                                                              -alpha,
095                                                              xWorst,
096                                                              evaluationFunction);
097            if (comparator.compare(reflected, secondWorst) < 0 &&
098                comparator.compare(best, reflected) <= 0) {
099                return original.replaceLast(reflected);
100            }
101
102            if (comparator.compare(reflected, best) < 0) {
103                // Expansion.
104                final PointValuePair expanded = Simplex.newPoint(centroid,
105                                                                 -gamma,
106                                                                 xWorst,
107                                                                 evaluationFunction);
108                if (comparator.compare(expanded, reflected) < 0 ||
109                    (sa != null &&
110                     sa.test(expanded.getValue() - reflected.getValue()))) {
111                    return original.replaceLast(expanded);
112                } else {
113                    return original.replaceLast(reflected);
114                }
115            }
116
117            if (comparator.compare(reflected, worst) < 0) {
118                // Outside contraction.
119                final PointValuePair contracted = Simplex.newPoint(centroid,
120                                                                   rho,
121                                                                   reflected.getPoint(),
122                                                                   evaluationFunction);
123                if (comparator.compare(contracted, reflected) < 0) {
124                    return original.replaceLast(contracted); // Accept contracted point.
125                }
126            } else {
127                // Inside contraction.
128                final PointValuePair contracted = Simplex.newPoint(centroid,
129                                                                   rho,
130                                                                   xWorst,
131                                                                   evaluationFunction);
132                if (comparator.compare(contracted, worst) < 0) {
133                    return original.replaceLast(contracted); // Accept contracted point.
134                }
135            }
136
137            // Shrink.
138            return original.shrink(sigma, evaluationFunction);
139        };
140    }
141
142    /** {@inheritDoc} */
143    @Override
144    public String toString() {
145        return "Nelder-Mead [a=" + alpha +
146            " g=" + gamma +
147            " r=" + rho +
148            " s=" + sigma + "]";
149    }
150}