View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.math3.optim.nonlinear.scalar.noderiv;
18  
19  import java.util.Comparator;
20  
21  import org.apache.commons.math3.optim.PointValuePair;
22  import org.apache.commons.math3.analysis.MultivariateFunction;
23  
24  /**
25   * This class implements the Nelder-Mead simplex algorithm.
26   *
27   * @version $Id: NelderMeadSimplex.java 1435539 2013-01-19 13:27:24Z tn $
28   * @since 3.0
29   */
30  public class NelderMeadSimplex extends AbstractSimplex {
31      /** Default value for {@link #rho}: {@value}. */
32      private static final double DEFAULT_RHO = 1;
33      /** Default value for {@link #khi}: {@value}. */
34      private static final double DEFAULT_KHI = 2;
35      /** Default value for {@link #gamma}: {@value}. */
36      private static final double DEFAULT_GAMMA = 0.5;
37      /** Default value for {@link #sigma}: {@value}. */
38      private static final double DEFAULT_SIGMA = 0.5;
39      /** Reflection coefficient. */
40      private final double rho;
41      /** Expansion coefficient. */
42      private final double khi;
43      /** Contraction coefficient. */
44      private final double gamma;
45      /** Shrinkage coefficient. */
46      private final double sigma;
47  
48      /**
49       * Build a Nelder-Mead simplex with default coefficients.
50       * The default coefficients are 1.0 for rho, 2.0 for khi and 0.5
51       * for both gamma and sigma.
52       *
53       * @param n Dimension of the simplex.
54       */
55      public NelderMeadSimplex(final int n) {
56          this(n, 1d);
57      }
58  
59      /**
60       * Build a Nelder-Mead simplex with default coefficients.
61       * The default coefficients are 1.0 for rho, 2.0 for khi and 0.5
62       * for both gamma and sigma.
63       *
64       * @param n Dimension of the simplex.
65       * @param sideLength Length of the sides of the default (hypercube)
66       * simplex. See {@link AbstractSimplex#AbstractSimplex(int,double)}.
67       */
68      public NelderMeadSimplex(final int n, double sideLength) {
69          this(n, sideLength,
70               DEFAULT_RHO, DEFAULT_KHI, DEFAULT_GAMMA, DEFAULT_SIGMA);
71      }
72  
73      /**
74       * Build a Nelder-Mead simplex with specified coefficients.
75       *
76       * @param n Dimension of the simplex. See
77       * {@link AbstractSimplex#AbstractSimplex(int,double)}.
78       * @param sideLength Length of the sides of the default (hypercube)
79       * simplex. See {@link AbstractSimplex#AbstractSimplex(int,double)}.
80       * @param rho Reflection coefficient.
81       * @param khi Expansion coefficient.
82       * @param gamma Contraction coefficient.
83       * @param sigma Shrinkage coefficient.
84       */
85      public NelderMeadSimplex(final int n, double sideLength,
86                               final double rho, final double khi,
87                               final double gamma, final double sigma) {
88          super(n, sideLength);
89  
90          this.rho = rho;
91          this.khi = khi;
92          this.gamma = gamma;
93          this.sigma = sigma;
94      }
95  
96      /**
97       * Build a Nelder-Mead simplex with specified coefficients.
98       *
99       * @param n Dimension of the simplex. See
100      * {@link AbstractSimplex#AbstractSimplex(int)}.
101      * @param rho Reflection coefficient.
102      * @param khi Expansion coefficient.
103      * @param gamma Contraction coefficient.
104      * @param sigma Shrinkage coefficient.
105      */
106     public NelderMeadSimplex(final int n,
107                              final double rho, final double khi,
108                              final double gamma, final double sigma) {
109         this(n, 1d, rho, khi, gamma, sigma);
110     }
111 
112     /**
113      * Build a Nelder-Mead simplex with default coefficients.
114      * The default coefficients are 1.0 for rho, 2.0 for khi and 0.5
115      * for both gamma and sigma.
116      *
117      * @param steps Steps along the canonical axes representing box edges.
118      * They may be negative but not zero. See
119      */
120     public NelderMeadSimplex(final double[] steps) {
121         this(steps, DEFAULT_RHO, DEFAULT_KHI, DEFAULT_GAMMA, DEFAULT_SIGMA);
122     }
123 
124     /**
125      * Build a Nelder-Mead simplex with specified coefficients.
126      *
127      * @param steps Steps along the canonical axes representing box edges.
128      * They may be negative but not zero. See
129      * {@link AbstractSimplex#AbstractSimplex(double[])}.
130      * @param rho Reflection coefficient.
131      * @param khi Expansion coefficient.
132      * @param gamma Contraction coefficient.
133      * @param sigma Shrinkage coefficient.
134      * @throws IllegalArgumentException if one of the steps is zero.
135      */
136     public NelderMeadSimplex(final double[] steps,
137                              final double rho, final double khi,
138                              final double gamma, final double sigma) {
139         super(steps);
140 
141         this.rho = rho;
142         this.khi = khi;
143         this.gamma = gamma;
144         this.sigma = sigma;
145     }
146 
147     /**
148      * Build a Nelder-Mead simplex with default coefficients.
149      * The default coefficients are 1.0 for rho, 2.0 for khi and 0.5
150      * for both gamma and sigma.
151      *
152      * @param referenceSimplex Reference simplex. See
153      * {@link AbstractSimplex#AbstractSimplex(double[][])}.
154      */
155     public NelderMeadSimplex(final double[][] referenceSimplex) {
156         this(referenceSimplex, DEFAULT_RHO, DEFAULT_KHI, DEFAULT_GAMMA, DEFAULT_SIGMA);
157     }
158 
159     /**
160      * Build a Nelder-Mead simplex with specified coefficients.
161      *
162      * @param referenceSimplex Reference simplex. See
163      * {@link AbstractSimplex#AbstractSimplex(double[][])}.
164      * @param rho Reflection coefficient.
165      * @param khi Expansion coefficient.
166      * @param gamma Contraction coefficient.
167      * @param sigma Shrinkage coefficient.
168      * @throws org.apache.commons.math3.exception.NotStrictlyPositiveException
169      * if the reference simplex does not contain at least one point.
170      * @throws org.apache.commons.math3.exception.DimensionMismatchException
171      * if there is a dimension mismatch in the reference simplex.
172      */
173     public NelderMeadSimplex(final double[][] referenceSimplex,
174                              final double rho, final double khi,
175                              final double gamma, final double sigma) {
176         super(referenceSimplex);
177 
178         this.rho = rho;
179         this.khi = khi;
180         this.gamma = gamma;
181         this.sigma = sigma;
182     }
183 
184     /** {@inheritDoc} */
185     @Override
186     public void iterate(final MultivariateFunction evaluationFunction,
187                         final Comparator<PointValuePair> comparator) {
188         // The simplex has n + 1 points if dimension is n.
189         final int n = getDimension();
190 
191         // Interesting values.
192         final PointValuePair best = getPoint(0);
193         final PointValuePair secondBest = getPoint(n - 1);
194         final PointValuePair worst = getPoint(n);
195         final double[] xWorst = worst.getPointRef();
196 
197         // Compute the centroid of the best vertices (dismissing the worst
198         // point at index n).
199         final double[] centroid = new double[n];
200         for (int i = 0; i < n; i++) {
201             final double[] x = getPoint(i).getPointRef();
202             for (int j = 0; j < n; j++) {
203                 centroid[j] += x[j];
204             }
205         }
206         final double scaling = 1.0 / n;
207         for (int j = 0; j < n; j++) {
208             centroid[j] *= scaling;
209         }
210 
211         // compute the reflection point
212         final double[] xR = new double[n];
213         for (int j = 0; j < n; j++) {
214             xR[j] = centroid[j] + rho * (centroid[j] - xWorst[j]);
215         }
216         final PointValuePair reflected
217             = new PointValuePair(xR, evaluationFunction.value(xR), false);
218 
219         if (comparator.compare(best, reflected) <= 0 &&
220             comparator.compare(reflected, secondBest) < 0) {
221             // Accept the reflected point.
222             replaceWorstPoint(reflected, comparator);
223         } else if (comparator.compare(reflected, best) < 0) {
224             // Compute the expansion point.
225             final double[] xE = new double[n];
226             for (int j = 0; j < n; j++) {
227                 xE[j] = centroid[j] + khi * (xR[j] - centroid[j]);
228             }
229             final PointValuePair expanded
230                 = new PointValuePair(xE, evaluationFunction.value(xE), false);
231 
232             if (comparator.compare(expanded, reflected) < 0) {
233                 // Accept the expansion point.
234                 replaceWorstPoint(expanded, comparator);
235             } else {
236                 // Accept the reflected point.
237                 replaceWorstPoint(reflected, comparator);
238             }
239         } else {
240             if (comparator.compare(reflected, worst) < 0) {
241                 // Perform an outside contraction.
242                 final double[] xC = new double[n];
243                 for (int j = 0; j < n; j++) {
244                     xC[j] = centroid[j] + gamma * (xR[j] - centroid[j]);
245                 }
246                 final PointValuePair outContracted
247                     = new PointValuePair(xC, evaluationFunction.value(xC), false);
248                 if (comparator.compare(outContracted, reflected) <= 0) {
249                     // Accept the contraction point.
250                     replaceWorstPoint(outContracted, comparator);
251                     return;
252                 }
253             } else {
254                 // Perform an inside contraction.
255                 final double[] xC = new double[n];
256                 for (int j = 0; j < n; j++) {
257                     xC[j] = centroid[j] - gamma * (centroid[j] - xWorst[j]);
258                 }
259                 final PointValuePair inContracted
260                     = new PointValuePair(xC, evaluationFunction.value(xC), false);
261 
262                 if (comparator.compare(inContracted, worst) < 0) {
263                     // Accept the contraction point.
264                     replaceWorstPoint(inContracted, comparator);
265                     return;
266                 }
267             }
268 
269             // Perform a shrink.
270             final double[] xSmallest = getPoint(0).getPointRef();
271             for (int i = 1; i <= n; i++) {
272                 final double[] x = getPoint(i).getPoint();
273                 for (int j = 0; j < n; j++) {
274                     x[j] = xSmallest[j] + sigma * (x[j] - xSmallest[j]);
275                 }
276                 setPoint(i, new PointValuePair(x, Double.NaN, false));
277             }
278             evaluate(evaluationFunction, comparator);
279         }
280     }
281 }