001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.commons.math4.legacy.optim.nonlinear.scalar.noderiv; 018 019import java.util.Comparator; 020import java.util.function.UnaryOperator; 021import java.util.function.DoublePredicate; 022 023import org.apache.commons.math4.legacy.analysis.MultivariateFunction; 024import org.apache.commons.math4.legacy.optim.PointValuePair; 025 026/** 027 * <a href="https://en.wikipedia.org/wiki/Nelder%E2%80%93Mead_method">Nelder-Mead method</a>. 028 */ 029public class NelderMeadTransform 030 implements Simplex.TransformFactory { 031 /** Default value for {@link #alpha}: {@value}. */ 032 private static final double DEFAULT_ALPHA = 1; 033 /** Default value for {@link #gamma}: {@value}. */ 034 private static final double DEFAULT_GAMMA = 2; 035 /** Default value for {@link #rho}: {@value}. */ 036 private static final double DEFAULT_RHO = 0.5; 037 /** Default value for {@link #sigma}: {@value}. */ 038 private static final double DEFAULT_SIGMA = 0.5; 039 /** Reflection coefficient. */ 040 private final double alpha; 041 /** Expansion coefficient. */ 042 private final double gamma; 043 /** Contraction coefficient. */ 044 private final double rho; 045 /** Shrinkage coefficient. */ 046 private final double sigma; 047 048 /** 049 * @param alpha Reflection coefficient. 050 * @param gamma Expansion coefficient. 051 * @param rho Contraction coefficient. 052 * @param sigma Shrinkage coefficient. 053 */ 054 public NelderMeadTransform(double alpha, 055 double gamma, 056 double rho, 057 double sigma) { 058 this.alpha = alpha; 059 this.gamma = gamma; 060 this.rho = rho; 061 this.sigma = sigma; 062 } 063 064 /** 065 * Transform with default values. 066 */ 067 public NelderMeadTransform() { 068 this(DEFAULT_ALPHA, 069 DEFAULT_GAMMA, 070 DEFAULT_RHO, 071 DEFAULT_SIGMA); 072 } 073 074 /** {@inheritDoc} */ 075 @Override 076 public UnaryOperator<Simplex> create(final MultivariateFunction evaluationFunction, 077 final Comparator<PointValuePair> comparator, 078 final DoublePredicate sa) { 079 return original -> { 080 // The simplex has n + 1 points if dimension is n. 081 final int n = original.getDimension(); 082 083 // Interesting values. 084 final PointValuePair best = original.get(0); 085 final PointValuePair secondWorst = original.get(n - 1); 086 final PointValuePair worst = original.get(n); 087 final double[] xWorst = worst.getPoint(); 088 089 // Centroid of the best vertices, dismissing the worst point (at index n). 090 final double[] centroid = Simplex.centroid(original.asList().subList(0, n)); 091 092 // Reflection. 093 final PointValuePair reflected = Simplex.newPoint(centroid, 094 -alpha, 095 xWorst, 096 evaluationFunction); 097 if (comparator.compare(reflected, secondWorst) < 0 && 098 comparator.compare(best, reflected) <= 0) { 099 return original.replaceLast(reflected); 100 } 101 102 if (comparator.compare(reflected, best) < 0) { 103 // Expansion. 104 final PointValuePair expanded = Simplex.newPoint(centroid, 105 -gamma, 106 xWorst, 107 evaluationFunction); 108 if (comparator.compare(expanded, reflected) < 0 || 109 (sa != null && 110 sa.test(expanded.getValue() - reflected.getValue()))) { 111 return original.replaceLast(expanded); 112 } else { 113 return original.replaceLast(reflected); 114 } 115 } 116 117 if (comparator.compare(reflected, worst) < 0) { 118 // Outside contraction. 119 final PointValuePair contracted = Simplex.newPoint(centroid, 120 rho, 121 reflected.getPoint(), 122 evaluationFunction); 123 if (comparator.compare(contracted, reflected) < 0) { 124 return original.replaceLast(contracted); // Accept contracted point. 125 } 126 } else { 127 // Inside contraction. 128 final PointValuePair contracted = Simplex.newPoint(centroid, 129 rho, 130 xWorst, 131 evaluationFunction); 132 if (comparator.compare(contracted, worst) < 0) { 133 return original.replaceLast(contracted); // Accept contracted point. 134 } 135 } 136 137 // Shrink. 138 return original.shrink(sigma, evaluationFunction); 139 }; 140 } 141 142 /** {@inheritDoc} */ 143 @Override 144 public String toString() { 145 return "Nelder-Mead [a=" + alpha + 146 " g=" + gamma + 147 " r=" + rho + 148 " s=" + sigma + "]"; 149 } 150}