BracketFinder.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.statistics.inference;

  18. import java.util.function.DoubleUnaryOperator;

  19. /**
  20.  * Provide an interval that brackets a local minimum of a function.
  21.  * This code is based on a Python implementation (from <em>SciPy</em>,
  22.  * module {@code optimize.py} v0.5).
  23.  *
  24.  * <p>This class has been extracted from {@code o.a.c.math4.optim.univariate}
  25.  * and modified to: remove support for bracketing a maximum; support bounds
  26.  * on the bracket; correct the sign of the denominator when the magnitude is small;
  27.  * and return true/false if there is a minimum strictly inside the bounds.
  28.  *
  29.  * @since 1.1
  30.  */
  31. class BracketFinder {
  32.     /** Tolerance to avoid division by zero. */
  33.     private static final double EPS_MIN = 1e-21;
  34.     /** Golden section. */
  35.     private static final double GOLD = 1.6180339887498948482;
  36.     /** Factor for expanding the interval. */
  37.     private final double growLimit;
  38.     /**  Number of allowed function evaluations. */
  39.     private final int maxEvaluations;
  40.     /** Number of function evaluations performed in the last search. */
  41.     private int evaluations;
  42.     /** Lower bound of the bracket. */
  43.     private double lo;
  44.     /** Higher bound of the bracket. */
  45.     private double hi;
  46.     /** Point inside the bracket. */
  47.     private double mid;
  48.     /** Function value at {@link #lo}. */
  49.     private double fLo;
  50.     /** Function value at {@link #hi}. */
  51.     private double fHi;
  52.     /** Function value at {@link #mid}. */
  53.     private double fMid;

  54.     /**
  55.      * Constructor with default values {@code 100, 100000} (see the
  56.      * {@link #BracketFinder(double,int) other constructor}).
  57.      */
  58.     BracketFinder() {
  59.         this(100, 100000);
  60.     }

  61.     /**
  62.      * Create a bracketing interval finder.
  63.      *
  64.      * @param growLimit Expanding factor.
  65.      * @param maxEvaluations Maximum number of evaluations allowed for finding
  66.      * a bracketing interval.
  67.      * @throws IllegalArgumentException if the {@code growLimit} or {@code maxEvalutations}
  68.      * are not strictly positive.
  69.      */
  70.     BracketFinder(double growLimit, int maxEvaluations) {
  71.         Arguments.checkStrictlyPositive(growLimit);
  72.         Arguments.checkStrictlyPositive(maxEvaluations);
  73.         this.growLimit = growLimit;
  74.         this.maxEvaluations = maxEvaluations;
  75.     }

  76.     /**
  77.      * Search downhill from the initial points to obtain new points that bracket a local
  78.      * minimum of the function. Note that the initial points do not have to bracket a minimum.
  79.      * An exception is raised if a minimum cannot be found within the configured number
  80.      * of function evaluations.
  81.      *
  82.      * <p>The bracket is limited to the provided bounds if they create a positive interval
  83.      * {@code min < max}. It is possible that the middle of the bracket is at the bounds as
  84.      * the final bracket is {@code f(mid) <= min(f(lo), f(hi))} and {@code lo <= mid <= hi}.
  85.      *
  86.      * <p>No exception is raised if the initial points are not within the bounds; the points
  87.      * are updated to be within the bounds.
  88.      *
  89.      * <p>No exception is raised if the initial points are equal; the bracket will be returned
  90.      * as a single point {@code lo == mid == hi}.
  91.      *
  92.      * @param func Function whose optimum should be bracketed.
  93.      * @param a Initial point.
  94.      * @param b Initial point.
  95.      * @param min Minimum bound of the bracket (inclusive).
  96.      * @param max Maximum bound of the bracket (inclusive).
  97.      * @return true if the mid-point is strictly within the final bracket {@code [lo, hi]};
  98.      * false if there is no local minima.
  99.      * @throws IllegalStateException if the maximum number of evaluations is exceeded.
  100.      */
  101.     boolean search(DoubleUnaryOperator func,
  102.                    double a, double b,
  103.                    double min, double max) {
  104.         evaluations = 0;

  105.         // Limit the range of x
  106.         final DoubleUnaryOperator range;
  107.         if (min < max) {
  108.             // Limit: min <= x <= max
  109.             range = x -> {
  110.                 if (x > min) {
  111.                     return x < max ? x : max;
  112.                 }
  113.                 return min;
  114.             };
  115.         } else {
  116.             range = DoubleUnaryOperator.identity();
  117.         }

  118.         double xA = range.applyAsDouble(a);
  119.         double xB = range.applyAsDouble(b);
  120.         double fA = value(func, xA);
  121.         double fB = value(func, xB);
  122.         // Ensure fB <= fA
  123.         if (fA < fB) {
  124.             double tmp = xA;
  125.             xA = xB;
  126.             xB = tmp;
  127.             tmp = fA;
  128.             fA = fB;
  129.             fB = tmp;
  130.         }

  131.         double xC = range.applyAsDouble(xB + GOLD * (xB - xA));
  132.         double fC = value(func, xC);

  133.         // Note: When a [min, max] interval is provided and there is no minima then this
  134.         // loop will terminate when B == C and both are at the min/max bound.
  135.         while (fC < fB) {
  136.             final double tmp1 = (xB - xA) * (fB - fC);
  137.             final double tmp2 = (xB - xC) * (fB - fA);

  138.             final double val = tmp2 - tmp1;
  139.             // limit magnitude of val to a small value
  140.             final double denom = 2 * Math.copySign(Math.max(Math.abs(val), EPS_MIN), val);

  141.             double w = range.applyAsDouble(xB - ((xB - xC) * tmp2 - (xB - xA) * tmp1) / denom);
  142.             final double wLim = range.applyAsDouble(xB + growLimit * (xC - xB));

  143.             double fW;
  144.             if ((w - xC) * (xB - w) > 0) {
  145.                 // xB < w < xC
  146.                 fW = value(func, w);
  147.                 if (fW < fC) {
  148.                     // minimum in [xB, xC]
  149.                     xA = xB;
  150.                     xB = w;
  151.                     fA = fB;
  152.                     fB = fW;
  153.                     break;
  154.                 } else if (fW > fB) {
  155.                     // minimum in [xA, w]
  156.                     xC = w;
  157.                     fC = fW;
  158.                     break;
  159.                 }
  160.                 // continue downhill
  161.                 w = range.applyAsDouble(xC + GOLD * (xC - xB));
  162.                 fW = value(func, w);
  163.             } else if ((w - wLim) * (xC - w) > 0) {
  164.                 // xC < w < limit
  165.                 fW = value(func, w);
  166.                 if (fW < fC) {
  167.                     // continue downhill
  168.                     xB = xC;
  169.                     xC = w;
  170.                     w = range.applyAsDouble(xC + GOLD * (xC - xB));
  171.                     fB = fC;
  172.                     fC = fW;
  173.                     fW = value(func, w);
  174.                 }
  175.             } else if ((w - wLim) * (wLim - xC) >= 0) {
  176.                 // xC <= limit <= w
  177.                 w = wLim;
  178.                 fW = value(func, w);
  179.             } else {
  180.                 // possibly w == xC; reject w and take a default step
  181.                 w = range.applyAsDouble(xC + GOLD * (xC - xB));
  182.                 fW = value(func, w);
  183.             }

  184.             xA = xB;
  185.             fA = fB;
  186.             xB = xC;
  187.             fB = fC;
  188.             xC = w;
  189.             fC = fW;
  190.         }

  191.         mid = xB;
  192.         fMid = fB;

  193.         // Store the bracket: lo <= mid <= hi
  194.         if (xC < xA) {
  195.             lo = xC;
  196.             fLo = fC;
  197.             hi = xA;
  198.             fHi = fA;
  199.         } else {
  200.             lo = xA;
  201.             fLo = fA;
  202.             hi = xC;
  203.             fHi = fC;
  204.         }

  205.         return lo < mid && mid < hi;
  206.     }

  207.     /**
  208.      * @return the number of evaluations.
  209.      */
  210.     int getEvaluations() {
  211.         return evaluations;
  212.     }

  213.     /**
  214.      * @return the lower bound of the bracket.
  215.      * @see #getFLo()
  216.      */
  217.     double getLo() {
  218.         return lo;
  219.     }

  220.     /**
  221.      * Get function value at {@link #getLo()}.
  222.      * @return function value at {@link #getLo()}
  223.      */
  224.     double getFLo() {
  225.         return fLo;
  226.     }

  227.     /**
  228.      * @return the higher bound of the bracket.
  229.      * @see #getFHi()
  230.      */
  231.     double getHi() {
  232.         return hi;
  233.     }

  234.     /**
  235.      * Get function value at {@link #getHi()}.
  236.      * @return function value at {@link #getHi()}
  237.      */
  238.     double getFHi() {
  239.         return fHi;
  240.     }

  241.     /**
  242.      * @return a point in the middle of the bracket.
  243.      * @see #getFMid()
  244.      */
  245.     double getMid() {
  246.         return mid;
  247.     }

  248.     /**
  249.      * Get function value at {@link #getMid()}.
  250.      * @return function value at {@link #getMid()}
  251.      */
  252.     double getFMid() {
  253.         return fMid;
  254.     }

  255.     /**
  256.      * Get the value of the function.
  257.      *
  258.      * @param func Function.
  259.      * @param x Point.
  260.      * @return the value
  261.      * @throws IllegalStateException if the maximal number of evaluations is exceeded.
  262.      */
  263.     private double value(DoubleUnaryOperator func, double x) {
  264.         if (evaluations >= maxEvaluations) {
  265.             throw new IllegalStateException("Too many evaluations: " + evaluations);
  266.         }
  267.         evaluations++;
  268.         return func.applyAsDouble(x);
  269.     }
  270. }