001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.commons.math4.legacy.optim.univariate; 018 019import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException; 020import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException; 021import org.apache.commons.math4.legacy.optim.ConvergenceChecker; 022import org.apache.commons.math4.legacy.optim.nonlinear.scalar.GoalType; 023import org.apache.commons.math4.core.jdkmath.JdkMath; 024import org.apache.commons.numbers.core.Precision; 025 026/** 027 * For a function defined on some interval {@code (lo, hi)}, this class 028 * finds an approximation {@code x} to the point at which the function 029 * attains its minimum. 030 * It implements Richard Brent's algorithm (from his book "Algorithms for 031 * Minimization without Derivatives", p. 79) for finding minima of real 032 * univariate functions. 033 * <br> 034 * This code is an adaptation, partly based on the Python code from SciPy 035 * (module "optimize.py" v0.5); the original algorithm is also modified 036 * <ul> 037 * <li>to use an initial guess provided by the user,</li> 038 * <li>to ensure that the best point encountered is the one returned.</li> 039 * </ul> 040 * 041 * @since 2.0 042 */ 043public class BrentOptimizer extends UnivariateOptimizer { 044 /** 045 * Golden section. 046 */ 047 private static final double GOLDEN_SECTION = 0.5 * (3 - JdkMath.sqrt(5)); 048 /** 049 * Minimum relative tolerance. 050 */ 051 private static final double MIN_RELATIVE_TOLERANCE = 2 * JdkMath.ulp(1d); 052 /** 053 * Relative threshold. 054 */ 055 private final double relativeThreshold; 056 /** 057 * Absolute threshold. 058 */ 059 private final double absoluteThreshold; 060 061 /** 062 * The arguments are used implement the original stopping criterion 063 * of Brent's algorithm. 064 * {@code abs} and {@code rel} define a tolerance 065 * {@code tol = rel |x| + abs}. {@code rel} should be no smaller than 066 * <em>2 macheps</em> and preferably not much less than <em>sqrt(macheps)</em>, 067 * where <em>macheps</em> is the relative machine precision. {@code abs} must 068 * be positive. 069 * 070 * @param rel Relative threshold. 071 * @param abs Absolute threshold. 072 * @param checker Additional, user-defined, convergence checking 073 * procedure. 074 * @throws NotStrictlyPositiveException if {@code abs <= 0}. 075 * @throws NumberIsTooSmallException if {@code rel < 2 * Math.ulp(1d)}. 076 */ 077 public BrentOptimizer(double rel, 078 double abs, 079 ConvergenceChecker<UnivariatePointValuePair> checker) { 080 super(checker); 081 082 if (rel < MIN_RELATIVE_TOLERANCE) { 083 throw new NumberIsTooSmallException(rel, MIN_RELATIVE_TOLERANCE, true); 084 } 085 if (abs <= 0) { 086 throw new NotStrictlyPositiveException(abs); 087 } 088 089 relativeThreshold = rel; 090 absoluteThreshold = abs; 091 } 092 093 /** 094 * The arguments are used for implementing the original stopping criterion 095 * of Brent's algorithm. 096 * {@code abs} and {@code rel} define a tolerance 097 * {@code tol = rel |x| + abs}. {@code rel} should be no smaller than 098 * <em>2 macheps</em> and preferably not much less than <em>sqrt(macheps)</em>, 099 * where <em>macheps</em> is the relative machine precision. {@code abs} must 100 * be positive. 101 * 102 * @param rel Relative threshold. 103 * @param abs Absolute threshold. 104 * @throws NotStrictlyPositiveException if {@code abs <= 0}. 105 * @throws NumberIsTooSmallException if {@code rel < 2 * Math.ulp(1d)}. 106 */ 107 public BrentOptimizer(double rel, 108 double abs) { 109 this(rel, abs, null); 110 } 111 112 /** {@inheritDoc} */ 113 @Override 114 protected UnivariatePointValuePair doOptimize() { 115 final boolean isMinim = getGoalType() == GoalType.MINIMIZE; 116 final double lo = getMin(); 117 final double mid = getStartValue(); 118 final double hi = getMax(); 119 120 // Optional additional convergence criteria. 121 final ConvergenceChecker<UnivariatePointValuePair> checker 122 = getConvergenceChecker(); 123 124 double a; 125 double b; 126 if (lo < hi) { 127 a = lo; 128 b = hi; 129 } else { 130 a = hi; 131 b = lo; 132 } 133 134 double x = mid; 135 double v = x; 136 double w = x; 137 double d = 0; 138 double e = 0; 139 double fx = computeObjectiveValue(x); 140 if (!isMinim) { 141 fx = -fx; 142 } 143 double fv = fx; 144 double fw = fx; 145 146 UnivariatePointValuePair previous = null; 147 UnivariatePointValuePair current 148 = new UnivariatePointValuePair(x, isMinim ? fx : -fx); 149 // Best point encountered so far (which is the initial guess). 150 UnivariatePointValuePair best = current; 151 152 while (true) { 153 final double m = 0.5 * (a + b); 154 final double tol1 = relativeThreshold * JdkMath.abs(x) + absoluteThreshold; 155 final double tol2 = 2 * tol1; 156 157 // Default stopping criterion. 158 final boolean stop = JdkMath.abs(x - m) <= tol2 - 0.5 * (b - a); 159 if (!stop) { 160 double p = 0; 161 double q = 0; 162 double r = 0; 163 double u = 0; 164 165 if (JdkMath.abs(e) > tol1) { // Fit parabola. 166 r = (x - w) * (fx - fv); 167 q = (x - v) * (fx - fw); 168 p = (x - v) * q - (x - w) * r; 169 q = 2 * (q - r); 170 171 if (q > 0) { 172 p = -p; 173 } else { 174 q = -q; 175 } 176 177 r = e; 178 e = d; 179 180 if (p > q * (a - x) && 181 p < q * (b - x) && 182 JdkMath.abs(p) < JdkMath.abs(0.5 * q * r)) { 183 // Parabolic interpolation step. 184 d = p / q; 185 u = x + d; 186 187 // f must not be evaluated too close to a or b. 188 if (u - a < tol2 || b - u < tol2) { 189 if (x <= m) { 190 d = tol1; 191 } else { 192 d = -tol1; 193 } 194 } 195 } else { 196 // Golden section step. 197 if (x < m) { 198 e = b - x; 199 } else { 200 e = a - x; 201 } 202 d = GOLDEN_SECTION * e; 203 } 204 } else { 205 // Golden section step. 206 if (x < m) { 207 e = b - x; 208 } else { 209 e = a - x; 210 } 211 d = GOLDEN_SECTION * e; 212 } 213 214 // Update by at least "tol1". 215 if (JdkMath.abs(d) < tol1) { 216 if (d >= 0) { 217 u = x + tol1; 218 } else { 219 u = x - tol1; 220 } 221 } else { 222 u = x + d; 223 } 224 225 double fu = computeObjectiveValue(u); 226 if (!isMinim) { 227 fu = -fu; 228 } 229 230 // User-defined convergence checker. 231 previous = current; 232 current = new UnivariatePointValuePair(u, isMinim ? fu : -fu); 233 best = best(best, 234 best(previous, 235 current, 236 isMinim), 237 isMinim); 238 239 if (checker != null && checker.converged(getIterations(), previous, current)) { 240 return best; 241 } 242 243 // Update a, b, v, w and x. 244 if (fu <= fx) { 245 if (u < x) { 246 b = x; 247 } else { 248 a = x; 249 } 250 v = w; 251 fv = fw; 252 w = x; 253 fw = fx; 254 x = u; 255 fx = fu; 256 } else { 257 if (u < x) { 258 a = u; 259 } else { 260 b = u; 261 } 262 if (fu <= fw || 263 Precision.equals(w, x)) { 264 v = w; 265 fv = fw; 266 w = u; 267 fw = fu; 268 } else if (fu <= fv || 269 Precision.equals(v, x) || 270 Precision.equals(v, w)) { 271 v = u; 272 fv = fu; 273 } 274 } 275 } else { // Default termination (Brent's criterion). 276 return best(best, 277 best(previous, 278 current, 279 isMinim), 280 isMinim); 281 } 282 283 incrementIterationCount(); 284 } 285 } 286 287 /** 288 * Selects the best of two points. 289 * 290 * @param a Point and value. 291 * @param b Point and value. 292 * @param isMinim {@code true} if the selected point must be the one with 293 * the lowest value. 294 * @return the best point, or {@code null} if {@code a} and {@code b} are 295 * both {@code null}. When {@code a} and {@code b} have the same function 296 * value, {@code a} is returned. 297 */ 298 private UnivariatePointValuePair best(UnivariatePointValuePair a, 299 UnivariatePointValuePair b, 300 boolean isMinim) { 301 if (a == null) { 302 return b; 303 } 304 if (b == null) { 305 return a; 306 } 307 308 if (isMinim) { 309 return a.getValue() <= b.getValue() ? a : b; 310 } else { 311 return a.getValue() >= b.getValue() ? a : b; 312 } 313 } 314}