BrentOptimizer.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.math4.legacy.optim.univariate;

  18. import org.apache.commons.math4.legacy.analysis.UnivariateFunction;
  19. import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
  20. import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
  21. import org.apache.commons.math4.legacy.optim.ConvergenceChecker;
  22. import org.apache.commons.math4.legacy.optim.nonlinear.scalar.GoalType;
  23. import org.apache.commons.math4.core.jdkmath.JdkMath;
  24. import org.apache.commons.numbers.core.Precision;

  25. /**
  26.  * For a function defined on some interval {@code (lo, hi)}, this class
  27.  * finds an approximation {@code x} to the point at which the function
  28.  * attains its minimum.
  29.  * It implements Richard Brent's algorithm (from his book "Algorithms for
  30.  * Minimization without Derivatives", p. 79) for finding minima of real
  31.  * univariate functions.
  32.  * <br>
  33.  * This code is an adaptation, partly based on the Python code from SciPy
  34.  * (module "optimize.py" v0.5); the original algorithm is also modified
  35.  * <ul>
  36.  *  <li>to use an initial guess provided by the user,</li>
  37.  *  <li>to ensure that the best point encountered is the one returned.</li>
  38.  * </ul>
  39.  *
  40.  * @since 2.0
  41.  */
  42. public class BrentOptimizer extends UnivariateOptimizer {
  43.     /**
  44.      * Golden section.
  45.      */
  46.     private static final double GOLDEN_SECTION = 0.5 * (3 - JdkMath.sqrt(5));
  47.     /**
  48.      * Minimum relative tolerance.
  49.      */
  50.     private static final double MIN_RELATIVE_TOLERANCE = 2 * JdkMath.ulp(1d);
  51.     /**
  52.      * Relative threshold.
  53.      */
  54.     private final double relativeThreshold;
  55.     /**
  56.      * Absolute threshold.
  57.      */
  58.     private final double absoluteThreshold;

  59.     /**
  60.      * The arguments are used implement the original stopping criterion
  61.      * of Brent's algorithm.
  62.      * {@code abs} and {@code rel} define a tolerance
  63.      * {@code tol = rel |x| + abs}. {@code rel} should be no smaller than
  64.      * <em>2 macheps</em> and preferably not much less than <em>sqrt(macheps)</em>,
  65.      * where <em>macheps</em> is the relative machine precision. {@code abs} must
  66.      * be positive.
  67.      *
  68.      * @param rel Relative threshold.
  69.      * @param abs Absolute threshold.
  70.      * @param checker Additional, user-defined, convergence checking
  71.      * procedure.
  72.      * @throws NotStrictlyPositiveException if {@code abs <= 0}.
  73.      * @throws NumberIsTooSmallException if {@code rel < 2 * Math.ulp(1d)}.
  74.      */
  75.     public BrentOptimizer(double rel,
  76.                           double abs,
  77.                           ConvergenceChecker<UnivariatePointValuePair> checker) {
  78.         super(checker);

  79.         if (rel < MIN_RELATIVE_TOLERANCE) {
  80.             throw new NumberIsTooSmallException(rel, MIN_RELATIVE_TOLERANCE, true);
  81.         }
  82.         if (abs <= 0) {
  83.             throw new NotStrictlyPositiveException(abs);
  84.         }

  85.         relativeThreshold = rel;
  86.         absoluteThreshold = abs;
  87.     }

  88.     /**
  89.      * The arguments are used for implementing the original stopping criterion
  90.      * of Brent's algorithm.
  91.      * {@code abs} and {@code rel} define a tolerance
  92.      * {@code tol = rel |x| + abs}. {@code rel} should be no smaller than
  93.      * <em>2 macheps</em> and preferably not much less than <em>sqrt(macheps)</em>,
  94.      * where <em>macheps</em> is the relative machine precision. {@code abs} must
  95.      * be positive.
  96.      *
  97.      * @param rel Relative threshold.
  98.      * @param abs Absolute threshold.
  99.      * @throws NotStrictlyPositiveException if {@code abs <= 0}.
  100.      * @throws NumberIsTooSmallException if {@code rel < 2 * Math.ulp(1d)}.
  101.      */
  102.     public BrentOptimizer(double rel,
  103.                           double abs) {
  104.         this(rel, abs, null);
  105.     }

  106.     /** {@inheritDoc} */
  107.     @Override
  108.     protected UnivariatePointValuePair doOptimize() {
  109.         final boolean isMinim = getGoalType() == GoalType.MINIMIZE;
  110.         final double lo = getMin();
  111.         final double mid = getStartValue();
  112.         final double hi = getMax();
  113.         final UnivariateFunction func = getObjectiveFunction();

  114.         // Optional additional convergence criteria.
  115.         final ConvergenceChecker<UnivariatePointValuePair> checker
  116.             = getConvergenceChecker();

  117.         double a;
  118.         double b;
  119.         if (lo < hi) {
  120.             a = lo;
  121.             b = hi;
  122.         } else {
  123.             a = hi;
  124.             b = lo;
  125.         }

  126.         double x = mid;
  127.         double v = x;
  128.         double w = x;
  129.         double d = 0;
  130.         double e = 0;
  131.         double fx = func.value(x);
  132.         if (!isMinim) {
  133.             fx = -fx;
  134.         }
  135.         double fv = fx;
  136.         double fw = fx;

  137.         UnivariatePointValuePair previous = null;
  138.         UnivariatePointValuePair current
  139.             = new UnivariatePointValuePair(x, isMinim ? fx : -fx);
  140.         // Best point encountered so far (which is the initial guess).
  141.         UnivariatePointValuePair best = current;

  142.         while (true) {
  143.             final double m = 0.5 * (a + b);
  144.             final double tol1 = relativeThreshold * JdkMath.abs(x) + absoluteThreshold;
  145.             final double tol2 = 2 * tol1;

  146.             // Default stopping criterion.
  147.             final boolean stop = JdkMath.abs(x - m) <= tol2 - 0.5 * (b - a);
  148.             if (!stop) {
  149.                 double p = 0;
  150.                 double q = 0;
  151.                 double r = 0;
  152.                 double u = 0;

  153.                 if (JdkMath.abs(e) > tol1) { // Fit parabola.
  154.                     r = (x - w) * (fx - fv);
  155.                     q = (x - v) * (fx - fw);
  156.                     p = (x - v) * q - (x - w) * r;
  157.                     q = 2 * (q - r);

  158.                     if (q > 0) {
  159.                         p = -p;
  160.                     } else {
  161.                         q = -q;
  162.                     }

  163.                     r = e;
  164.                     e = d;

  165.                     if (p > q * (a - x) &&
  166.                         p < q * (b - x) &&
  167.                         JdkMath.abs(p) < JdkMath.abs(0.5 * q * r)) {
  168.                         // Parabolic interpolation step.
  169.                         d = p / q;
  170.                         u = x + d;

  171.                         // f must not be evaluated too close to a or b.
  172.                         if (u - a < tol2 || b - u < tol2) {
  173.                             if (x <= m) {
  174.                                 d = tol1;
  175.                             } else {
  176.                                 d = -tol1;
  177.                             }
  178.                         }
  179.                     } else {
  180.                         // Golden section step.
  181.                         if (x < m) {
  182.                             e = b - x;
  183.                         } else {
  184.                             e = a - x;
  185.                         }
  186.                         d = GOLDEN_SECTION * e;
  187.                     }
  188.                 } else {
  189.                     // Golden section step.
  190.                     if (x < m) {
  191.                         e = b - x;
  192.                     } else {
  193.                         e = a - x;
  194.                     }
  195.                     d = GOLDEN_SECTION * e;
  196.                 }

  197.                 // Update by at least "tol1".
  198.                 if (JdkMath.abs(d) < tol1) {
  199.                     if (d >= 0) {
  200.                         u = x + tol1;
  201.                     } else {
  202.                         u = x - tol1;
  203.                     }
  204.                 } else {
  205.                     u = x + d;
  206.                 }

  207.                 double fu = func.value(u);
  208.                 if (!isMinim) {
  209.                     fu = -fu;
  210.                 }

  211.                 // User-defined convergence checker.
  212.                 previous = current;
  213.                 current = new UnivariatePointValuePair(u, isMinim ? fu : -fu);
  214.                 best = best(best,
  215.                             best(previous,
  216.                                  current,
  217.                                  isMinim),
  218.                             isMinim);

  219.                 if (checker != null && checker.converged(getIterations(), previous, current)) {
  220.                     return best;
  221.                 }

  222.                 // Update a, b, v, w and x.
  223.                 if (fu <= fx) {
  224.                     if (u < x) {
  225.                         b = x;
  226.                     } else {
  227.                         a = x;
  228.                     }
  229.                     v = w;
  230.                     fv = fw;
  231.                     w = x;
  232.                     fw = fx;
  233.                     x = u;
  234.                     fx = fu;
  235.                 } else {
  236.                     if (u < x) {
  237.                         a = u;
  238.                     } else {
  239.                         b = u;
  240.                     }
  241.                     if (fu <= fw ||
  242.                         Precision.equals(w, x)) {
  243.                         v = w;
  244.                         fv = fw;
  245.                         w = u;
  246.                         fw = fu;
  247.                     } else if (fu <= fv ||
  248.                                Precision.equals(v, x) ||
  249.                                Precision.equals(v, w)) {
  250.                         v = u;
  251.                         fv = fu;
  252.                     }
  253.                 }
  254.             } else { // Default termination (Brent's criterion).
  255.                 return best(best,
  256.                             best(previous,
  257.                                  current,
  258.                                  isMinim),
  259.                             isMinim);
  260.             }

  261.             incrementIterationCount();
  262.         }
  263.     }

  264.     /**
  265.      * Selects the best of two points.
  266.      *
  267.      * @param a Point and value.
  268.      * @param b Point and value.
  269.      * @param isMinim {@code true} if the selected point must be the one with
  270.      * the lowest value.
  271.      * @return the best point, or {@code null} if {@code a} and {@code b} are
  272.      * both {@code null}. When {@code a} and {@code b} have the same function
  273.      * value, {@code a} is returned.
  274.      */
  275.     private UnivariatePointValuePair best(UnivariatePointValuePair a,
  276.                                           UnivariatePointValuePair b,
  277.                                           boolean isMinim) {
  278.         if (a == null) {
  279.             return b;
  280.         }
  281.         if (b == null) {
  282.             return a;
  283.         }

  284.         if (isMinim) {
  285.             return a.getValue() <= b.getValue() ? a : b;
  286.         } else {
  287.             return a.getValue() >= b.getValue() ? a : b;
  288.         }
  289.     }
  290. }