View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.math3.optim.nonlinear.scalar.noderiv;
18  
19  import java.util.Comparator;
20  
21  import org.apache.commons.math3.optim.PointValuePair;
22  import org.apache.commons.math3.analysis.MultivariateFunction;
23  
24  /**
25   * This class implements the Nelder-Mead simplex algorithm.
26   *
27   * @since 3.0
28   */
29  public class NelderMeadSimplex extends AbstractSimplex {
30      /** Default value for {@link #rho}: {@value}. */
31      private static final double DEFAULT_RHO = 1;
32      /** Default value for {@link #khi}: {@value}. */
33      private static final double DEFAULT_KHI = 2;
34      /** Default value for {@link #gamma}: {@value}. */
35      private static final double DEFAULT_GAMMA = 0.5;
36      /** Default value for {@link #sigma}: {@value}. */
37      private static final double DEFAULT_SIGMA = 0.5;
38      /** Reflection coefficient. */
39      private final double rho;
40      /** Expansion coefficient. */
41      private final double khi;
42      /** Contraction coefficient. */
43      private final double gamma;
44      /** Shrinkage coefficient. */
45      private final double sigma;
46  
47      /**
48       * Build a Nelder-Mead simplex with default coefficients.
49       * The default coefficients are 1.0 for rho, 2.0 for khi and 0.5
50       * for both gamma and sigma.
51       *
52       * @param n Dimension of the simplex.
53       */
54      public NelderMeadSimplex(final int n) {
55          this(n, 1d);
56      }
57  
58      /**
59       * Build a Nelder-Mead simplex with default coefficients.
60       * The default coefficients are 1.0 for rho, 2.0 for khi and 0.5
61       * for both gamma and sigma.
62       *
63       * @param n Dimension of the simplex.
64       * @param sideLength Length of the sides of the default (hypercube)
65       * simplex. See {@link AbstractSimplex#AbstractSimplex(int,double)}.
66       */
67      public NelderMeadSimplex(final int n, double sideLength) {
68          this(n, sideLength,
69               DEFAULT_RHO, DEFAULT_KHI, DEFAULT_GAMMA, DEFAULT_SIGMA);
70      }
71  
72      /**
73       * Build a Nelder-Mead simplex with specified coefficients.
74       *
75       * @param n Dimension of the simplex. See
76       * {@link AbstractSimplex#AbstractSimplex(int,double)}.
77       * @param sideLength Length of the sides of the default (hypercube)
78       * simplex. See {@link AbstractSimplex#AbstractSimplex(int,double)}.
79       * @param rho Reflection coefficient.
80       * @param khi Expansion coefficient.
81       * @param gamma Contraction coefficient.
82       * @param sigma Shrinkage coefficient.
83       */
84      public NelderMeadSimplex(final int n, double sideLength,
85                               final double rho, final double khi,
86                               final double gamma, final double sigma) {
87          super(n, sideLength);
88  
89          this.rho = rho;
90          this.khi = khi;
91          this.gamma = gamma;
92          this.sigma = sigma;
93      }
94  
95      /**
96       * Build a Nelder-Mead simplex with specified coefficients.
97       *
98       * @param n Dimension of the simplex. See
99       * {@link AbstractSimplex#AbstractSimplex(int)}.
100      * @param rho Reflection coefficient.
101      * @param khi Expansion coefficient.
102      * @param gamma Contraction coefficient.
103      * @param sigma Shrinkage coefficient.
104      */
105     public NelderMeadSimplex(final int n,
106                              final double rho, final double khi,
107                              final double gamma, final double sigma) {
108         this(n, 1d, rho, khi, gamma, sigma);
109     }
110 
111     /**
112      * Build a Nelder-Mead simplex with default coefficients.
113      * The default coefficients are 1.0 for rho, 2.0 for khi and 0.5
114      * for both gamma and sigma.
115      *
116      * @param steps Steps along the canonical axes representing box edges.
117      * They may be negative but not zero. See
118      */
119     public NelderMeadSimplex(final double[] steps) {
120         this(steps, DEFAULT_RHO, DEFAULT_KHI, DEFAULT_GAMMA, DEFAULT_SIGMA);
121     }
122 
123     /**
124      * Build a Nelder-Mead simplex with specified coefficients.
125      *
126      * @param steps Steps along the canonical axes representing box edges.
127      * They may be negative but not zero. See
128      * {@link AbstractSimplex#AbstractSimplex(double[])}.
129      * @param rho Reflection coefficient.
130      * @param khi Expansion coefficient.
131      * @param gamma Contraction coefficient.
132      * @param sigma Shrinkage coefficient.
133      * @throws IllegalArgumentException if one of the steps is zero.
134      */
135     public NelderMeadSimplex(final double[] steps,
136                              final double rho, final double khi,
137                              final double gamma, final double sigma) {
138         super(steps);
139 
140         this.rho = rho;
141         this.khi = khi;
142         this.gamma = gamma;
143         this.sigma = sigma;
144     }
145 
146     /**
147      * Build a Nelder-Mead simplex with default coefficients.
148      * The default coefficients are 1.0 for rho, 2.0 for khi and 0.5
149      * for both gamma and sigma.
150      *
151      * @param referenceSimplex Reference simplex. See
152      * {@link AbstractSimplex#AbstractSimplex(double[][])}.
153      */
154     public NelderMeadSimplex(final double[][] referenceSimplex) {
155         this(referenceSimplex, DEFAULT_RHO, DEFAULT_KHI, DEFAULT_GAMMA, DEFAULT_SIGMA);
156     }
157 
158     /**
159      * Build a Nelder-Mead simplex with specified coefficients.
160      *
161      * @param referenceSimplex Reference simplex. See
162      * {@link AbstractSimplex#AbstractSimplex(double[][])}.
163      * @param rho Reflection coefficient.
164      * @param khi Expansion coefficient.
165      * @param gamma Contraction coefficient.
166      * @param sigma Shrinkage coefficient.
167      * @throws org.apache.commons.math3.exception.NotStrictlyPositiveException
168      * if the reference simplex does not contain at least one point.
169      * @throws org.apache.commons.math3.exception.DimensionMismatchException
170      * if there is a dimension mismatch in the reference simplex.
171      */
172     public NelderMeadSimplex(final double[][] referenceSimplex,
173                              final double rho, final double khi,
174                              final double gamma, final double sigma) {
175         super(referenceSimplex);
176 
177         this.rho = rho;
178         this.khi = khi;
179         this.gamma = gamma;
180         this.sigma = sigma;
181     }
182 
183     /** {@inheritDoc} */
184     @Override
185     public void iterate(final MultivariateFunction evaluationFunction,
186                         final Comparator<PointValuePair> comparator) {
187         // The simplex has n + 1 points if dimension is n.
188         final int n = getDimension();
189 
190         // Interesting values.
191         final PointValuePair best = getPoint(0);
192         final PointValuePair secondBest = getPoint(n - 1);
193         final PointValuePair worst = getPoint(n);
194         final double[] xWorst = worst.getPointRef();
195 
196         // Compute the centroid of the best vertices (dismissing the worst
197         // point at index n).
198         final double[] centroid = new double[n];
199         for (int i = 0; i < n; i++) {
200             final double[] x = getPoint(i).getPointRef();
201             for (int j = 0; j < n; j++) {
202                 centroid[j] += x[j];
203             }
204         }
205         final double scaling = 1.0 / n;
206         for (int j = 0; j < n; j++) {
207             centroid[j] *= scaling;
208         }
209 
210         // compute the reflection point
211         final double[] xR = new double[n];
212         for (int j = 0; j < n; j++) {
213             xR[j] = centroid[j] + rho * (centroid[j] - xWorst[j]);
214         }
215         final PointValuePair reflected
216             = new PointValuePair(xR, evaluationFunction.value(xR), false);
217 
218         if (comparator.compare(best, reflected) <= 0 &&
219             comparator.compare(reflected, secondBest) < 0) {
220             // Accept the reflected point.
221             replaceWorstPoint(reflected, comparator);
222         } else if (comparator.compare(reflected, best) < 0) {
223             // Compute the expansion point.
224             final double[] xE = new double[n];
225             for (int j = 0; j < n; j++) {
226                 xE[j] = centroid[j] + khi * (xR[j] - centroid[j]);
227             }
228             final PointValuePair expanded
229                 = new PointValuePair(xE, evaluationFunction.value(xE), false);
230 
231             if (comparator.compare(expanded, reflected) < 0) {
232                 // Accept the expansion point.
233                 replaceWorstPoint(expanded, comparator);
234             } else {
235                 // Accept the reflected point.
236                 replaceWorstPoint(reflected, comparator);
237             }
238         } else {
239             if (comparator.compare(reflected, worst) < 0) {
240                 // Perform an outside contraction.
241                 final double[] xC = new double[n];
242                 for (int j = 0; j < n; j++) {
243                     xC[j] = centroid[j] + gamma * (xR[j] - centroid[j]);
244                 }
245                 final PointValuePair outContracted
246                     = new PointValuePair(xC, evaluationFunction.value(xC), false);
247                 if (comparator.compare(outContracted, reflected) <= 0) {
248                     // Accept the contraction point.
249                     replaceWorstPoint(outContracted, comparator);
250                     return;
251                 }
252             } else {
253                 // Perform an inside contraction.
254                 final double[] xC = new double[n];
255                 for (int j = 0; j < n; j++) {
256                     xC[j] = centroid[j] - gamma * (centroid[j] - xWorst[j]);
257                 }
258                 final PointValuePair inContracted
259                     = new PointValuePair(xC, evaluationFunction.value(xC), false);
260 
261                 if (comparator.compare(inContracted, worst) < 0) {
262                     // Accept the contraction point.
263                     replaceWorstPoint(inContracted, comparator);
264                     return;
265                 }
266             }
267 
268             // Perform a shrink.
269             final double[] xSmallest = getPoint(0).getPointRef();
270             for (int i = 1; i <= n; i++) {
271                 final double[] x = getPoint(i).getPoint();
272                 for (int j = 0; j < n; j++) {
273                     x[j] = xSmallest[j] + sigma * (x[j] - xSmallest[j]);
274                 }
275                 setPoint(i, new PointValuePair(x, Double.NaN, false));
276             }
277             evaluate(evaluationFunction, comparator);
278         }
279     }
280 }