001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.commons.math3.optim.nonlinear.scalar.noderiv; 018 019import java.util.Comparator; 020 021import org.apache.commons.math3.optim.PointValuePair; 022import org.apache.commons.math3.analysis.MultivariateFunction; 023 024/** 025 * This class implements the Nelder-Mead simplex algorithm. 026 * 027 * @since 3.0 028 */ 029public class NelderMeadSimplex extends AbstractSimplex { 030 /** Default value for {@link #rho}: {@value}. */ 031 private static final double DEFAULT_RHO = 1; 032 /** Default value for {@link #khi}: {@value}. */ 033 private static final double DEFAULT_KHI = 2; 034 /** Default value for {@link #gamma}: {@value}. */ 035 private static final double DEFAULT_GAMMA = 0.5; 036 /** Default value for {@link #sigma}: {@value}. */ 037 private static final double DEFAULT_SIGMA = 0.5; 038 /** Reflection coefficient. */ 039 private final double rho; 040 /** Expansion coefficient. */ 041 private final double khi; 042 /** Contraction coefficient. */ 043 private final double gamma; 044 /** Shrinkage coefficient. */ 045 private final double sigma; 046 047 /** 048 * Build a Nelder-Mead simplex with default coefficients. 049 * The default coefficients are 1.0 for rho, 2.0 for khi and 0.5 050 * for both gamma and sigma. 051 * 052 * @param n Dimension of the simplex. 053 */ 054 public NelderMeadSimplex(final int n) { 055 this(n, 1d); 056 } 057 058 /** 059 * Build a Nelder-Mead simplex with default coefficients. 060 * The default coefficients are 1.0 for rho, 2.0 for khi and 0.5 061 * for both gamma and sigma. 062 * 063 * @param n Dimension of the simplex. 064 * @param sideLength Length of the sides of the default (hypercube) 065 * simplex. See {@link AbstractSimplex#AbstractSimplex(int,double)}. 066 */ 067 public NelderMeadSimplex(final int n, double sideLength) { 068 this(n, sideLength, 069 DEFAULT_RHO, DEFAULT_KHI, DEFAULT_GAMMA, DEFAULT_SIGMA); 070 } 071 072 /** 073 * Build a Nelder-Mead simplex with specified coefficients. 074 * 075 * @param n Dimension of the simplex. See 076 * {@link AbstractSimplex#AbstractSimplex(int,double)}. 077 * @param sideLength Length of the sides of the default (hypercube) 078 * simplex. See {@link AbstractSimplex#AbstractSimplex(int,double)}. 079 * @param rho Reflection coefficient. 080 * @param khi Expansion coefficient. 081 * @param gamma Contraction coefficient. 082 * @param sigma Shrinkage coefficient. 083 */ 084 public NelderMeadSimplex(final int n, double sideLength, 085 final double rho, final double khi, 086 final double gamma, final double sigma) { 087 super(n, sideLength); 088 089 this.rho = rho; 090 this.khi = khi; 091 this.gamma = gamma; 092 this.sigma = sigma; 093 } 094 095 /** 096 * Build a Nelder-Mead simplex with specified coefficients. 097 * 098 * @param n Dimension of the simplex. See 099 * {@link AbstractSimplex#AbstractSimplex(int)}. 100 * @param rho Reflection coefficient. 101 * @param khi Expansion coefficient. 102 * @param gamma Contraction coefficient. 103 * @param sigma Shrinkage coefficient. 104 */ 105 public NelderMeadSimplex(final int n, 106 final double rho, final double khi, 107 final double gamma, final double sigma) { 108 this(n, 1d, rho, khi, gamma, sigma); 109 } 110 111 /** 112 * Build a Nelder-Mead simplex with default coefficients. 113 * The default coefficients are 1.0 for rho, 2.0 for khi and 0.5 114 * for both gamma and sigma. 115 * 116 * @param steps Steps along the canonical axes representing box edges. 117 * They may be negative but not zero. See 118 */ 119 public NelderMeadSimplex(final double[] steps) { 120 this(steps, DEFAULT_RHO, DEFAULT_KHI, DEFAULT_GAMMA, DEFAULT_SIGMA); 121 } 122 123 /** 124 * Build a Nelder-Mead simplex with specified coefficients. 125 * 126 * @param steps Steps along the canonical axes representing box edges. 127 * They may be negative but not zero. See 128 * {@link AbstractSimplex#AbstractSimplex(double[])}. 129 * @param rho Reflection coefficient. 130 * @param khi Expansion coefficient. 131 * @param gamma Contraction coefficient. 132 * @param sigma Shrinkage coefficient. 133 * @throws IllegalArgumentException if one of the steps is zero. 134 */ 135 public NelderMeadSimplex(final double[] steps, 136 final double rho, final double khi, 137 final double gamma, final double sigma) { 138 super(steps); 139 140 this.rho = rho; 141 this.khi = khi; 142 this.gamma = gamma; 143 this.sigma = sigma; 144 } 145 146 /** 147 * Build a Nelder-Mead simplex with default coefficients. 148 * The default coefficients are 1.0 for rho, 2.0 for khi and 0.5 149 * for both gamma and sigma. 150 * 151 * @param referenceSimplex Reference simplex. See 152 * {@link AbstractSimplex#AbstractSimplex(double[][])}. 153 */ 154 public NelderMeadSimplex(final double[][] referenceSimplex) { 155 this(referenceSimplex, DEFAULT_RHO, DEFAULT_KHI, DEFAULT_GAMMA, DEFAULT_SIGMA); 156 } 157 158 /** 159 * Build a Nelder-Mead simplex with specified coefficients. 160 * 161 * @param referenceSimplex Reference simplex. See 162 * {@link AbstractSimplex#AbstractSimplex(double[][])}. 163 * @param rho Reflection coefficient. 164 * @param khi Expansion coefficient. 165 * @param gamma Contraction coefficient. 166 * @param sigma Shrinkage coefficient. 167 * @throws org.apache.commons.math3.exception.NotStrictlyPositiveException 168 * if the reference simplex does not contain at least one point. 169 * @throws org.apache.commons.math3.exception.DimensionMismatchException 170 * if there is a dimension mismatch in the reference simplex. 171 */ 172 public NelderMeadSimplex(final double[][] referenceSimplex, 173 final double rho, final double khi, 174 final double gamma, final double sigma) { 175 super(referenceSimplex); 176 177 this.rho = rho; 178 this.khi = khi; 179 this.gamma = gamma; 180 this.sigma = sigma; 181 } 182 183 /** {@inheritDoc} */ 184 @Override 185 public void iterate(final MultivariateFunction evaluationFunction, 186 final Comparator<PointValuePair> comparator) { 187 // The simplex has n + 1 points if dimension is n. 188 final int n = getDimension(); 189 190 // Interesting values. 191 final PointValuePair best = getPoint(0); 192 final PointValuePair secondBest = getPoint(n - 1); 193 final PointValuePair worst = getPoint(n); 194 final double[] xWorst = worst.getPointRef(); 195 196 // Compute the centroid of the best vertices (dismissing the worst 197 // point at index n). 198 final double[] centroid = new double[n]; 199 for (int i = 0; i < n; i++) { 200 final double[] x = getPoint(i).getPointRef(); 201 for (int j = 0; j < n; j++) { 202 centroid[j] += x[j]; 203 } 204 } 205 final double scaling = 1.0 / n; 206 for (int j = 0; j < n; j++) { 207 centroid[j] *= scaling; 208 } 209 210 // compute the reflection point 211 final double[] xR = new double[n]; 212 for (int j = 0; j < n; j++) { 213 xR[j] = centroid[j] + rho * (centroid[j] - xWorst[j]); 214 } 215 final PointValuePair reflected 216 = new PointValuePair(xR, evaluationFunction.value(xR), false); 217 218 if (comparator.compare(best, reflected) <= 0 && 219 comparator.compare(reflected, secondBest) < 0) { 220 // Accept the reflected point. 221 replaceWorstPoint(reflected, comparator); 222 } else if (comparator.compare(reflected, best) < 0) { 223 // Compute the expansion point. 224 final double[] xE = new double[n]; 225 for (int j = 0; j < n; j++) { 226 xE[j] = centroid[j] + khi * (xR[j] - centroid[j]); 227 } 228 final PointValuePair expanded 229 = new PointValuePair(xE, evaluationFunction.value(xE), false); 230 231 if (comparator.compare(expanded, reflected) < 0) { 232 // Accept the expansion point. 233 replaceWorstPoint(expanded, comparator); 234 } else { 235 // Accept the reflected point. 236 replaceWorstPoint(reflected, comparator); 237 } 238 } else { 239 if (comparator.compare(reflected, worst) < 0) { 240 // Perform an outside contraction. 241 final double[] xC = new double[n]; 242 for (int j = 0; j < n; j++) { 243 xC[j] = centroid[j] + gamma * (xR[j] - centroid[j]); 244 } 245 final PointValuePair outContracted 246 = new PointValuePair(xC, evaluationFunction.value(xC), false); 247 if (comparator.compare(outContracted, reflected) <= 0) { 248 // Accept the contraction point. 249 replaceWorstPoint(outContracted, comparator); 250 return; 251 } 252 } else { 253 // Perform an inside contraction. 254 final double[] xC = new double[n]; 255 for (int j = 0; j < n; j++) { 256 xC[j] = centroid[j] - gamma * (centroid[j] - xWorst[j]); 257 } 258 final PointValuePair inContracted 259 = new PointValuePair(xC, evaluationFunction.value(xC), false); 260 261 if (comparator.compare(inContracted, worst) < 0) { 262 // Accept the contraction point. 263 replaceWorstPoint(inContracted, comparator); 264 return; 265 } 266 } 267 268 // Perform a shrink. 269 final double[] xSmallest = getPoint(0).getPointRef(); 270 for (int i = 1; i <= n; i++) { 271 final double[] x = getPoint(i).getPoint(); 272 for (int j = 0; j < n; j++) { 273 x[j] = xSmallest[j] + sigma * (x[j] - xSmallest[j]); 274 } 275 setPoint(i, new PointValuePair(x, Double.NaN, false)); 276 } 277 evaluate(evaluationFunction, comparator); 278 } 279 } 280}