001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.math3.analysis.solvers;
018
019import org.apache.commons.math3.analysis.UnivariateFunction;
020import org.apache.commons.math3.exception.NoBracketingException;
021import org.apache.commons.math3.exception.NotStrictlyPositiveException;
022import org.apache.commons.math3.exception.NullArgumentException;
023import org.apache.commons.math3.exception.NumberIsTooLargeException;
024import org.apache.commons.math3.exception.util.LocalizedFormats;
025import org.apache.commons.math3.util.FastMath;
026
027/**
028 * Utility routines for {@link UnivariateSolver} objects.
029 *
030 */
031public class UnivariateSolverUtils {
032    /**
033     * Class contains only static methods.
034     */
035    private UnivariateSolverUtils() {}
036
037    /**
038     * Convenience method to find a zero of a univariate real function.  A default
039     * solver is used.
040     *
041     * @param function Function.
042     * @param x0 Lower bound for the interval.
043     * @param x1 Upper bound for the interval.
044     * @return a value where the function is zero.
045     * @throws NoBracketingException if the function has the same sign at the
046     * endpoints.
047     * @throws NullArgumentException if {@code function} is {@code null}.
048     */
049    public static double solve(UnivariateFunction function, double x0, double x1)
050        throws NullArgumentException,
051               NoBracketingException {
052        if (function == null) {
053            throw new NullArgumentException(LocalizedFormats.FUNCTION);
054        }
055        final UnivariateSolver solver = new BrentSolver();
056        return solver.solve(Integer.MAX_VALUE, function, x0, x1);
057    }
058
059    /**
060     * Convenience method to find a zero of a univariate real function.  A default
061     * solver is used.
062     *
063     * @param function Function.
064     * @param x0 Lower bound for the interval.
065     * @param x1 Upper bound for the interval.
066     * @param absoluteAccuracy Accuracy to be used by the solver.
067     * @return a value where the function is zero.
068     * @throws NoBracketingException if the function has the same sign at the
069     * endpoints.
070     * @throws NullArgumentException if {@code function} is {@code null}.
071     */
072    public static double solve(UnivariateFunction function,
073                               double x0, double x1,
074                               double absoluteAccuracy)
075        throws NullArgumentException,
076               NoBracketingException {
077        if (function == null) {
078            throw new NullArgumentException(LocalizedFormats.FUNCTION);
079        }
080        final UnivariateSolver solver = new BrentSolver(absoluteAccuracy);
081        return solver.solve(Integer.MAX_VALUE, function, x0, x1);
082    }
083
084    /**
085     * Force a root found by a non-bracketing solver to lie on a specified side,
086     * as if the solver were a bracketing one.
087     *
088     * @param maxEval maximal number of new evaluations of the function
089     * (evaluations already done for finding the root should have already been subtracted
090     * from this number)
091     * @param f function to solve
092     * @param bracketing bracketing solver to use for shifting the root
093     * @param baseRoot original root found by a previous non-bracketing solver
094     * @param min minimal bound of the search interval
095     * @param max maximal bound of the search interval
096     * @param allowedSolution the kind of solutions that the root-finding algorithm may
097     * accept as solutions.
098     * @return a root approximation, on the specified side of the exact root
099     * @throws NoBracketingException if the function has the same sign at the
100     * endpoints.
101     */
102    public static double forceSide(final int maxEval, final UnivariateFunction f,
103                                   final BracketedUnivariateSolver<UnivariateFunction> bracketing,
104                                   final double baseRoot, final double min, final double max,
105                                   final AllowedSolution allowedSolution)
106        throws NoBracketingException {
107
108        if (allowedSolution == AllowedSolution.ANY_SIDE) {
109            // no further bracketing required
110            return baseRoot;
111        }
112
113        // find a very small interval bracketing the root
114        final double step = FastMath.max(bracketing.getAbsoluteAccuracy(),
115                                         FastMath.abs(baseRoot * bracketing.getRelativeAccuracy()));
116        double xLo        = FastMath.max(min, baseRoot - step);
117        double fLo        = f.value(xLo);
118        double xHi        = FastMath.min(max, baseRoot + step);
119        double fHi        = f.value(xHi);
120        int remainingEval = maxEval - 2;
121        while (remainingEval > 0) {
122
123            if ((fLo >= 0 && fHi <= 0) || (fLo <= 0 && fHi >= 0)) {
124                // compute the root on the selected side
125                return bracketing.solve(remainingEval, f, xLo, xHi, baseRoot, allowedSolution);
126            }
127
128            // try increasing the interval
129            boolean changeLo = false;
130            boolean changeHi = false;
131            if (fLo < fHi) {
132                // increasing function
133                if (fLo >= 0) {
134                    changeLo = true;
135                } else {
136                    changeHi = true;
137                }
138            } else if (fLo > fHi) {
139                // decreasing function
140                if (fLo <= 0) {
141                    changeLo = true;
142                } else {
143                    changeHi = true;
144                }
145            } else {
146                // unknown variation
147                changeLo = true;
148                changeHi = true;
149            }
150
151            // update the lower bound
152            if (changeLo) {
153                xLo = FastMath.max(min, xLo - step);
154                fLo  = f.value(xLo);
155                remainingEval--;
156            }
157
158            // update the higher bound
159            if (changeHi) {
160                xHi = FastMath.min(max, xHi + step);
161                fHi  = f.value(xHi);
162                remainingEval--;
163            }
164
165        }
166
167        throw new NoBracketingException(LocalizedFormats.FAILED_BRACKETING,
168                                        xLo, xHi, fLo, fHi,
169                                        maxEval - remainingEval, maxEval, baseRoot,
170                                        min, max);
171
172    }
173
174    /**
175     * This method simply calls {@link #bracket(UnivariateFunction, double, double, double,
176     * double, double, int) bracket(function, initial, lowerBound, upperBound, q, r, maximumIterations)}
177     * with {@code q} and {@code r} set to 1.0 and {@code maximumIterations} set to {@code Integer.MAX_VALUE}.
178     * <p>
179     * <strong>Note: </strong> this method can take {@code Integer.MAX_VALUE}
180     * iterations to throw a {@code ConvergenceException.}  Unless you are
181     * confident that there is a root between {@code lowerBound} and
182     * {@code upperBound} near {@code initial}, it is better to use
183     * {@link #bracket(UnivariateFunction, double, double, double, double,double, int)
184     * bracket(function, initial, lowerBound, upperBound, q, r, maximumIterations)},
185     * explicitly specifying the maximum number of iterations.</p>
186     *
187     * @param function Function.
188     * @param initial Initial midpoint of interval being expanded to
189     * bracket a root.
190     * @param lowerBound Lower bound (a is never lower than this value)
191     * @param upperBound Upper bound (b never is greater than this
192     * value).
193     * @return a two-element array holding a and b.
194     * @throws NoBracketingException if a root cannot be bracketted.
195     * @throws NotStrictlyPositiveException if {@code maximumIterations <= 0}.
196     * @throws NullArgumentException if {@code function} is {@code null}.
197     */
198    public static double[] bracket(UnivariateFunction function,
199                                   double initial,
200                                   double lowerBound, double upperBound)
201        throws NullArgumentException,
202               NotStrictlyPositiveException,
203               NoBracketingException {
204        return bracket(function, initial, lowerBound, upperBound, 1.0, 1.0, Integer.MAX_VALUE);
205    }
206
207     /**
208     * This method simply calls {@link #bracket(UnivariateFunction, double, double, double,
209     * double, double, int) bracket(function, initial, lowerBound, upperBound, q, r, maximumIterations)}
210     * with {@code q} and {@code r} set to 1.0.
211     * @param function Function.
212     * @param initial Initial midpoint of interval being expanded to
213     * bracket a root.
214     * @param lowerBound Lower bound (a is never lower than this value).
215     * @param upperBound Upper bound (b never is greater than this
216     * value).
217     * @param maximumIterations Maximum number of iterations to perform
218     * @return a two element array holding a and b.
219     * @throws NoBracketingException if the algorithm fails to find a and b
220     * satisfying the desired conditions.
221     * @throws NotStrictlyPositiveException if {@code maximumIterations <= 0}.
222     * @throws NullArgumentException if {@code function} is {@code null}.
223     */
224    public static double[] bracket(UnivariateFunction function,
225                                   double initial,
226                                   double lowerBound, double upperBound,
227                                   int maximumIterations)
228        throws NullArgumentException,
229               NotStrictlyPositiveException,
230               NoBracketingException {
231        return bracket(function, initial, lowerBound, upperBound, 1.0, 1.0, maximumIterations);
232    }
233
234    /**
235     * This method attempts to find two values a and b satisfying <ul>
236     * <li> {@code lowerBound <= a < initial < b <= upperBound} </li>
237     * <li> {@code f(a) * f(b) <= 0} </li>
238     * </ul>
239     * If {@code f} is continuous on {@code [a,b]}, this means that {@code a}
240     * and {@code b} bracket a root of {@code f}.
241     * <p>
242     * The algorithm checks the sign of \( f(l_k) \) and \( f(u_k) \) for increasing
243     * values of k, where \( l_k = max(lower, initial - \delta_k) \),
244     * \( u_k = min(upper, initial + \delta_k) \), using recurrence
245     * \( \delta_{k+1} = r \delta_k + q, \delta_0 = 0\) and starting search with \( k=1 \).
246     * The algorithm stops when one of the following happens: <ul>
247     * <li> at least one positive and one negative value have been found --  success!</li>
248     * <li> both endpoints have reached their respective limits -- NoBracketingException </li>
249     * <li> {@code maximumIterations} iterations elapse -- NoBracketingException </li></ul>
250     * <p>
251     * If different signs are found at first iteration ({@code k=1}), then the returned
252     * interval will be \( [a, b] = [l_1, u_1] \). If different signs are found at a later
253     * iteration {@code k>1}, then the returned interval will be either
254     * \( [a, b] = [l_{k+1}, l_{k}] \) or \( [a, b] = [u_{k}, u_{k+1}] \). A root solver called
255     * with these parameters will therefore start with the smallest bracketing interval known
256     * at this step.
257     * </p>
258     * <p>
259     * Interval expansion rate is tuned by changing the recurrence parameters {@code r} and
260     * {@code q}. When the multiplicative factor {@code r} is set to 1, the sequence is a
261     * simple arithmetic sequence with linear increase. When the multiplicative factor {@code r}
262     * is larger than 1, the sequence has an asymptotically exponential rate. Note than the
263     * additive parameter {@code q} should never be set to zero, otherwise the interval would
264     * degenerate to the single initial point for all values of {@code k}.
265     * </p>
266     * <p>
267     * As a rule of thumb, when the location of the root is expected to be approximately known
268     * within some error margin, {@code r} should be set to 1 and {@code q} should be set to the
269     * order of magnitude of the error margin. When the location of the root is really a wild guess,
270     * then {@code r} should be set to a value larger than 1 (typically 2 to double the interval
271     * length at each iteration) and {@code q} should be set according to half the initial
272     * search interval length.
273     * </p>
274     * <p>
275     * As an example, if we consider the trivial function {@code f(x) = 1 - x} and use
276     * {@code initial = 4}, {@code r = 1}, {@code q = 2}, the algorithm will compute
277     * {@code f(4-2) = f(2) = -1} and {@code f(4+2) = f(6) = -5} for {@code k = 1}, then
278     * {@code f(4-4) = f(0) = +1} and {@code f(4+4) = f(8) = -7} for {@code k = 2}. Then it will
279     * return the interval {@code [0, 2]} as the smallest one known to be bracketing the root.
280     * As shown by this example, the initial value (here {@code 4}) may lie outside of the returned
281     * bracketing interval.
282     * </p>
283     * @param function function to check
284     * @param initial Initial midpoint of interval being expanded to
285     * bracket a root.
286     * @param lowerBound Lower bound (a is never lower than this value).
287     * @param upperBound Upper bound (b never is greater than this
288     * value).
289     * @param q additive offset used to compute bounds sequence (must be strictly positive)
290     * @param r multiplicative factor used to compute bounds sequence
291     * @param maximumIterations Maximum number of iterations to perform
292     * @return a two element array holding the bracketing values.
293     * @exception NoBracketingException if function cannot be bracketed in the search interval
294     */
295    public static double[] bracket(final UnivariateFunction function, final double initial,
296                                   final double lowerBound, final double upperBound,
297                                   final double q, final double r, final int maximumIterations)
298        throws NoBracketingException {
299
300        if (function == null) {
301            throw new NullArgumentException(LocalizedFormats.FUNCTION);
302        }
303        if (q <= 0)  {
304            throw new NotStrictlyPositiveException(q);
305        }
306        if (maximumIterations <= 0)  {
307            throw new NotStrictlyPositiveException(LocalizedFormats.INVALID_MAX_ITERATIONS, maximumIterations);
308        }
309        verifySequence(lowerBound, initial, upperBound);
310
311        // initialize the recurrence
312        double a     = initial;
313        double b     = initial;
314        double fa    = Double.NaN;
315        double fb    = Double.NaN;
316        double delta = 0;
317
318        for (int numIterations = 0;
319             (numIterations < maximumIterations) && (a > lowerBound || b < upperBound);
320             ++numIterations) {
321
322            final double previousA  = a;
323            final double previousFa = fa;
324            final double previousB  = b;
325            final double previousFb = fb;
326
327            delta = r * delta + q;
328            a     = FastMath.max(initial - delta, lowerBound);
329            b     = FastMath.min(initial + delta, upperBound);
330            fa    = function.value(a);
331            fb    = function.value(b);
332
333            if (numIterations == 0) {
334                // at first iteration, we don't have a previous interval
335                // we simply compare both sides of the initial interval
336                if (fa * fb <= 0) {
337                    // the first interval already brackets a root
338                    return new double[] { a, b };
339                }
340            } else {
341                // we have a previous interval with constant sign and expand it,
342                // we expect sign changes to occur at boundaries
343                if (fa * previousFa <= 0) {
344                    // sign change detected at near lower bound
345                    return new double[] { a, previousA };
346                } else if (fb * previousFb <= 0) {
347                    // sign change detected at near upper bound
348                    return new double[] { previousB, b };
349                }
350            }
351
352        }
353
354        // no bracketing found
355        throw new NoBracketingException(a, b, fa, fb);
356
357    }
358
359    /**
360     * Compute the midpoint of two values.
361     *
362     * @param a first value.
363     * @param b second value.
364     * @return the midpoint.
365     */
366    public static double midpoint(double a, double b) {
367        return (a + b) * 0.5;
368    }
369
370    /**
371     * Check whether the interval bounds bracket a root. That is, if the
372     * values at the endpoints are not equal to zero, then the function takes
373     * opposite signs at the endpoints.
374     *
375     * @param function Function.
376     * @param lower Lower endpoint.
377     * @param upper Upper endpoint.
378     * @return {@code true} if the function values have opposite signs at the
379     * given points.
380     * @throws NullArgumentException if {@code function} is {@code null}.
381     */
382    public static boolean isBracketing(UnivariateFunction function,
383                                       final double lower,
384                                       final double upper)
385        throws NullArgumentException {
386        if (function == null) {
387            throw new NullArgumentException(LocalizedFormats.FUNCTION);
388        }
389        final double fLo = function.value(lower);
390        final double fHi = function.value(upper);
391        return (fLo >= 0 && fHi <= 0) || (fLo <= 0 && fHi >= 0);
392    }
393
394    /**
395     * Check whether the arguments form a (strictly) increasing sequence.
396     *
397     * @param start First number.
398     * @param mid Second number.
399     * @param end Third number.
400     * @return {@code true} if the arguments form an increasing sequence.
401     */
402    public static boolean isSequence(final double start,
403                                     final double mid,
404                                     final double end) {
405        return (start < mid) && (mid < end);
406    }
407
408    /**
409     * Check that the endpoints specify an interval.
410     *
411     * @param lower Lower endpoint.
412     * @param upper Upper endpoint.
413     * @throws NumberIsTooLargeException if {@code lower >= upper}.
414     */
415    public static void verifyInterval(final double lower,
416                                      final double upper)
417        throws NumberIsTooLargeException {
418        if (lower >= upper) {
419            throw new NumberIsTooLargeException(LocalizedFormats.ENDPOINTS_NOT_AN_INTERVAL,
420                                                lower, upper, false);
421        }
422    }
423
424    /**
425     * Check that {@code lower < initial < upper}.
426     *
427     * @param lower Lower endpoint.
428     * @param initial Initial value.
429     * @param upper Upper endpoint.
430     * @throws NumberIsTooLargeException if {@code lower >= initial} or
431     * {@code initial >= upper}.
432     */
433    public static void verifySequence(final double lower,
434                                      final double initial,
435                                      final double upper)
436        throws NumberIsTooLargeException {
437        verifyInterval(lower, initial);
438        verifyInterval(initial, upper);
439    }
440
441    /**
442     * Check that the endpoints specify an interval and the end points
443     * bracket a root.
444     *
445     * @param function Function.
446     * @param lower Lower endpoint.
447     * @param upper Upper endpoint.
448     * @throws NoBracketingException if the function has the same sign at the
449     * endpoints.
450     * @throws NullArgumentException if {@code function} is {@code null}.
451     */
452    public static void verifyBracketing(UnivariateFunction function,
453                                        final double lower,
454                                        final double upper)
455        throws NullArgumentException,
456               NoBracketingException {
457        if (function == null) {
458            throw new NullArgumentException(LocalizedFormats.FUNCTION);
459        }
460        verifyInterval(lower, upper);
461        if (!isBracketing(function, lower, upper)) {
462            throw new NoBracketingException(lower, upper,
463                                            function.value(lower),
464                                            function.value(upper));
465        }
466    }
467}