View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.statistics.inference;
18  
19  import java.util.function.DoubleUnaryOperator;
20  import org.apache.commons.numbers.core.Precision;
21  
22  /**
23   * For a function defined on some interval {@code (lo, hi)}, this class
24   * finds an approximation {@code x} to the point at which the function
25   * attains its minimum.
26   * It implements Richard Brent's algorithm (from his book "Algorithms for
27   * Minimization without Derivatives", p. 79) for finding minima of real
28   * univariate functions.
29   *
30   * <P>This code is an adaptation, partly based on the Python code from SciPy
31   * (module "optimize.py" v0.5); the original algorithm is also modified:
32   * <ul>
33   *  <li>to use an initial guess provided by the user,</li>
34   *  <li>to ensure that the best point encountered is the one returned.</li>
35   * </ul>
36   *
37   * <p>This class has been extracted from {@code o.a.c.math4.optim.univariate}
38   * and simplified to remove support for the UnivariateOptimizer interface.
39   * This removed the options: to find the maximum; use a custom convergence checker
40   * on the function value; and remove the maximum function evaluation count.
41   * The class now implements a single optimize method within the provided bracket
42   * from the given start position (with value).
43   *
44   * @since 1.1
45   */
46  final class BrentOptimizer {
47      /** Golden section. (3 - sqrt(5)) / 2. */
48      private static final double GOLDEN_SECTION = 0.3819660112501051;
49      /** Minimum relative tolerance. 2 * eps = 2^-51. */
50      private static final double MIN_RELATIVE_TOLERANCE = 0x1.0p-51;
51  
52      /** Relative threshold. */
53      private final double relativeThreshold;
54      /** Absolute threshold. */
55      private final double absoluteThreshold;
56      /** The number of function evaluations from the most recent call to optimize. */
57      private int evaluations;
58  
59      /**
60       * This class holds a point and the value of an objective function at this
61       * point. This is a simple immutable container.
62       *
63       * @since 1.1
64       */
65      static final class PointValuePair {
66          /** Point. */
67          private final double point;
68          /** Value of the objective function at the point. */
69          private final double value;
70  
71          /**
72           * @param point Point.
73           * @param value Value of an objective function at the point.
74           */
75          private PointValuePair(double point, double value) {
76              this.point = point;
77              this.value = value;
78          }
79  
80          /**
81           * Create a point/objective function value pair.
82           *
83           * @param point Point.
84           * @param value Value of an objective function at the point.
85           * @return the pair
86           */
87          static PointValuePair of(double point, double value) {
88              return new PointValuePair(point, value);
89          }
90  
91          /**
92           * Get the point.
93           *
94           * @return the point.
95           */
96          double getPoint() {
97              return point;
98          }
99  
100         /**
101          * Get the value of the objective function.
102          *
103          * @return the stored value of the objective function.
104          */
105         double getValue() {
106             return value;
107         }
108     }
109 
110     /**
111      * The arguments are used to implement the original stopping criterion
112      * of Brent's algorithm.
113      * {@code abs} and {@code rel} define a tolerance
114      * {@code tol = rel |x| + abs}. {@code rel} should be no smaller than
115      * <em>2 macheps</em> and preferably not much less than <em>sqrt(macheps)</em>,
116      * where <em>macheps</em> is the relative machine precision. {@code abs} must
117      * be positive.
118      *
119      * @param rel Relative threshold.
120      * @param abs Absolute threshold.
121      * @throws IllegalArgumentException if {@code abs <= 0}; or if {@code rel < 2 * Math.ulp(1.0)}
122      */
123     BrentOptimizer(double rel, double abs) {
124         if (rel >= MIN_RELATIVE_TOLERANCE) {
125             relativeThreshold = rel;
126             absoluteThreshold = Arguments.checkStrictlyPositive(abs);
127         } else {
128             // relative too small, or NaN
129             throw new InferenceException(InferenceException.X_LT_Y, rel, MIN_RELATIVE_TOLERANCE);
130         }
131     }
132 
133     /**
134      * Gets the number of function evaluations from the most recent call to
135      * {@link #optimize(DoubleUnaryOperator, double, double, double, double) optimize}.
136      *
137      * @return the function evaluations
138      */
139     int getEvaluations() {
140         return evaluations;
141     }
142 
143     /**
144      * Search for the minimum inside the provided interval. The bracket must satisfy
145      * the equalities {@code lo < mid < hi} or {@code hi < mid < lo}.
146      *
147      * <p>Note: This function accepts the initial guess and the function value at that point.
148      * This is done for convenience as this internal class is used where the caller already
149      * knows the function value.
150      *
151      * @param func Function to solve.
152      * @param lo Lower bound of the search interval.
153      * @param hi Higher bound of the search interval.
154      * @param mid Start point.
155      * @param fMid Function value at the start point.
156      * @return the value where the function is minimum.
157      * @throws IllegalArgumentException if start point is not within the search interval
158      * @throws IllegalStateException if the maximum number of iterations is exceeded
159      */
160     PointValuePair optimize(DoubleUnaryOperator func,
161                             double lo, double hi,
162                             double mid, double fMid) {
163         double a;
164         double b;
165         if (lo < hi) {
166             a = lo;
167             b = hi;
168         } else {
169             a = hi;
170             b = lo;
171         }
172         if (!(a < mid && mid < b)) {
173             throw new InferenceException("Invalid bounds: (%s, %s) with start %s", a, b, mid);
174         }
175         double x = mid;
176         double v = x;
177         double w = x;
178         double d = 0;
179         double e = 0;
180         double fx = fMid;
181         double fv = fx;
182         double fw = fx;
183 
184         // Best point encountered so far (which is the initial guess).
185         double bestX = x;
186         double bestFx = fx;
187 
188         // No test for iteration count.
189         // Note that the termination criterion is based purely on the size of the current
190         // bracket and the current point x. If the function evaluates NaN then golden
191         // section steps are taken.
192         evaluations = 0;
193         for (;;) {
194             final double m = 0.5 * (a + b);
195             final double tol1 = relativeThreshold * Math.abs(x) + absoluteThreshold;
196             final double tol2 = 2 * tol1;
197 
198             // Default termination (Brent's criterion).
199             if (Math.abs(x - m) <= tol2 - 0.5 * (b - a)) {
200                 return PointValuePair.of(bestX, bestFx);
201             }
202 
203             if (Math.abs(e) > tol1) {
204                 // Fit parabola.
205                 double r = (x - w) * (fx - fv);
206                 double q = (x - v) * (fx - fw);
207                 double p = (x - v) * q - (x - w) * r;
208                 q = 2 * (q - r);
209 
210                 if (q > 0) {
211                     p = -p;
212                 } else {
213                     q = -q;
214                 }
215 
216                 r = e;
217                 e = d;
218 
219                 if (p > q * (a - x) &&
220                     p < q * (b - x) &&
221                     Math.abs(p) < Math.abs(0.5 * q * r)) {
222                     // Parabolic interpolation step.
223                     d = p / q;
224                     final double u = x + d;
225 
226                     // f must not be evaluated too close to a or b.
227                     if (u - a < tol2 || b - u < tol2) {
228                         if (x <= m) {
229                             d = tol1;
230                         } else {
231                             d = -tol1;
232                         }
233                     }
234                 } else {
235                     // Golden section step.
236                     if (x < m) {
237                         e = b - x;
238                     } else {
239                         e = a - x;
240                     }
241                     d = GOLDEN_SECTION * e;
242                 }
243             } else {
244                 // Golden section step.
245                 if (x < m) {
246                     e = b - x;
247                 } else {
248                     e = a - x;
249                 }
250                 d = GOLDEN_SECTION * e;
251             }
252 
253             // Update by at least "tol1".
254             // Here d is never NaN so the evaluation point u is always finite.
255             final double u;
256             if (Math.abs(d) < tol1) {
257                 if (d >= 0) {
258                     u = x + tol1;
259                 } else {
260                     u = x - tol1;
261                 }
262             } else {
263                 u = x + d;
264             }
265 
266             evaluations++;
267             final double fu = func.applyAsDouble(u);
268 
269             // Maintain the best encountered result
270             if (fu < bestFx) {
271                 bestX = u;
272                 bestFx = fu;
273             }
274 
275             // Note:
276             // Here the use of a convergence checker on f(x) previous vs current has been removed.
277             // Typically when the checker requires a very small relative difference
278             // the optimizer will stop before, or soon after, on Brent's criterion when that is
279             // configured with the smallest recommended convergence criteria.
280 
281             // Update a, b, v, w and x.
282             if (fu <= fx) {
283                 if (u < x) {
284                     b = x;
285                 } else {
286                     a = x;
287                 }
288                 v = w;
289                 fv = fw;
290                 w = x;
291                 fw = fx;
292                 x = u;
293                 fx = fu;
294             } else {
295                 if (u < x) {
296                     a = u;
297                 } else {
298                     b = u;
299                 }
300                 if (fu <= fw ||
301                     Precision.equals(w, x)) {
302                     v = w;
303                     fv = fw;
304                     w = u;
305                     fw = fu;
306                 } else if (fu <= fv ||
307                            Precision.equals(v, x) ||
308                            Precision.equals(v, w)) {
309                     v = u;
310                     fv = fu;
311                 }
312             }
313         }
314     }
315 }