View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.statistics.inference;
18  
19  import java.util.function.DoubleUnaryOperator;
20  
21  /**
22   * Provide an interval that brackets a local minimum of a function.
23   * This code is based on a Python implementation (from <em>SciPy</em>,
24   * module {@code optimize.py} v0.5).
25   *
26   * <p>This class has been extracted from {@code o.a.c.math4.optim.univariate}
27   * and modified to: remove support for bracketing a maximum; support bounds
28   * on the bracket; correct the sign of the denominator when the magnitude is small;
29   * and return true/false if there is a minimum strictly inside the bounds.
30   *
31   * @since 1.1
32   */
33  class BracketFinder {
34      /** Tolerance to avoid division by zero. */
35      private static final double EPS_MIN = 1e-21;
36      /** Golden section. */
37      private static final double GOLD = 1.6180339887498948482;
38      /** Factor for expanding the interval. */
39      private final double growLimit;
40      /**  Number of allowed function evaluations. */
41      private final int maxEvaluations;
42      /** Number of function evaluations performed in the last search. */
43      private int evaluations;
44      /** Lower bound of the bracket. */
45      private double lo;
46      /** Higher bound of the bracket. */
47      private double hi;
48      /** Point inside the bracket. */
49      private double mid;
50      /** Function value at {@link #lo}. */
51      private double fLo;
52      /** Function value at {@link #hi}. */
53      private double fHi;
54      /** Function value at {@link #mid}. */
55      private double fMid;
56  
57      /**
58       * Constructor with default values {@code 100, 100000} (see the
59       * {@link #BracketFinder(double,int) other constructor}).
60       */
61      BracketFinder() {
62          this(100, 100000);
63      }
64  
65      /**
66       * Create a bracketing interval finder.
67       *
68       * @param growLimit Expanding factor.
69       * @param maxEvaluations Maximum number of evaluations allowed for finding
70       * a bracketing interval.
71       * @throws IllegalArgumentException if the {@code growLimit} or {@code maxEvalutations}
72       * are not strictly positive.
73       */
74      BracketFinder(double growLimit, int maxEvaluations) {
75          Arguments.checkStrictlyPositive(growLimit);
76          Arguments.checkStrictlyPositive(maxEvaluations);
77          this.growLimit = growLimit;
78          this.maxEvaluations = maxEvaluations;
79      }
80  
81      /**
82       * Search downhill from the initial points to obtain new points that bracket a local
83       * minimum of the function. Note that the initial points do not have to bracket a minimum.
84       * An exception is raised if a minimum cannot be found within the configured number
85       * of function evaluations.
86       *
87       * <p>The bracket is limited to the provided bounds if they create a positive interval
88       * {@code min < max}. It is possible that the middle of the bracket is at the bounds as
89       * the final bracket is {@code f(mid) <= min(f(lo), f(hi))} and {@code lo <= mid <= hi}.
90       *
91       * <p>No exception is raised if the initial points are not within the bounds; the points
92       * are updated to be within the bounds.
93       *
94       * <p>No exception is raised if the initial points are equal; the bracket will be returned
95       * as a single point {@code lo == mid == hi}.
96       *
97       * @param func Function whose optimum should be bracketed.
98       * @param a Initial point.
99       * @param b Initial point.
100      * @param min Minimum bound of the bracket (inclusive).
101      * @param max Maximum bound of the bracket (inclusive).
102      * @return true if the mid-point is strictly within the final bracket {@code [lo, hi]};
103      * false if there is no local minima.
104      * @throws IllegalStateException if the maximum number of evaluations is exceeded.
105      */
106     boolean search(DoubleUnaryOperator func,
107                    double a, double b,
108                    double min, double max) {
109         evaluations = 0;
110 
111         // Limit the range of x
112         final DoubleUnaryOperator range;
113         if (min < max) {
114             // Limit: min <= x <= max
115             range = x -> {
116                 if (x > min) {
117                     return x < max ? x : max;
118                 }
119                 return min;
120             };
121         } else {
122             range = DoubleUnaryOperator.identity();
123         }
124 
125         double xA = range.applyAsDouble(a);
126         double xB = range.applyAsDouble(b);
127         double fA = value(func, xA);
128         double fB = value(func, xB);
129         // Ensure fB <= fA
130         if (fA < fB) {
131             double tmp = xA;
132             xA = xB;
133             xB = tmp;
134             tmp = fA;
135             fA = fB;
136             fB = tmp;
137         }
138 
139         double xC = range.applyAsDouble(xB + GOLD * (xB - xA));
140         double fC = value(func, xC);
141 
142         // Note: When a [min, max] interval is provided and there is no minima then this
143         // loop will terminate when B == C and both are at the min/max bound.
144         while (fC < fB) {
145             final double tmp1 = (xB - xA) * (fB - fC);
146             final double tmp2 = (xB - xC) * (fB - fA);
147 
148             final double val = tmp2 - tmp1;
149             // limit magnitude of val to a small value
150             final double denom = 2 * Math.copySign(Math.max(Math.abs(val), EPS_MIN), val);
151 
152             double w = range.applyAsDouble(xB - ((xB - xC) * tmp2 - (xB - xA) * tmp1) / denom);
153             final double wLim = range.applyAsDouble(xB + growLimit * (xC - xB));
154 
155             double fW;
156             if ((w - xC) * (xB - w) > 0) {
157                 // xB < w < xC
158                 fW = value(func, w);
159                 if (fW < fC) {
160                     // minimum in [xB, xC]
161                     xA = xB;
162                     xB = w;
163                     fA = fB;
164                     fB = fW;
165                     break;
166                 } else if (fW > fB) {
167                     // minimum in [xA, w]
168                     xC = w;
169                     fC = fW;
170                     break;
171                 }
172                 // continue downhill
173                 w = range.applyAsDouble(xC + GOLD * (xC - xB));
174                 fW = value(func, w);
175             } else if ((w - wLim) * (xC - w) > 0) {
176                 // xC < w < limit
177                 fW = value(func, w);
178                 if (fW < fC) {
179                     // continue downhill
180                     xB = xC;
181                     xC = w;
182                     w = range.applyAsDouble(xC + GOLD * (xC - xB));
183                     fB = fC;
184                     fC = fW;
185                     fW = value(func, w);
186                 }
187             } else if ((w - wLim) * (wLim - xC) >= 0) {
188                 // xC <= limit <= w
189                 w = wLim;
190                 fW = value(func, w);
191             } else {
192                 // possibly w == xC; reject w and take a default step
193                 w = range.applyAsDouble(xC + GOLD * (xC - xB));
194                 fW = value(func, w);
195             }
196 
197             xA = xB;
198             fA = fB;
199             xB = xC;
200             fB = fC;
201             xC = w;
202             fC = fW;
203         }
204 
205         mid = xB;
206         fMid = fB;
207 
208         // Store the bracket: lo <= mid <= hi
209         if (xC < xA) {
210             lo = xC;
211             fLo = fC;
212             hi = xA;
213             fHi = fA;
214         } else {
215             lo = xA;
216             fLo = fA;
217             hi = xC;
218             fHi = fC;
219         }
220 
221         return lo < mid && mid < hi;
222     }
223 
224     /**
225      * @return the number of evaluations.
226      */
227     int getEvaluations() {
228         return evaluations;
229     }
230 
231     /**
232      * @return the lower bound of the bracket.
233      * @see #getFLo()
234      */
235     double getLo() {
236         return lo;
237     }
238 
239     /**
240      * Get function value at {@link #getLo()}.
241      * @return function value at {@link #getLo()}
242      */
243     double getFLo() {
244         return fLo;
245     }
246 
247     /**
248      * @return the higher bound of the bracket.
249      * @see #getFHi()
250      */
251     double getHi() {
252         return hi;
253     }
254 
255     /**
256      * Get function value at {@link #getHi()}.
257      * @return function value at {@link #getHi()}
258      */
259     double getFHi() {
260         return fHi;
261     }
262 
263     /**
264      * @return a point in the middle of the bracket.
265      * @see #getFMid()
266      */
267     double getMid() {
268         return mid;
269     }
270 
271     /**
272      * Get function value at {@link #getMid()}.
273      * @return function value at {@link #getMid()}
274      */
275     double getFMid() {
276         return fMid;
277     }
278 
279     /**
280      * Get the value of the function.
281      *
282      * @param func Function.
283      * @param x Point.
284      * @return the value
285      * @throws IllegalStateException if the maximal number of evaluations is exceeded.
286      */
287     private double value(DoubleUnaryOperator func, double x) {
288         if (evaluations >= maxEvaluations) {
289             throw new IllegalStateException("Too many evaluations: " + evaluations);
290         }
291         evaluations++;
292         return func.applyAsDouble(x);
293     }
294 }