UnivariateSolverUtils.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.math4.legacy.analysis.solvers;

  18. import org.apache.commons.math4.legacy.analysis.UnivariateFunction;
  19. import org.apache.commons.math4.legacy.exception.NoBracketingException;
  20. import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
  21. import org.apache.commons.math4.legacy.exception.NullArgumentException;
  22. import org.apache.commons.math4.legacy.exception.NumberIsTooLargeException;
  23. import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
  24. import org.apache.commons.math4.core.jdkmath.JdkMath;

  25. /**
  26.  * Utility routines for {@link UnivariateSolver} objects.
  27.  *
  28.  */
  29. public final class UnivariateSolverUtils {
  30.     /**
  31.      * Class contains only static methods.
  32.      */
  33.     private UnivariateSolverUtils() {}

  34.     /**
  35.      * Convenience method to find a zero of a univariate real function.  A default
  36.      * solver is used.
  37.      *
  38.      * @param function Function.
  39.      * @param x0 Lower bound for the interval.
  40.      * @param x1 Upper bound for the interval.
  41.      * @return a value where the function is zero.
  42.      * @throws NoBracketingException if the function has the same sign at the
  43.      * endpoints.
  44.      * @throws NullArgumentException if {@code function} is {@code null}.
  45.      */
  46.     public static double solve(UnivariateFunction function, double x0, double x1)
  47.         throws NullArgumentException,
  48.                NoBracketingException {
  49.         if (function == null) {
  50.             throw new NullArgumentException(LocalizedFormats.FUNCTION);
  51.         }
  52.         final UnivariateSolver solver = new BrentSolver();
  53.         return solver.solve(Integer.MAX_VALUE, function, x0, x1);
  54.     }

  55.     /**
  56.      * Convenience method to find a zero of a univariate real function.  A default
  57.      * solver is used.
  58.      *
  59.      * @param function Function.
  60.      * @param x0 Lower bound for the interval.
  61.      * @param x1 Upper bound for the interval.
  62.      * @param absoluteAccuracy Accuracy to be used by the solver.
  63.      * @return a value where the function is zero.
  64.      * @throws NoBracketingException if the function has the same sign at the
  65.      * endpoints.
  66.      * @throws NullArgumentException if {@code function} is {@code null}.
  67.      */
  68.     public static double solve(UnivariateFunction function,
  69.                                double x0, double x1,
  70.                                double absoluteAccuracy)
  71.         throws NullArgumentException,
  72.                NoBracketingException {
  73.         if (function == null) {
  74.             throw new NullArgumentException(LocalizedFormats.FUNCTION);
  75.         }
  76.         final UnivariateSolver solver = new BrentSolver(absoluteAccuracy);
  77.         return solver.solve(Integer.MAX_VALUE, function, x0, x1);
  78.     }

  79.     /**
  80.      * Force a root found by a non-bracketing solver to lie on a specified side,
  81.      * as if the solver were a bracketing one.
  82.      *
  83.      * @param maxEval maximal number of new evaluations of the function
  84.      * (evaluations already done for finding the root should have already been subtracted
  85.      * from this number)
  86.      * @param f function to solve
  87.      * @param bracketing bracketing solver to use for shifting the root
  88.      * @param baseRoot original root found by a previous non-bracketing solver
  89.      * @param min minimal bound of the search interval
  90.      * @param max maximal bound of the search interval
  91.      * @param allowedSolution the kind of solutions that the root-finding algorithm may
  92.      * accept as solutions.
  93.      * @return a root approximation, on the specified side of the exact root
  94.      * @throws NoBracketingException if the function has the same sign at the
  95.      * endpoints.
  96.      */
  97.     public static double forceSide(final int maxEval, final UnivariateFunction f,
  98.                                    final BracketedUnivariateSolver<UnivariateFunction> bracketing,
  99.                                    final double baseRoot, final double min, final double max,
  100.                                    final AllowedSolution allowedSolution)
  101.         throws NoBracketingException {

  102.         if (allowedSolution == AllowedSolution.ANY_SIDE) {
  103.             // no further bracketing required
  104.             return baseRoot;
  105.         }

  106.         // find a very small interval bracketing the root
  107.         final double step = JdkMath.max(bracketing.getAbsoluteAccuracy(),
  108.                                          JdkMath.abs(baseRoot * bracketing.getRelativeAccuracy()));
  109.         double xLo        = JdkMath.max(min, baseRoot - step);
  110.         double fLo        = f.value(xLo);
  111.         double xHi        = JdkMath.min(max, baseRoot + step);
  112.         double fHi        = f.value(xHi);
  113.         int remainingEval = maxEval - 2;
  114.         while (remainingEval > 0) {

  115.             if ((fLo >= 0 && fHi <= 0) || (fLo <= 0 && fHi >= 0)) {
  116.                 // compute the root on the selected side
  117.                 return bracketing.solve(remainingEval, f, xLo, xHi, baseRoot, allowedSolution);
  118.             }

  119.             // try increasing the interval
  120.             boolean changeLo = false;
  121.             boolean changeHi = false;
  122.             if (fLo < fHi) {
  123.                 // increasing function
  124.                 if (fLo >= 0) {
  125.                     changeLo = true;
  126.                 } else {
  127.                     changeHi = true;
  128.                 }
  129.             } else if (fLo > fHi) {
  130.                 // decreasing function
  131.                 if (fLo <= 0) {
  132.                     changeLo = true;
  133.                 } else {
  134.                     changeHi = true;
  135.                 }
  136.             } else {
  137.                 // unknown variation
  138.                 changeLo = true;
  139.                 changeHi = true;
  140.             }

  141.             // update the lower bound
  142.             if (changeLo) {
  143.                 xLo = JdkMath.max(min, xLo - step);
  144.                 fLo  = f.value(xLo);
  145.                 remainingEval--;
  146.             }

  147.             // update the higher bound
  148.             if (changeHi) {
  149.                 xHi = JdkMath.min(max, xHi + step);
  150.                 fHi  = f.value(xHi);
  151.                 remainingEval--;
  152.             }
  153.         }

  154.         throw new NoBracketingException(LocalizedFormats.FAILED_BRACKETING,
  155.                                         xLo, xHi, fLo, fHi,
  156.                                         maxEval - remainingEval, maxEval, baseRoot,
  157.                                         min, max);
  158.     }

  159.     /**
  160.      * This method simply calls {@link #bracket(UnivariateFunction, double, double, double,
  161.      * double, double, int) bracket(function, initial, lowerBound, upperBound, q, r, maximumIterations)}
  162.      * with {@code q} and {@code r} set to 1.0 and {@code maximumIterations} set to {@code Integer.MAX_VALUE}.
  163.      * <p>
  164.      * <strong>Note: </strong> this method can take {@code Integer.MAX_VALUE}
  165.      * iterations to throw a {@code ConvergenceException.}  Unless you are
  166.      * confident that there is a root between {@code lowerBound} and
  167.      * {@code upperBound} near {@code initial}, it is better to use
  168.      * {@link #bracket(UnivariateFunction, double, double, double, double,double, int)
  169.      * bracket(function, initial, lowerBound, upperBound, q, r, maximumIterations)},
  170.      * explicitly specifying the maximum number of iterations.</p>
  171.      *
  172.      * @param function Function.
  173.      * @param initial Initial midpoint of interval being expanded to
  174.      * bracket a root.
  175.      * @param lowerBound Lower bound (a is never lower than this value)
  176.      * @param upperBound Upper bound (b never is greater than this
  177.      * value).
  178.      * @return a two-element array holding a and b.
  179.      * @throws NoBracketingException if a root cannot be bracketted.
  180.      * @throws NotStrictlyPositiveException if {@code maximumIterations <= 0}.
  181.      * @throws NullArgumentException if {@code function} is {@code null}.
  182.      */
  183.     public static double[] bracket(UnivariateFunction function,
  184.                                    double initial,
  185.                                    double lowerBound, double upperBound)
  186.         throws NullArgumentException,
  187.                NotStrictlyPositiveException,
  188.                NoBracketingException {
  189.         return bracket(function, initial, lowerBound, upperBound, 1.0, 1.0, Integer.MAX_VALUE);
  190.     }

  191.      /**
  192.      * This method simply calls {@link #bracket(UnivariateFunction, double, double, double,
  193.      * double, double, int) bracket(function, initial, lowerBound, upperBound, q, r, maximumIterations)}
  194.      * with {@code q} and {@code r} set to 1.0.
  195.      * @param function Function.
  196.      * @param initial Initial midpoint of interval being expanded to
  197.      * bracket a root.
  198.      * @param lowerBound Lower bound (a is never lower than this value).
  199.      * @param upperBound Upper bound (b never is greater than this
  200.      * value).
  201.      * @param maximumIterations Maximum number of iterations to perform
  202.      * @return a two element array holding a and b.
  203.      * @throws NoBracketingException if the algorithm fails to find a and b
  204.      * satisfying the desired conditions.
  205.      * @throws NotStrictlyPositiveException if {@code maximumIterations <= 0}.
  206.      * @throws NullArgumentException if {@code function} is {@code null}.
  207.      */
  208.     public static double[] bracket(UnivariateFunction function,
  209.                                    double initial,
  210.                                    double lowerBound, double upperBound,
  211.                                    int maximumIterations)
  212.         throws NullArgumentException,
  213.                NotStrictlyPositiveException,
  214.                NoBracketingException {
  215.         return bracket(function, initial, lowerBound, upperBound, 1.0, 1.0, maximumIterations);
  216.     }

  217.     /**
  218.      * This method attempts to find two values a and b satisfying <ul>
  219.      * <li> {@code lowerBound <= a < initial < b <= upperBound} </li>
  220.      * <li> {@code f(a) * f(b) <= 0} </li>
  221.      * </ul>
  222.      * If {@code f} is continuous on {@code [a,b]}, this means that {@code a}
  223.      * and {@code b} bracket a root of {@code f}.
  224.      * <p>
  225.      * The algorithm checks the sign of \( f(l_k) \) and \( f(u_k) \) for increasing
  226.      * values of k, where \( l_k = max(lower, initial - \delta_k) \),
  227.      * \( u_k = min(upper, initial + \delta_k) \), using recurrence
  228.      * \( \delta_{k+1} = r \delta_k + q, \delta_0 = 0\) and starting search with \( k=1 \).
  229.      * The algorithm stops when one of the following happens: <ul>
  230.      * <li> at least one positive and one negative value have been found --  success!</li>
  231.      * <li> both endpoints have reached their respective limits -- NoBracketingException </li>
  232.      * <li> {@code maximumIterations} iterations elapse -- NoBracketingException </li></ul>
  233.      * <p>
  234.      * If different signs are found at first iteration ({@code k=1}), then the returned
  235.      * interval will be \( [a, b] = [l_1, u_1] \). If different signs are found at a later
  236.      * iteration {@code k>1}, then the returned interval will be either
  237.      * \( [a, b] = [l_{k+1}, l_{k}] \) or \( [a, b] = [u_{k}, u_{k+1}] \). A root solver called
  238.      * with these parameters will therefore start with the smallest bracketing interval known
  239.      * at this step.
  240.      * </p>
  241.      * <p>
  242.      * Interval expansion rate is tuned by changing the recurrence parameters {@code r} and
  243.      * {@code q}. When the multiplicative factor {@code r} is set to 1, the sequence is a
  244.      * simple arithmetic sequence with linear increase. When the multiplicative factor {@code r}
  245.      * is larger than 1, the sequence has an asymptotically exponential rate. Note than the
  246.      * additive parameter {@code q} should never be set to zero, otherwise the interval would
  247.      * degenerate to the single initial point for all values of {@code k}.
  248.      * </p>
  249.      * <p>
  250.      * As a rule of thumb, when the location of the root is expected to be approximately known
  251.      * within some error margin, {@code r} should be set to 1 and {@code q} should be set to the
  252.      * order of magnitude of the error margin. When the location of the root is really a wild guess,
  253.      * then {@code r} should be set to a value larger than 1 (typically 2 to double the interval
  254.      * length at each iteration) and {@code q} should be set according to half the initial
  255.      * search interval length.
  256.      * </p>
  257.      * <p>
  258.      * As an example, if we consider the trivial function {@code f(x) = 1 - x} and use
  259.      * {@code initial = 4}, {@code r = 1}, {@code q = 2}, the algorithm will compute
  260.      * {@code f(4-2) = f(2) = -1} and {@code f(4+2) = f(6) = -5} for {@code k = 1}, then
  261.      * {@code f(4-4) = f(0) = +1} and {@code f(4+4) = f(8) = -7} for {@code k = 2}. Then it will
  262.      * return the interval {@code [0, 2]} as the smallest one known to be bracketing the root.
  263.      * As shown by this example, the initial value (here {@code 4}) may lie outside of the returned
  264.      * bracketing interval.
  265.      * </p>
  266.      * @param function function to check
  267.      * @param initial Initial midpoint of interval being expanded to
  268.      * bracket a root.
  269.      * @param lowerBound Lower bound (a is never lower than this value).
  270.      * @param upperBound Upper bound (b never is greater than this
  271.      * value).
  272.      * @param q additive offset used to compute bounds sequence (must be strictly positive)
  273.      * @param r multiplicative factor used to compute bounds sequence
  274.      * @param maximumIterations Maximum number of iterations to perform
  275.      * @return a two element array holding the bracketing values.
  276.      * @exception NoBracketingException if function cannot be bracketed in the search interval
  277.      */
  278.     public static double[] bracket(final UnivariateFunction function, final double initial,
  279.                                    final double lowerBound, final double upperBound,
  280.                                    final double q, final double r, final int maximumIterations)
  281.         throws NoBracketingException {

  282.         if (function == null) {
  283.             throw new NullArgumentException(LocalizedFormats.FUNCTION);
  284.         }
  285.         if (q <= 0)  {
  286.             throw new NotStrictlyPositiveException(q);
  287.         }
  288.         if (maximumIterations <= 0)  {
  289.             throw new NotStrictlyPositiveException(LocalizedFormats.INVALID_MAX_ITERATIONS, maximumIterations);
  290.         }
  291.         verifySequence(lowerBound, initial, upperBound);

  292.         // initialize the recurrence
  293.         double a     = initial;
  294.         double b     = initial;
  295.         double fa    = Double.NaN;
  296.         double fb    = Double.NaN;
  297.         double delta = 0;

  298.         for (int numIterations = 0;
  299.              numIterations < maximumIterations && (a > lowerBound || b < upperBound);
  300.              ++numIterations) {

  301.             final double previousA  = a;
  302.             final double previousFa = fa;
  303.             final double previousB  = b;
  304.             final double previousFb = fb;

  305.             delta = r * delta + q;
  306.             a     = JdkMath.max(initial - delta, lowerBound);
  307.             b     = JdkMath.min(initial + delta, upperBound);
  308.             fa    = function.value(a);
  309.             fb    = function.value(b);

  310.             if (numIterations == 0) {
  311.                 // at first iteration, we don't have a previous interval
  312.                 // we simply compare both sides of the initial interval
  313.                 if (fa * fb <= 0) {
  314.                     // the first interval already brackets a root
  315.                     return new double[] { a, b };
  316.                 }
  317.             } else {
  318.                 // we have a previous interval with constant sign and expand it,
  319.                 // we expect sign changes to occur at boundaries
  320.                 if (fa * previousFa <= 0) {
  321.                     // sign change detected at near lower bound
  322.                     return new double[] { a, previousA };
  323.                 } else if (fb * previousFb <= 0) {
  324.                     // sign change detected at near upper bound
  325.                     return new double[] { previousB, b };
  326.                 }
  327.             }
  328.         }

  329.         // no bracketing found
  330.         throw new NoBracketingException(a, b, fa, fb);
  331.     }

  332.     /**
  333.      * Compute the midpoint of two values.
  334.      *
  335.      * @param a first value.
  336.      * @param b second value.
  337.      * @return the midpoint.
  338.      */
  339.     public static double midpoint(double a, double b) {
  340.         return (a + b) * 0.5;
  341.     }

  342.     /**
  343.      * Check whether the interval bounds bracket a root. That is, if the
  344.      * values at the endpoints are not equal to zero, then the function takes
  345.      * opposite signs at the endpoints.
  346.      *
  347.      * @param function Function.
  348.      * @param lower Lower endpoint.
  349.      * @param upper Upper endpoint.
  350.      * @return {@code true} if the function values have opposite signs at the
  351.      * given points.
  352.      * @throws NullArgumentException if {@code function} is {@code null}.
  353.      */
  354.     public static boolean isBracketing(UnivariateFunction function,
  355.                                        final double lower,
  356.                                        final double upper)
  357.         throws NullArgumentException {
  358.         if (function == null) {
  359.             throw new NullArgumentException(LocalizedFormats.FUNCTION);
  360.         }
  361.         final double fLo = function.value(lower);
  362.         final double fHi = function.value(upper);
  363.         return (fLo >= 0 && fHi <= 0) || (fLo <= 0 && fHi >= 0);
  364.     }

  365.     /**
  366.      * Check whether the arguments form a (strictly) increasing sequence.
  367.      *
  368.      * @param start First number.
  369.      * @param mid Second number.
  370.      * @param end Third number.
  371.      * @return {@code true} if the arguments form an increasing sequence.
  372.      */
  373.     public static boolean isSequence(final double start,
  374.                                      final double mid,
  375.                                      final double end) {
  376.         return start < mid && mid < end;
  377.     }

  378.     /**
  379.      * Check that the endpoints specify an interval.
  380.      *
  381.      * @param lower Lower endpoint.
  382.      * @param upper Upper endpoint.
  383.      * @throws NumberIsTooLargeException if {@code lower >= upper}.
  384.      */
  385.     public static void verifyInterval(final double lower,
  386.                                       final double upper)
  387.         throws NumberIsTooLargeException {
  388.         if (lower >= upper) {
  389.             throw new NumberIsTooLargeException(LocalizedFormats.ENDPOINTS_NOT_AN_INTERVAL,
  390.                                                 lower, upper, false);
  391.         }
  392.     }

  393.     /**
  394.      * Check that {@code lower < initial < upper}.
  395.      *
  396.      * @param lower Lower endpoint.
  397.      * @param initial Initial value.
  398.      * @param upper Upper endpoint.
  399.      * @throws NumberIsTooLargeException if {@code lower >= initial} or
  400.      * {@code initial >= upper}.
  401.      */
  402.     public static void verifySequence(final double lower,
  403.                                       final double initial,
  404.                                       final double upper)
  405.         throws NumberIsTooLargeException {
  406.         verifyInterval(lower, initial);
  407.         verifyInterval(initial, upper);
  408.     }

  409.     /**
  410.      * Check that the endpoints specify an interval and the end points
  411.      * bracket a root.
  412.      *
  413.      * @param function Function.
  414.      * @param lower Lower endpoint.
  415.      * @param upper Upper endpoint.
  416.      * @throws NoBracketingException if the function has the same sign at the
  417.      * endpoints.
  418.      * @throws NullArgumentException if {@code function} is {@code null}.
  419.      */
  420.     public static void verifyBracketing(UnivariateFunction function,
  421.                                         final double lower,
  422.                                         final double upper)
  423.         throws NullArgumentException,
  424.                NoBracketingException {
  425.         if (function == null) {
  426.             throw new NullArgumentException(LocalizedFormats.FUNCTION);
  427.         }
  428.         verifyInterval(lower, upper);
  429.         if (!isBracketing(function, lower, upper)) {
  430.             throw new NoBracketingException(lower, upper,
  431.                                             function.value(lower),
  432.                                             function.value(upper));
  433.         }
  434.     }
  435. }