001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.commons.math4.legacy.optim.nonlinear.scalar.noderiv; 018 019import java.util.Arrays; 020import java.util.List; 021import java.util.ArrayList; 022import java.util.Comparator; 023import java.util.Collections; 024import java.util.Objects; 025import java.util.function.UnaryOperator; 026import java.util.function.IntSupplier; 027import java.util.concurrent.CopyOnWriteArrayList; 028 029import org.apache.commons.math4.legacy.core.MathArrays; 030import org.apache.commons.math4.legacy.analysis.MultivariateFunction; 031import org.apache.commons.math4.legacy.exception.MathUnsupportedOperationException; 032import org.apache.commons.math4.legacy.exception.MathInternalError; 033import org.apache.commons.math4.legacy.exception.util.LocalizedFormats; 034import org.apache.commons.math4.legacy.optim.ConvergenceChecker; 035import org.apache.commons.math4.legacy.optim.OptimizationData; 036import org.apache.commons.math4.legacy.optim.PointValuePair; 037import org.apache.commons.math4.legacy.optim.SimpleValueChecker; 038import org.apache.commons.math4.legacy.optim.InitialGuess; 039import org.apache.commons.math4.legacy.optim.MaxEval; 040import org.apache.commons.math4.legacy.optim.nonlinear.scalar.GoalType; 041import org.apache.commons.math4.legacy.optim.nonlinear.scalar.MultivariateOptimizer; 042import org.apache.commons.math4.legacy.optim.nonlinear.scalar.SimulatedAnnealing; 043import org.apache.commons.math4.legacy.optim.nonlinear.scalar.PopulationSize; 044import org.apache.commons.math4.legacy.optim.nonlinear.scalar.ObjectiveFunction; 045 046/** 047 * This class implements simplex-based direct search optimization. 048 * 049 * <p> 050 * Direct search methods only use objective function values, they do 051 * not need derivatives and don't either try to compute approximation 052 * of the derivatives. According to a 1996 paper by Margaret H. Wright 053 * (<a href="http://cm.bell-labs.com/cm/cs/doc/96/4-02.ps.gz">Direct 054 * Search Methods: Once Scorned, Now Respectable</a>), they are used 055 * when either the computation of the derivative is impossible (noisy 056 * functions, unpredictable discontinuities) or difficult (complexity, 057 * computation cost). In the first cases, rather than an optimum, a 058 * <em>not too bad</em> point is desired. In the latter cases, an 059 * optimum is desired but cannot be reasonably found. In all cases 060 * direct search methods can be useful. 061 * 062 * <p> 063 * Simplex-based direct search methods are based on comparison of 064 * the objective function values at the vertices of a simplex (which is a 065 * set of n+1 points in dimension n) that is updated by the algorithms 066 * steps. 067 * 068 * <p> 069 * In addition to those documented in 070 * {@link MultivariateOptimizer#optimize(OptimizationData[]) MultivariateOptimizer}, 071 * an instance of this class will register the following data: 072 * <ul> 073 * <li>{@link Simplex}</li> 074 * <li>{@link Simplex.TransformFactory}</li> 075 * <li>{@link SimulatedAnnealing}</li> 076 * <li>{@link PopulationSize}</li> 077 * </ul> 078 * 079 * <p> 080 * Each call to {@code optimize} will re-use the start configuration of 081 * the current simplex and move it such that its first vertex is at the 082 * provided start point of the optimization. 083 * If the {@code optimize} method is called to solve a different problem 084 * and the number of parameters change, the simplex must be re-initialized 085 * to one with the appropriate dimensions. 086 * 087 * <p> 088 * Convergence is considered achieved when <em>all</em> the simplex points 089 * have converged. 090 * <p> 091 * Whenever {@link SimulatedAnnealing simulated annealing (SA)} is activated, 092 * and the SA phase has completed, convergence has probably not been reached 093 * yet; whenever it's the case, an additional (non-SA) search will be performed 094 * (using the current best simplex point as a start point). 095 * <p> 096 * Additional "best list" searches can be requested through setting the 097 * {@link PopulationSize} argument of the {@link #optimize(OptimizationData[]) 098 * optimize} method. 099 * 100 * <p> 101 * This implementation does not directly support constrained optimization 102 * with simple bounds. 103 * The call to {@link #optimize(OptimizationData[]) optimize} will throw 104 * {@link MathUnsupportedOperationException} if bounds are passed to it. 105 * 106 * @see NelderMeadTransform 107 * @see MultiDirectionalTransform 108 * @see HedarFukushimaTransform 109 */ 110public class SimplexOptimizer extends MultivariateOptimizer { 111 /** Default simplex side length ratio. */ 112 private static final double SIMPLEX_SIDE_RATIO = 1e-1; 113 /** Simplex update function factory. */ 114 private Simplex.TransformFactory updateRule; 115 /** Initial simplex. */ 116 private Simplex initialSimplex; 117 /** Simulated annealing setup (optional). */ 118 private SimulatedAnnealing simulatedAnnealing = null; 119 /** User-defined number of additional optimizations (optional). */ 120 private int populationSize = 0; 121 /** Actual number of additional optimizations. */ 122 private int additionalSearch = 0; 123 /** Callbacks. */ 124 private final List<Observer> callbacks = new CopyOnWriteArrayList<>(); 125 126 /** 127 * @param checker Convergence checker. 128 */ 129 public SimplexOptimizer(ConvergenceChecker<PointValuePair> checker) { 130 super(checker); 131 } 132 133 /** 134 * @param rel Relative threshold. 135 * @param abs Absolute threshold. 136 */ 137 public SimplexOptimizer(double rel, 138 double abs) { 139 this(new SimpleValueChecker(rel, abs)); 140 } 141 142 /** 143 * Callback interface for updating caller's code with the current 144 * state of the optimization. 145 */ 146 @FunctionalInterface 147 public interface Observer { 148 /** 149 * Method called after each modification of the {@code simplex}. 150 * 151 * @param simplex Current simplex. 152 * @param isInit {@code true} at the start of a new search (either 153 * "main" or "best list"), after the initial simplex's vertices 154 * have been evaluated. 155 * @param numEval Number of evaluations of the objective function. 156 */ 157 void update(Simplex simplex, 158 boolean isInit, 159 int numEval); 160 } 161 162 /** 163 * Register a callback. 164 * 165 * @param cb Callback. 166 */ 167 public void addObserver(Observer cb) { 168 Objects.requireNonNull(cb, "Callback"); 169 callbacks.add(cb); 170 } 171 172 /** {@inheritDoc} */ 173 @Override 174 protected PointValuePair doOptimize() { 175 checkParameters(); 176 177 // Indirect call to "computeObjectiveValue" in order to update the 178 // evaluations counter. 179 final MultivariateFunction evalFunc = this::computeObjectiveValue; 180 181 final boolean isMinim = getGoalType() == GoalType.MINIMIZE; 182 final Comparator<PointValuePair> comparator = (o1, o2) -> { 183 final double v1 = o1.getValue(); 184 final double v2 = o2.getValue(); 185 return isMinim ? Double.compare(v1, v2) : Double.compare(v2, v1); 186 }; 187 188 // Start points for additional search. 189 final List<PointValuePair> bestList = new ArrayList<>(); 190 191 Simplex currentSimplex = initialSimplex.translate(getStartPoint()).evaluate(evalFunc, comparator); 192 notifyObservers(currentSimplex, true); 193 double temperature = Double.NaN; // Only used with simulated annealing. 194 Simplex previousSimplex = null; 195 196 if (simulatedAnnealing != null) { 197 temperature = 198 temperature(currentSimplex.get(0), 199 currentSimplex.get(currentSimplex.getDimension()), 200 simulatedAnnealing.getStartProbability()); 201 } 202 203 while (true) { 204 if (previousSimplex != null) { // Skip check at first iteration. 205 if (hasConverged(previousSimplex, currentSimplex)) { 206 return currentSimplex.get(0); 207 } 208 } 209 210 // We still need to search. 211 previousSimplex = currentSimplex; 212 213 if (simulatedAnnealing != null) { 214 // Update current temperature. 215 temperature = 216 simulatedAnnealing.getCoolingSchedule().apply(temperature, 217 currentSimplex); 218 219 final double endTemperature = 220 temperature(currentSimplex.get(0), 221 currentSimplex.get(currentSimplex.getDimension()), 222 simulatedAnnealing.getEndProbability()); 223 224 if (temperature < endTemperature) { 225 break; 226 } 227 228 final UnaryOperator<Simplex> update = 229 updateRule.create(evalFunc, 230 comparator, 231 simulatedAnnealing.metropolis(temperature)); 232 233 for (int i = 0; i < simulatedAnnealing.getEpochDuration(); i++) { 234 // Simplex is transformed (and observers are notified). 235 currentSimplex = applyUpdate(update, 236 currentSimplex, 237 evalFunc, 238 comparator); 239 } 240 } else { 241 // No simulated annealing. 242 final UnaryOperator<Simplex> update = 243 updateRule.create(evalFunc, comparator, null); 244 245 // Simplex is transformed (and observers are notified). 246 currentSimplex = applyUpdate(update, 247 currentSimplex, 248 evalFunc, 249 comparator); 250 } 251 252 if (additionalSearch != 0) { 253 // In "bestList", we must keep track of at least two points 254 // in order to be able to compute the new initial simplex for 255 // the additional search. 256 final int max = Math.max(additionalSearch, 2); 257 258 // Store best points. 259 for (int i = 0; i < currentSimplex.getSize(); i++) { 260 keepIfBetter(currentSimplex.get(i), 261 comparator, 262 bestList, 263 max); 264 } 265 } 266 267 incrementIterationCount(); 268 } 269 270 // No convergence. 271 272 if (additionalSearch > 0) { 273 // Additional optimizations. 274 // Reference to counter in the "main" search in order to retrieve 275 // the total number of evaluations in the "best list" search. 276 final IntSupplier evalCount = () -> getEvaluations(); 277 278 return bestListSearch(evalFunc, 279 comparator, 280 bestList, 281 evalCount); 282 } 283 284 throw new MathInternalError(); // Should never happen. 285 } 286 287 /** 288 * Scans the list of (required and optional) optimization data that 289 * characterize the problem. 290 * 291 * @param optData Optimization data. 292 * The following data will be looked for: 293 * <ul> 294 * <li>{@link Simplex}</li> 295 * <li>{@link Simplex.TransformFactory}</li> 296 * <li>{@link SimulatedAnnealing}</li> 297 * <li>{@link PopulationSize}</li> 298 * </ul> 299 */ 300 @Override 301 protected void parseOptimizationData(OptimizationData... optData) { 302 // Allow base class to register its own data. 303 super.parseOptimizationData(optData); 304 305 // The existing values (as set by the previous call) are reused 306 // if not provided in the argument list. 307 for (OptimizationData data : optData) { 308 if (data instanceof Simplex) { 309 initialSimplex = (Simplex) data; 310 } else if (data instanceof Simplex.TransformFactory) { 311 updateRule = (Simplex.TransformFactory) data; 312 } else if (data instanceof SimulatedAnnealing) { 313 simulatedAnnealing = (SimulatedAnnealing) data; 314 } else if (data instanceof PopulationSize) { 315 populationSize = ((PopulationSize) data).getPopulationSize(); 316 } 317 } 318 } 319 320 /** 321 * Detects whether the simplex has shrunk below the user-defined 322 * tolerance. 323 * 324 * @param previous Simplex at previous iteration. 325 * @param current Simplex at current iteration. 326 * @return {@code true} if convergence is considered achieved. 327 */ 328 private boolean hasConverged(Simplex previous, 329 Simplex current) { 330 final ConvergenceChecker<PointValuePair> checker = getConvergenceChecker(); 331 332 for (int i = 0; i < current.getSize(); i++) { 333 final PointValuePair prev = previous.get(i); 334 final PointValuePair curr = current.get(i); 335 336 if (!checker.converged(getIterations(), prev, curr)) { 337 return false; 338 } 339 } 340 341 return true; 342 } 343 344 /** 345 * @throws MathUnsupportedOperationException if bounds were passed to the 346 * {@link #optimize(OptimizationData[]) optimize} method. 347 * @throws NullPointerException if no initial simplex or no transform rule 348 * was passed to the {@link #optimize(OptimizationData[]) optimize} method. 349 * @throws IllegalArgumentException if {@link #populationSize} is negative. 350 */ 351 private void checkParameters() { 352 Objects.requireNonNull(updateRule, "Update rule"); 353 Objects.requireNonNull(initialSimplex, "Initial simplex"); 354 355 if (getLowerBound() != null || 356 getUpperBound() != null) { 357 throw new MathUnsupportedOperationException(LocalizedFormats.CONSTRAINT); 358 } 359 360 if (populationSize < 0) { 361 throw new IllegalArgumentException("Population size"); 362 } 363 364 additionalSearch = simulatedAnnealing == null ? 365 Math.max(0, populationSize) : 366 Math.max(1, populationSize); 367 } 368 369 /** 370 * Computes the temperature as a function of the acceptance probability 371 * and the fitness difference between two of the simplex vertices (usually 372 * the best and worst points). 373 * 374 * @param p1 Simplex point. 375 * @param p2 Simplex point. 376 * @param prob Acceptance probability. 377 * @return the temperature. 378 */ 379 private double temperature(PointValuePair p1, 380 PointValuePair p2, 381 double prob) { 382 return -Math.abs(p1.getValue() - p2.getValue()) / Math.log(prob); 383 } 384 385 /** 386 * Stores the given {@code candidate} if its fitness is better than 387 * that of the last (assumed to be the worst) point in {@code list}. 388 * 389 * <p>If the list is below the maximum size then the {@code candidate} 390 * is added if it is not already in the list. The list is sorted 391 * when it reaches the maximum size. 392 * 393 * @param candidate Point to be stored. 394 * @param comp Fitness comparator. 395 * @param list Starting points (modified in-place). 396 * @param max Maximum size of the {@code list}. 397 */ 398 private static void keepIfBetter(PointValuePair candidate, 399 Comparator<PointValuePair> comp, 400 List<PointValuePair> list, 401 int max) { 402 final int listSize = list.size(); 403 final double[] candidatePoint = candidate.getPoint(); 404 if (listSize == 0) { 405 list.add(candidate); 406 } else if (listSize < max) { 407 // List is not fully populated yet. 408 for (PointValuePair p : list) { 409 final double[] pPoint = p.getPoint(); 410 if (Arrays.equals(pPoint, candidatePoint)) { 411 // Point was already stored. 412 return; 413 } 414 } 415 // Store candidate. 416 list.add(candidate); 417 // Sort the list when required 418 if (list.size() == max) { 419 Collections.sort(list, comp); 420 } 421 } else { 422 final int last = max - 1; 423 if (comp.compare(candidate, list.get(last)) < 0) { 424 for (PointValuePair p : list) { 425 final double[] pPoint = p.getPoint(); 426 if (Arrays.equals(pPoint, candidatePoint)) { 427 // Point was already stored. 428 return; 429 } 430 } 431 432 // Store better candidate and reorder the list. 433 list.set(last, candidate); 434 Collections.sort(list, comp); 435 } 436 } 437 } 438 439 /** 440 * Computes the smallest distance between the given {@code point} 441 * and any of the other points in the {@code list}. 442 * 443 * @param point Point. 444 * @param list List. 445 * @return the smallest distance. 446 */ 447 private static double shortestDistance(PointValuePair point, 448 List<PointValuePair> list) { 449 double minDist = Double.POSITIVE_INFINITY; 450 451 final double[] p = point.getPoint(); 452 for (PointValuePair other : list) { 453 final double[] pOther = other.getPoint(); 454 if (!Arrays.equals(p, pOther)) { 455 final double dist = MathArrays.distance(p, pOther); 456 if (dist < minDist) { 457 minDist = dist; 458 } 459 } 460 } 461 462 return minDist; 463 } 464 465 /** 466 * Perform additional optimizations. 467 * 468 * @param evalFunc Objective function. 469 * @param comp Fitness comparator. 470 * @param bestList Best points encountered during the "main" search. 471 * List is assumed to be ordered from best to worst. 472 * @param evalCount Evaluation counter. 473 * @return the optimum. 474 */ 475 private PointValuePair bestListSearch(MultivariateFunction evalFunc, 476 Comparator<PointValuePair> comp, 477 List<PointValuePair> bestList, 478 IntSupplier evalCount) { 479 PointValuePair best = bestList.get(0); // Overall best result. 480 481 // Additional local optimizations using each of the best 482 // points visited during the main search. 483 for (int i = 0; i < additionalSearch; i++) { 484 final PointValuePair start = bestList.get(i); 485 // Find shortest distance to the other points. 486 final double dist = shortestDistance(start, bestList); 487 final double[] init = start.getPoint(); 488 // Create smaller initial simplex. 489 final Simplex simplex = Simplex.equalSidesAlongAxes(init.length, 490 SIMPLEX_SIDE_RATIO * dist); 491 492 final PointValuePair r = directSearch(init, 493 simplex, 494 evalFunc, 495 getConvergenceChecker(), 496 getGoalType(), 497 callbacks, 498 evalCount); 499 if (comp.compare(r, best) < 0) { 500 best = r; // New overall best. 501 } 502 } 503 504 return best; 505 } 506 507 /** 508 * @param init Start point. 509 * @param simplex Initial simplex. 510 * @param eval Objective function. 511 * Note: It is assumed that evaluations of this function are 512 * incrementing the main counter. 513 * @param checker Convergence checker. 514 * @param goalType Whether to minimize or maximize the objective function. 515 * @param cbList Callbacks. 516 * @param evalCount Evaluation counter. 517 * @return the optimum. 518 */ 519 private static PointValuePair directSearch(double[] init, 520 Simplex simplex, 521 MultivariateFunction eval, 522 ConvergenceChecker<PointValuePair> checker, 523 GoalType goalType, 524 List<Observer> cbList, 525 final IntSupplier evalCount) { 526 final SimplexOptimizer optim = new SimplexOptimizer(checker); 527 528 for (Observer cOrig : cbList) { 529 final SimplexOptimizer.Observer cNew = (spx, isInit, numEval) -> 530 cOrig.update(spx, isInit, evalCount.getAsInt()); 531 532 optim.addObserver(cNew); 533 } 534 535 return optim.optimize(MaxEval.unlimited(), 536 new ObjectiveFunction(eval), 537 goalType, 538 new InitialGuess(init), 539 simplex, 540 new MultiDirectionalTransform()); 541 } 542 543 /** 544 * @param simplex Current simplex. 545 * @param isInit Set to {@code true} at the start of a new search 546 * (either "main" or "best list"), after the evaluation of the initial 547 * simplex's vertices. 548 */ 549 private void notifyObservers(Simplex simplex, 550 boolean isInit) { 551 for (Observer cb : callbacks) { 552 cb.update(simplex, 553 isInit, 554 getEvaluations()); 555 } 556 } 557 558 /** 559 * Applies the {@code update} to the given {@code simplex} (and notifies 560 * observers). 561 * 562 * @param update Simplex transformation. 563 * @param simplex Current simplex. 564 * @param eval Objective function. 565 * @param comp Fitness comparator. 566 * @return the transformed simplex. 567 */ 568 private Simplex applyUpdate(UnaryOperator<Simplex> update, 569 Simplex simplex, 570 MultivariateFunction eval, 571 Comparator<PointValuePair> comp) { 572 final Simplex transformed = update.apply(simplex).evaluate(eval, comp); 573 574 notifyObservers(transformed, false); 575 576 return transformed; 577 } 578}