BrentOptimizer.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.statistics.inference;

  18. import java.util.function.DoubleUnaryOperator;
  19. import org.apache.commons.numbers.core.Precision;

  20. /**
  21.  * For a function defined on some interval {@code (lo, hi)}, this class
  22.  * finds an approximation {@code x} to the point at which the function
  23.  * attains its minimum.
  24.  * It implements Richard Brent's algorithm (from his book "Algorithms for
  25.  * Minimization without Derivatives", p. 79) for finding minima of real
  26.  * univariate functions.
  27.  *
  28.  * <P>This code is an adaptation, partly based on the Python code from SciPy
  29.  * (module "optimize.py" v0.5); the original algorithm is also modified:
  30.  * <ul>
  31.  *  <li>to use an initial guess provided by the user,</li>
  32.  *  <li>to ensure that the best point encountered is the one returned.</li>
  33.  * </ul>
  34.  *
  35.  * <p>This class has been extracted from {@code o.a.c.math4.optim.univariate}
  36.  * and simplified to remove support for the UnivariateOptimizer interface.
  37.  * This removed the options: to find the maximum; use a custom convergence checker
  38.  * on the function value; and remove the maximum function evaluation count.
  39.  * The class now implements a single optimize method within the provided bracket
  40.  * from the given start position (with value).
  41.  *
  42.  * @since 1.1
  43.  */
  44. final class BrentOptimizer {
  45.     /** Golden section. (3 - sqrt(5)) / 2. */
  46.     private static final double GOLDEN_SECTION = 0.3819660112501051;
  47.     /** Minimum relative tolerance. 2 * eps = 2^-51. */
  48.     private static final double MIN_RELATIVE_TOLERANCE = 0x1.0p-51;

  49.     /** Relative threshold. */
  50.     private final double relativeThreshold;
  51.     /** Absolute threshold. */
  52.     private final double absoluteThreshold;
  53.     /** The number of function evaluations from the most recent call to optimize. */
  54.     private int evaluations;

  55.     /**
  56.      * This class holds a point and the value of an objective function at this
  57.      * point. This is a simple immutable container.
  58.      *
  59.      * @since 1.1
  60.      */
  61.     static final class PointValuePair {
  62.         /** Point. */
  63.         private final double point;
  64.         /** Value of the objective function at the point. */
  65.         private final double value;

  66.         /**
  67.          * @param point Point.
  68.          * @param value Value of an objective function at the point.
  69.          */
  70.         private PointValuePair(double point, double value) {
  71.             this.point = point;
  72.             this.value = value;
  73.         }

  74.         /**
  75.          * Create a point/objective function value pair.
  76.          *
  77.          * @param point Point.
  78.          * @param value Value of an objective function at the point.
  79.          * @return the pair
  80.          */
  81.         static PointValuePair of(double point, double value) {
  82.             return new PointValuePair(point, value);
  83.         }

  84.         /**
  85.          * Get the point.
  86.          *
  87.          * @return the point.
  88.          */
  89.         double getPoint() {
  90.             return point;
  91.         }

  92.         /**
  93.          * Get the value of the objective function.
  94.          *
  95.          * @return the stored value of the objective function.
  96.          */
  97.         double getValue() {
  98.             return value;
  99.         }
  100.     }

  101.     /**
  102.      * The arguments are used to implement the original stopping criterion
  103.      * of Brent's algorithm.
  104.      * {@code abs} and {@code rel} define a tolerance
  105.      * {@code tol = rel |x| + abs}. {@code rel} should be no smaller than
  106.      * <em>2 macheps</em> and preferably not much less than <em>sqrt(macheps)</em>,
  107.      * where <em>macheps</em> is the relative machine precision. {@code abs} must
  108.      * be positive.
  109.      *
  110.      * @param rel Relative threshold.
  111.      * @param abs Absolute threshold.
  112.      * @throws IllegalArgumentException if {@code abs <= 0}; or if {@code rel < 2 * Math.ulp(1.0)}
  113.      */
  114.     BrentOptimizer(double rel, double abs) {
  115.         if (rel >= MIN_RELATIVE_TOLERANCE) {
  116.             relativeThreshold = rel;
  117.             absoluteThreshold = Arguments.checkStrictlyPositive(abs);
  118.         } else {
  119.             // relative too small, or NaN
  120.             throw new InferenceException(InferenceException.X_LT_Y, rel, MIN_RELATIVE_TOLERANCE);
  121.         }
  122.     }

  123.     /**
  124.      * Gets the number of function evaluations from the most recent call to
  125.      * {@link #optimize(DoubleUnaryOperator, double, double, double, double) optimize}.
  126.      *
  127.      * @return the function evaluations
  128.      */
  129.     int getEvaluations() {
  130.         return evaluations;
  131.     }

  132.     /**
  133.      * Search for the minimum inside the provided interval. The bracket must satisfy
  134.      * the equalities {@code lo < mid < hi} or {@code hi < mid < lo}.
  135.      *
  136.      * <p>Note: This function accepts the initial guess and the function value at that point.
  137.      * This is done for convenience as this internal class is used where the caller already
  138.      * knows the function value.
  139.      *
  140.      * @param func Function to solve.
  141.      * @param lo Lower bound of the search interval.
  142.      * @param hi Higher bound of the search interval.
  143.      * @param mid Start point.
  144.      * @param fMid Function value at the start point.
  145.      * @return the value where the function is minimum.
  146.      * @throws IllegalArgumentException if start point is not within the search interval
  147.      * @throws IllegalStateException if the maximum number of iterations is exceeded
  148.      */
  149.     PointValuePair optimize(DoubleUnaryOperator func,
  150.                             double lo, double hi,
  151.                             double mid, double fMid) {
  152.         double a;
  153.         double b;
  154.         if (lo < hi) {
  155.             a = lo;
  156.             b = hi;
  157.         } else {
  158.             a = hi;
  159.             b = lo;
  160.         }
  161.         if (!(a < mid && mid < b)) {
  162.             throw new InferenceException("Invalid bounds: (%s, %s) with start %s", a, b, mid);
  163.         }
  164.         double x = mid;
  165.         double v = x;
  166.         double w = x;
  167.         double d = 0;
  168.         double e = 0;
  169.         double fx = fMid;
  170.         double fv = fx;
  171.         double fw = fx;

  172.         // Best point encountered so far (which is the initial guess).
  173.         double bestX = x;
  174.         double bestFx = fx;

  175.         // No test for iteration count.
  176.         // Note that the termination criterion is based purely on the size of the current
  177.         // bracket and the current point x. If the function evaluates NaN then golden
  178.         // section steps are taken.
  179.         evaluations = 0;
  180.         for (;;) {
  181.             final double m = 0.5 * (a + b);
  182.             final double tol1 = relativeThreshold * Math.abs(x) + absoluteThreshold;
  183.             final double tol2 = 2 * tol1;

  184.             // Default termination (Brent's criterion).
  185.             if (Math.abs(x - m) <= tol2 - 0.5 * (b - a)) {
  186.                 return PointValuePair.of(bestX, bestFx);
  187.             }

  188.             if (Math.abs(e) > tol1) {
  189.                 // Fit parabola.
  190.                 double r = (x - w) * (fx - fv);
  191.                 double q = (x - v) * (fx - fw);
  192.                 double p = (x - v) * q - (x - w) * r;
  193.                 q = 2 * (q - r);

  194.                 if (q > 0) {
  195.                     p = -p;
  196.                 } else {
  197.                     q = -q;
  198.                 }

  199.                 r = e;
  200.                 e = d;

  201.                 if (p > q * (a - x) &&
  202.                     p < q * (b - x) &&
  203.                     Math.abs(p) < Math.abs(0.5 * q * r)) {
  204.                     // Parabolic interpolation step.
  205.                     d = p / q;
  206.                     final double u = x + d;

  207.                     // f must not be evaluated too close to a or b.
  208.                     if (u - a < tol2 || b - u < tol2) {
  209.                         if (x <= m) {
  210.                             d = tol1;
  211.                         } else {
  212.                             d = -tol1;
  213.                         }
  214.                     }
  215.                 } else {
  216.                     // Golden section step.
  217.                     if (x < m) {
  218.                         e = b - x;
  219.                     } else {
  220.                         e = a - x;
  221.                     }
  222.                     d = GOLDEN_SECTION * e;
  223.                 }
  224.             } else {
  225.                 // Golden section step.
  226.                 if (x < m) {
  227.                     e = b - x;
  228.                 } else {
  229.                     e = a - x;
  230.                 }
  231.                 d = GOLDEN_SECTION * e;
  232.             }

  233.             // Update by at least "tol1".
  234.             // Here d is never NaN so the evaluation point u is always finite.
  235.             final double u;
  236.             if (Math.abs(d) < tol1) {
  237.                 if (d >= 0) {
  238.                     u = x + tol1;
  239.                 } else {
  240.                     u = x - tol1;
  241.                 }
  242.             } else {
  243.                 u = x + d;
  244.             }

  245.             evaluations++;
  246.             final double fu = func.applyAsDouble(u);

  247.             // Maintain the best encountered result
  248.             if (fu < bestFx) {
  249.                 bestX = u;
  250.                 bestFx = fu;
  251.             }

  252.             // Note:
  253.             // Here the use of a convergence checker on f(x) previous vs current has been removed.
  254.             // Typically when the checker requires a very small relative difference
  255.             // the optimizer will stop before, or soon after, on Brent's criterion when that is
  256.             // configured with the smallest recommended convergence criteria.

  257.             // Update a, b, v, w and x.
  258.             if (fu <= fx) {
  259.                 if (u < x) {
  260.                     b = x;
  261.                 } else {
  262.                     a = x;
  263.                 }
  264.                 v = w;
  265.                 fv = fw;
  266.                 w = x;
  267.                 fw = fx;
  268.                 x = u;
  269.                 fx = fu;
  270.             } else {
  271.                 if (u < x) {
  272.                     a = u;
  273.                 } else {
  274.                     b = u;
  275.                 }
  276.                 if (fu <= fw ||
  277.                     Precision.equals(w, x)) {
  278.                     v = w;
  279.                     fv = fw;
  280.                     w = u;
  281.                     fw = fu;
  282.                 } else if (fu <= fv ||
  283.                            Precision.equals(v, x) ||
  284.                            Precision.equals(v, w)) {
  285.                     v = u;
  286.                     fv = fu;
  287.                 }
  288.             }
  289.         }
  290.     }
  291. }