BracketFinder.java
- /*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- package org.apache.commons.statistics.inference;
- import java.util.function.DoubleUnaryOperator;
- /**
- * Provide an interval that brackets a local minimum of a function.
- * This code is based on a Python implementation (from <em>SciPy</em>,
- * module {@code optimize.py} v0.5).
- *
- * <p>This class has been extracted from {@code o.a.c.math4.optim.univariate}
- * and modified to: remove support for bracketing a maximum; support bounds
- * on the bracket; correct the sign of the denominator when the magnitude is small;
- * and return true/false if there is a minimum strictly inside the bounds.
- *
- * @since 1.1
- */
- class BracketFinder {
- /** Tolerance to avoid division by zero. */
- private static final double EPS_MIN = 1e-21;
- /** Golden section. */
- private static final double GOLD = 1.6180339887498948482;
- /** Factor for expanding the interval. */
- private final double growLimit;
- /** Number of allowed function evaluations. */
- private final int maxEvaluations;
- /** Number of function evaluations performed in the last search. */
- private int evaluations;
- /** Lower bound of the bracket. */
- private double lo;
- /** Higher bound of the bracket. */
- private double hi;
- /** Point inside the bracket. */
- private double mid;
- /** Function value at {@link #lo}. */
- private double fLo;
- /** Function value at {@link #hi}. */
- private double fHi;
- /** Function value at {@link #mid}. */
- private double fMid;
- /**
- * Constructor with default values {@code 100, 100000} (see the
- * {@link #BracketFinder(double,int) other constructor}).
- */
- BracketFinder() {
- this(100, 100000);
- }
- /**
- * Create a bracketing interval finder.
- *
- * @param growLimit Expanding factor.
- * @param maxEvaluations Maximum number of evaluations allowed for finding
- * a bracketing interval.
- * @throws IllegalArgumentException if the {@code growLimit} or {@code maxEvalutations}
- * are not strictly positive.
- */
- BracketFinder(double growLimit, int maxEvaluations) {
- Arguments.checkStrictlyPositive(growLimit);
- Arguments.checkStrictlyPositive(maxEvaluations);
- this.growLimit = growLimit;
- this.maxEvaluations = maxEvaluations;
- }
- /**
- * Search downhill from the initial points to obtain new points that bracket a local
- * minimum of the function. Note that the initial points do not have to bracket a minimum.
- * An exception is raised if a minimum cannot be found within the configured number
- * of function evaluations.
- *
- * <p>The bracket is limited to the provided bounds if they create a positive interval
- * {@code min < max}. It is possible that the middle of the bracket is at the bounds as
- * the final bracket is {@code f(mid) <= min(f(lo), f(hi))} and {@code lo <= mid <= hi}.
- *
- * <p>No exception is raised if the initial points are not within the bounds; the points
- * are updated to be within the bounds.
- *
- * <p>No exception is raised if the initial points are equal; the bracket will be returned
- * as a single point {@code lo == mid == hi}.
- *
- * @param func Function whose optimum should be bracketed.
- * @param a Initial point.
- * @param b Initial point.
- * @param min Minimum bound of the bracket (inclusive).
- * @param max Maximum bound of the bracket (inclusive).
- * @return true if the mid-point is strictly within the final bracket {@code [lo, hi]};
- * false if there is no local minima.
- * @throws IllegalStateException if the maximum number of evaluations is exceeded.
- */
- boolean search(DoubleUnaryOperator func,
- double a, double b,
- double min, double max) {
- evaluations = 0;
- // Limit the range of x
- final DoubleUnaryOperator range;
- if (min < max) {
- // Limit: min <= x <= max
- range = x -> {
- if (x > min) {
- return x < max ? x : max;
- }
- return min;
- };
- } else {
- range = DoubleUnaryOperator.identity();
- }
- double xA = range.applyAsDouble(a);
- double xB = range.applyAsDouble(b);
- double fA = value(func, xA);
- double fB = value(func, xB);
- // Ensure fB <= fA
- if (fA < fB) {
- double tmp = xA;
- xA = xB;
- xB = tmp;
- tmp = fA;
- fA = fB;
- fB = tmp;
- }
- double xC = range.applyAsDouble(xB + GOLD * (xB - xA));
- double fC = value(func, xC);
- // Note: When a [min, max] interval is provided and there is no minima then this
- // loop will terminate when B == C and both are at the min/max bound.
- while (fC < fB) {
- final double tmp1 = (xB - xA) * (fB - fC);
- final double tmp2 = (xB - xC) * (fB - fA);
- final double val = tmp2 - tmp1;
- // limit magnitude of val to a small value
- final double denom = 2 * Math.copySign(Math.max(Math.abs(val), EPS_MIN), val);
- double w = range.applyAsDouble(xB - ((xB - xC) * tmp2 - (xB - xA) * tmp1) / denom);
- final double wLim = range.applyAsDouble(xB + growLimit * (xC - xB));
- double fW;
- if ((w - xC) * (xB - w) > 0) {
- // xB < w < xC
- fW = value(func, w);
- if (fW < fC) {
- // minimum in [xB, xC]
- xA = xB;
- xB = w;
- fA = fB;
- fB = fW;
- break;
- } else if (fW > fB) {
- // minimum in [xA, w]
- xC = w;
- fC = fW;
- break;
- }
- // continue downhill
- w = range.applyAsDouble(xC + GOLD * (xC - xB));
- fW = value(func, w);
- } else if ((w - wLim) * (xC - w) > 0) {
- // xC < w < limit
- fW = value(func, w);
- if (fW < fC) {
- // continue downhill
- xB = xC;
- xC = w;
- w = range.applyAsDouble(xC + GOLD * (xC - xB));
- fB = fC;
- fC = fW;
- fW = value(func, w);
- }
- } else if ((w - wLim) * (wLim - xC) >= 0) {
- // xC <= limit <= w
- w = wLim;
- fW = value(func, w);
- } else {
- // possibly w == xC; reject w and take a default step
- w = range.applyAsDouble(xC + GOLD * (xC - xB));
- fW = value(func, w);
- }
- xA = xB;
- fA = fB;
- xB = xC;
- fB = fC;
- xC = w;
- fC = fW;
- }
- mid = xB;
- fMid = fB;
- // Store the bracket: lo <= mid <= hi
- if (xC < xA) {
- lo = xC;
- fLo = fC;
- hi = xA;
- fHi = fA;
- } else {
- lo = xA;
- fLo = fA;
- hi = xC;
- fHi = fC;
- }
- return lo < mid && mid < hi;
- }
- /**
- * @return the number of evaluations.
- */
- int getEvaluations() {
- return evaluations;
- }
- /**
- * @return the lower bound of the bracket.
- * @see #getFLo()
- */
- double getLo() {
- return lo;
- }
- /**
- * Get function value at {@link #getLo()}.
- * @return function value at {@link #getLo()}
- */
- double getFLo() {
- return fLo;
- }
- /**
- * @return the higher bound of the bracket.
- * @see #getFHi()
- */
- double getHi() {
- return hi;
- }
- /**
- * Get function value at {@link #getHi()}.
- * @return function value at {@link #getHi()}
- */
- double getFHi() {
- return fHi;
- }
- /**
- * @return a point in the middle of the bracket.
- * @see #getFMid()
- */
- double getMid() {
- return mid;
- }
- /**
- * Get function value at {@link #getMid()}.
- * @return function value at {@link #getMid()}
- */
- double getFMid() {
- return fMid;
- }
- /**
- * Get the value of the function.
- *
- * @param func Function.
- * @param x Point.
- * @return the value
- * @throws IllegalStateException if the maximal number of evaluations is exceeded.
- */
- private double value(DoubleUnaryOperator func, double x) {
- if (evaluations >= maxEvaluations) {
- throw new IllegalStateException("Too many evaluations: " + evaluations);
- }
- evaluations++;
- return func.applyAsDouble(x);
- }
- }