001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.commons.math3.analysis.solvers; 018 019import org.apache.commons.math3.analysis.UnivariateFunction; 020import org.apache.commons.math3.exception.NoBracketingException; 021import org.apache.commons.math3.exception.NotStrictlyPositiveException; 022import org.apache.commons.math3.exception.NullArgumentException; 023import org.apache.commons.math3.exception.NumberIsTooLargeException; 024import org.apache.commons.math3.exception.util.LocalizedFormats; 025import org.apache.commons.math3.util.FastMath; 026 027/** 028 * Utility routines for {@link UnivariateSolver} objects. 029 * 030 */ 031public class UnivariateSolverUtils { 032 /** 033 * Class contains only static methods. 034 */ 035 private UnivariateSolverUtils() {} 036 037 /** 038 * Convenience method to find a zero of a univariate real function. A default 039 * solver is used. 040 * 041 * @param function Function. 042 * @param x0 Lower bound for the interval. 043 * @param x1 Upper bound for the interval. 044 * @return a value where the function is zero. 045 * @throws NoBracketingException if the function has the same sign at the 046 * endpoints. 047 * @throws NullArgumentException if {@code function} is {@code null}. 048 */ 049 public static double solve(UnivariateFunction function, double x0, double x1) 050 throws NullArgumentException, 051 NoBracketingException { 052 if (function == null) { 053 throw new NullArgumentException(LocalizedFormats.FUNCTION); 054 } 055 final UnivariateSolver solver = new BrentSolver(); 056 return solver.solve(Integer.MAX_VALUE, function, x0, x1); 057 } 058 059 /** 060 * Convenience method to find a zero of a univariate real function. A default 061 * solver is used. 062 * 063 * @param function Function. 064 * @param x0 Lower bound for the interval. 065 * @param x1 Upper bound for the interval. 066 * @param absoluteAccuracy Accuracy to be used by the solver. 067 * @return a value where the function is zero. 068 * @throws NoBracketingException if the function has the same sign at the 069 * endpoints. 070 * @throws NullArgumentException if {@code function} is {@code null}. 071 */ 072 public static double solve(UnivariateFunction function, 073 double x0, double x1, 074 double absoluteAccuracy) 075 throws NullArgumentException, 076 NoBracketingException { 077 if (function == null) { 078 throw new NullArgumentException(LocalizedFormats.FUNCTION); 079 } 080 final UnivariateSolver solver = new BrentSolver(absoluteAccuracy); 081 return solver.solve(Integer.MAX_VALUE, function, x0, x1); 082 } 083 084 /** 085 * Force a root found by a non-bracketing solver to lie on a specified side, 086 * as if the solver were a bracketing one. 087 * 088 * @param maxEval maximal number of new evaluations of the function 089 * (evaluations already done for finding the root should have already been subtracted 090 * from this number) 091 * @param f function to solve 092 * @param bracketing bracketing solver to use for shifting the root 093 * @param baseRoot original root found by a previous non-bracketing solver 094 * @param min minimal bound of the search interval 095 * @param max maximal bound of the search interval 096 * @param allowedSolution the kind of solutions that the root-finding algorithm may 097 * accept as solutions. 098 * @return a root approximation, on the specified side of the exact root 099 * @throws NoBracketingException if the function has the same sign at the 100 * endpoints. 101 */ 102 public static double forceSide(final int maxEval, final UnivariateFunction f, 103 final BracketedUnivariateSolver<UnivariateFunction> bracketing, 104 final double baseRoot, final double min, final double max, 105 final AllowedSolution allowedSolution) 106 throws NoBracketingException { 107 108 if (allowedSolution == AllowedSolution.ANY_SIDE) { 109 // no further bracketing required 110 return baseRoot; 111 } 112 113 // find a very small interval bracketing the root 114 final double step = FastMath.max(bracketing.getAbsoluteAccuracy(), 115 FastMath.abs(baseRoot * bracketing.getRelativeAccuracy())); 116 double xLo = FastMath.max(min, baseRoot - step); 117 double fLo = f.value(xLo); 118 double xHi = FastMath.min(max, baseRoot + step); 119 double fHi = f.value(xHi); 120 int remainingEval = maxEval - 2; 121 while (remainingEval > 0) { 122 123 if ((fLo >= 0 && fHi <= 0) || (fLo <= 0 && fHi >= 0)) { 124 // compute the root on the selected side 125 return bracketing.solve(remainingEval, f, xLo, xHi, baseRoot, allowedSolution); 126 } 127 128 // try increasing the interval 129 boolean changeLo = false; 130 boolean changeHi = false; 131 if (fLo < fHi) { 132 // increasing function 133 if (fLo >= 0) { 134 changeLo = true; 135 } else { 136 changeHi = true; 137 } 138 } else if (fLo > fHi) { 139 // decreasing function 140 if (fLo <= 0) { 141 changeLo = true; 142 } else { 143 changeHi = true; 144 } 145 } else { 146 // unknown variation 147 changeLo = true; 148 changeHi = true; 149 } 150 151 // update the lower bound 152 if (changeLo) { 153 xLo = FastMath.max(min, xLo - step); 154 fLo = f.value(xLo); 155 remainingEval--; 156 } 157 158 // update the higher bound 159 if (changeHi) { 160 xHi = FastMath.min(max, xHi + step); 161 fHi = f.value(xHi); 162 remainingEval--; 163 } 164 165 } 166 167 throw new NoBracketingException(LocalizedFormats.FAILED_BRACKETING, 168 xLo, xHi, fLo, fHi, 169 maxEval - remainingEval, maxEval, baseRoot, 170 min, max); 171 172 } 173 174 /** 175 * This method simply calls {@link #bracket(UnivariateFunction, double, double, double, 176 * double, double, int) bracket(function, initial, lowerBound, upperBound, q, r, maximumIterations)} 177 * with {@code q} and {@code r} set to 1.0 and {@code maximumIterations} set to {@code Integer.MAX_VALUE}. 178 * <p> 179 * <strong>Note: </strong> this method can take {@code Integer.MAX_VALUE} 180 * iterations to throw a {@code ConvergenceException.} Unless you are 181 * confident that there is a root between {@code lowerBound} and 182 * {@code upperBound} near {@code initial}, it is better to use 183 * {@link #bracket(UnivariateFunction, double, double, double, double,double, int) 184 * bracket(function, initial, lowerBound, upperBound, q, r, maximumIterations)}, 185 * explicitly specifying the maximum number of iterations.</p> 186 * 187 * @param function Function. 188 * @param initial Initial midpoint of interval being expanded to 189 * bracket a root. 190 * @param lowerBound Lower bound (a is never lower than this value) 191 * @param upperBound Upper bound (b never is greater than this 192 * value). 193 * @return a two-element array holding a and b. 194 * @throws NoBracketingException if a root cannot be bracketted. 195 * @throws NotStrictlyPositiveException if {@code maximumIterations <= 0}. 196 * @throws NullArgumentException if {@code function} is {@code null}. 197 */ 198 public static double[] bracket(UnivariateFunction function, 199 double initial, 200 double lowerBound, double upperBound) 201 throws NullArgumentException, 202 NotStrictlyPositiveException, 203 NoBracketingException { 204 return bracket(function, initial, lowerBound, upperBound, 1.0, 1.0, Integer.MAX_VALUE); 205 } 206 207 /** 208 * This method simply calls {@link #bracket(UnivariateFunction, double, double, double, 209 * double, double, int) bracket(function, initial, lowerBound, upperBound, q, r, maximumIterations)} 210 * with {@code q} and {@code r} set to 1.0. 211 * @param function Function. 212 * @param initial Initial midpoint of interval being expanded to 213 * bracket a root. 214 * @param lowerBound Lower bound (a is never lower than this value). 215 * @param upperBound Upper bound (b never is greater than this 216 * value). 217 * @param maximumIterations Maximum number of iterations to perform 218 * @return a two element array holding a and b. 219 * @throws NoBracketingException if the algorithm fails to find a and b 220 * satisfying the desired conditions. 221 * @throws NotStrictlyPositiveException if {@code maximumIterations <= 0}. 222 * @throws NullArgumentException if {@code function} is {@code null}. 223 */ 224 public static double[] bracket(UnivariateFunction function, 225 double initial, 226 double lowerBound, double upperBound, 227 int maximumIterations) 228 throws NullArgumentException, 229 NotStrictlyPositiveException, 230 NoBracketingException { 231 return bracket(function, initial, lowerBound, upperBound, 1.0, 1.0, maximumIterations); 232 } 233 234 /** 235 * This method attempts to find two values a and b satisfying <ul> 236 * <li> {@code lowerBound <= a < initial < b <= upperBound} </li> 237 * <li> {@code f(a) * f(b) <= 0} </li> 238 * </ul> 239 * If {@code f} is continuous on {@code [a,b]}, this means that {@code a} 240 * and {@code b} bracket a root of {@code f}. 241 * <p> 242 * The algorithm checks the sign of \( f(l_k) \) and \( f(u_k) \) for increasing 243 * values of k, where \( l_k = max(lower, initial - \delta_k) \), 244 * \( u_k = min(upper, initial + \delta_k) \), using recurrence 245 * \( \delta_{k+1} = r \delta_k + q, \delta_0 = 0\) and starting search with \( k=1 \). 246 * The algorithm stops when one of the following happens: <ul> 247 * <li> at least one positive and one negative value have been found -- success!</li> 248 * <li> both endpoints have reached their respective limits -- NoBracketingException </li> 249 * <li> {@code maximumIterations} iterations elapse -- NoBracketingException </li></ul> 250 * <p> 251 * If different signs are found at first iteration ({@code k=1}), then the returned 252 * interval will be \( [a, b] = [l_1, u_1] \). If different signs are found at a later 253 * iteration {@code k>1}, then the returned interval will be either 254 * \( [a, b] = [l_{k+1}, l_{k}] \) or \( [a, b] = [u_{k}, u_{k+1}] \). A root solver called 255 * with these parameters will therefore start with the smallest bracketing interval known 256 * at this step. 257 * </p> 258 * <p> 259 * Interval expansion rate is tuned by changing the recurrence parameters {@code r} and 260 * {@code q}. When the multiplicative factor {@code r} is set to 1, the sequence is a 261 * simple arithmetic sequence with linear increase. When the multiplicative factor {@code r} 262 * is larger than 1, the sequence has an asymptotically exponential rate. Note than the 263 * additive parameter {@code q} should never be set to zero, otherwise the interval would 264 * degenerate to the single initial point for all values of {@code k}. 265 * </p> 266 * <p> 267 * As a rule of thumb, when the location of the root is expected to be approximately known 268 * within some error margin, {@code r} should be set to 1 and {@code q} should be set to the 269 * order of magnitude of the error margin. When the location of the root is really a wild guess, 270 * then {@code r} should be set to a value larger than 1 (typically 2 to double the interval 271 * length at each iteration) and {@code q} should be set according to half the initial 272 * search interval length. 273 * </p> 274 * <p> 275 * As an example, if we consider the trivial function {@code f(x) = 1 - x} and use 276 * {@code initial = 4}, {@code r = 1}, {@code q = 2}, the algorithm will compute 277 * {@code f(4-2) = f(2) = -1} and {@code f(4+2) = f(6) = -5} for {@code k = 1}, then 278 * {@code f(4-4) = f(0) = +1} and {@code f(4+4) = f(8) = -7} for {@code k = 2}. Then it will 279 * return the interval {@code [0, 2]} as the smallest one known to be bracketing the root. 280 * As shown by this example, the initial value (here {@code 4}) may lie outside of the returned 281 * bracketing interval. 282 * </p> 283 * @param function function to check 284 * @param initial Initial midpoint of interval being expanded to 285 * bracket a root. 286 * @param lowerBound Lower bound (a is never lower than this value). 287 * @param upperBound Upper bound (b never is greater than this 288 * value). 289 * @param q additive offset used to compute bounds sequence (must be strictly positive) 290 * @param r multiplicative factor used to compute bounds sequence 291 * @param maximumIterations Maximum number of iterations to perform 292 * @return a two element array holding the bracketing values. 293 * @exception NoBracketingException if function cannot be bracketed in the search interval 294 */ 295 public static double[] bracket(final UnivariateFunction function, final double initial, 296 final double lowerBound, final double upperBound, 297 final double q, final double r, final int maximumIterations) 298 throws NoBracketingException { 299 300 if (function == null) { 301 throw new NullArgumentException(LocalizedFormats.FUNCTION); 302 } 303 if (q <= 0) { 304 throw new NotStrictlyPositiveException(q); 305 } 306 if (maximumIterations <= 0) { 307 throw new NotStrictlyPositiveException(LocalizedFormats.INVALID_MAX_ITERATIONS, maximumIterations); 308 } 309 verifySequence(lowerBound, initial, upperBound); 310 311 // initialize the recurrence 312 double a = initial; 313 double b = initial; 314 double fa = Double.NaN; 315 double fb = Double.NaN; 316 double delta = 0; 317 318 for (int numIterations = 0; 319 (numIterations < maximumIterations) && (a > lowerBound || b < upperBound); 320 ++numIterations) { 321 322 final double previousA = a; 323 final double previousFa = fa; 324 final double previousB = b; 325 final double previousFb = fb; 326 327 delta = r * delta + q; 328 a = FastMath.max(initial - delta, lowerBound); 329 b = FastMath.min(initial + delta, upperBound); 330 fa = function.value(a); 331 fb = function.value(b); 332 333 if (numIterations == 0) { 334 // at first iteration, we don't have a previous interval 335 // we simply compare both sides of the initial interval 336 if (fa * fb <= 0) { 337 // the first interval already brackets a root 338 return new double[] { a, b }; 339 } 340 } else { 341 // we have a previous interval with constant sign and expand it, 342 // we expect sign changes to occur at boundaries 343 if (fa * previousFa <= 0) { 344 // sign change detected at near lower bound 345 return new double[] { a, previousA }; 346 } else if (fb * previousFb <= 0) { 347 // sign change detected at near upper bound 348 return new double[] { previousB, b }; 349 } 350 } 351 352 } 353 354 // no bracketing found 355 throw new NoBracketingException(a, b, fa, fb); 356 357 } 358 359 /** 360 * Compute the midpoint of two values. 361 * 362 * @param a first value. 363 * @param b second value. 364 * @return the midpoint. 365 */ 366 public static double midpoint(double a, double b) { 367 return (a + b) * 0.5; 368 } 369 370 /** 371 * Check whether the interval bounds bracket a root. That is, if the 372 * values at the endpoints are not equal to zero, then the function takes 373 * opposite signs at the endpoints. 374 * 375 * @param function Function. 376 * @param lower Lower endpoint. 377 * @param upper Upper endpoint. 378 * @return {@code true} if the function values have opposite signs at the 379 * given points. 380 * @throws NullArgumentException if {@code function} is {@code null}. 381 */ 382 public static boolean isBracketing(UnivariateFunction function, 383 final double lower, 384 final double upper) 385 throws NullArgumentException { 386 if (function == null) { 387 throw new NullArgumentException(LocalizedFormats.FUNCTION); 388 } 389 final double fLo = function.value(lower); 390 final double fHi = function.value(upper); 391 return (fLo >= 0 && fHi <= 0) || (fLo <= 0 && fHi >= 0); 392 } 393 394 /** 395 * Check whether the arguments form a (strictly) increasing sequence. 396 * 397 * @param start First number. 398 * @param mid Second number. 399 * @param end Third number. 400 * @return {@code true} if the arguments form an increasing sequence. 401 */ 402 public static boolean isSequence(final double start, 403 final double mid, 404 final double end) { 405 return (start < mid) && (mid < end); 406 } 407 408 /** 409 * Check that the endpoints specify an interval. 410 * 411 * @param lower Lower endpoint. 412 * @param upper Upper endpoint. 413 * @throws NumberIsTooLargeException if {@code lower >= upper}. 414 */ 415 public static void verifyInterval(final double lower, 416 final double upper) 417 throws NumberIsTooLargeException { 418 if (lower >= upper) { 419 throw new NumberIsTooLargeException(LocalizedFormats.ENDPOINTS_NOT_AN_INTERVAL, 420 lower, upper, false); 421 } 422 } 423 424 /** 425 * Check that {@code lower < initial < upper}. 426 * 427 * @param lower Lower endpoint. 428 * @param initial Initial value. 429 * @param upper Upper endpoint. 430 * @throws NumberIsTooLargeException if {@code lower >= initial} or 431 * {@code initial >= upper}. 432 */ 433 public static void verifySequence(final double lower, 434 final double initial, 435 final double upper) 436 throws NumberIsTooLargeException { 437 verifyInterval(lower, initial); 438 verifyInterval(initial, upper); 439 } 440 441 /** 442 * Check that the endpoints specify an interval and the end points 443 * bracket a root. 444 * 445 * @param function Function. 446 * @param lower Lower endpoint. 447 * @param upper Upper endpoint. 448 * @throws NoBracketingException if the function has the same sign at the 449 * endpoints. 450 * @throws NullArgumentException if {@code function} is {@code null}. 451 */ 452 public static void verifyBracketing(UnivariateFunction function, 453 final double lower, 454 final double upper) 455 throws NullArgumentException, 456 NoBracketingException { 457 if (function == null) { 458 throw new NullArgumentException(LocalizedFormats.FUNCTION); 459 } 460 verifyInterval(lower, upper); 461 if (!isBracketing(function, lower, upper)) { 462 throw new NoBracketingException(lower, upper, 463 function.value(lower), 464 function.value(upper)); 465 } 466 } 467}