View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.math4.legacy.optim.univariate;
18  
19  import org.apache.commons.math4.legacy.analysis.UnivariateFunction;
20  import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
21  import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
22  import org.apache.commons.math4.legacy.optim.ConvergenceChecker;
23  import org.apache.commons.math4.legacy.optim.nonlinear.scalar.GoalType;
24  import org.apache.commons.math4.core.jdkmath.JdkMath;
25  import org.apache.commons.numbers.core.Precision;
26  
27  /**
28   * For a function defined on some interval {@code (lo, hi)}, this class
29   * finds an approximation {@code x} to the point at which the function
30   * attains its minimum.
31   * It implements Richard Brent's algorithm (from his book "Algorithms for
32   * Minimization without Derivatives", p. 79) for finding minima of real
33   * univariate functions.
34   * <br>
35   * This code is an adaptation, partly based on the Python code from SciPy
36   * (module "optimize.py" v0.5); the original algorithm is also modified
37   * <ul>
38   *  <li>to use an initial guess provided by the user,</li>
39   *  <li>to ensure that the best point encountered is the one returned.</li>
40   * </ul>
41   *
42   * @since 2.0
43   */
44  public class BrentOptimizer extends UnivariateOptimizer {
45      /**
46       * Golden section.
47       */
48      private static final double GOLDEN_SECTION = 0.5 * (3 - JdkMath.sqrt(5));
49      /**
50       * Minimum relative tolerance.
51       */
52      private static final double MIN_RELATIVE_TOLERANCE = 2 * JdkMath.ulp(1d);
53      /**
54       * Relative threshold.
55       */
56      private final double relativeThreshold;
57      /**
58       * Absolute threshold.
59       */
60      private final double absoluteThreshold;
61  
62      /**
63       * The arguments are used implement the original stopping criterion
64       * of Brent's algorithm.
65       * {@code abs} and {@code rel} define a tolerance
66       * {@code tol = rel |x| + abs}. {@code rel} should be no smaller than
67       * <em>2 macheps</em> and preferably not much less than <em>sqrt(macheps)</em>,
68       * where <em>macheps</em> is the relative machine precision. {@code abs} must
69       * be positive.
70       *
71       * @param rel Relative threshold.
72       * @param abs Absolute threshold.
73       * @param checker Additional, user-defined, convergence checking
74       * procedure.
75       * @throws NotStrictlyPositiveException if {@code abs <= 0}.
76       * @throws NumberIsTooSmallException if {@code rel < 2 * Math.ulp(1d)}.
77       */
78      public BrentOptimizer(double rel,
79                            double abs,
80                            ConvergenceChecker<UnivariatePointValuePair> checker) {
81          super(checker);
82  
83          if (rel < MIN_RELATIVE_TOLERANCE) {
84              throw new NumberIsTooSmallException(rel, MIN_RELATIVE_TOLERANCE, true);
85          }
86          if (abs <= 0) {
87              throw new NotStrictlyPositiveException(abs);
88          }
89  
90          relativeThreshold = rel;
91          absoluteThreshold = abs;
92      }
93  
94      /**
95       * The arguments are used for implementing the original stopping criterion
96       * of Brent's algorithm.
97       * {@code abs} and {@code rel} define a tolerance
98       * {@code tol = rel |x| + abs}. {@code rel} should be no smaller than
99       * <em>2 macheps</em> and preferably not much less than <em>sqrt(macheps)</em>,
100      * where <em>macheps</em> is the relative machine precision. {@code abs} must
101      * be positive.
102      *
103      * @param rel Relative threshold.
104      * @param abs Absolute threshold.
105      * @throws NotStrictlyPositiveException if {@code abs <= 0}.
106      * @throws NumberIsTooSmallException if {@code rel < 2 * Math.ulp(1d)}.
107      */
108     public BrentOptimizer(double rel,
109                           double abs) {
110         this(rel, abs, null);
111     }
112 
113     /** {@inheritDoc} */
114     @Override
115     protected UnivariatePointValuePair doOptimize() {
116         final boolean isMinim = getGoalType() == GoalType.MINIMIZE;
117         final double lo = getMin();
118         final double mid = getStartValue();
119         final double hi = getMax();
120         final UnivariateFunction func = getObjectiveFunction();
121 
122         // Optional additional convergence criteria.
123         final ConvergenceChecker<UnivariatePointValuePair> checker
124             = getConvergenceChecker();
125 
126         double a;
127         double b;
128         if (lo < hi) {
129             a = lo;
130             b = hi;
131         } else {
132             a = hi;
133             b = lo;
134         }
135 
136         double x = mid;
137         double v = x;
138         double w = x;
139         double d = 0;
140         double e = 0;
141         double fx = func.value(x);
142         if (!isMinim) {
143             fx = -fx;
144         }
145         double fv = fx;
146         double fw = fx;
147 
148         UnivariatePointValuePair previous = null;
149         UnivariatePointValuePair current
150             = new UnivariatePointValuePair(x, isMinim ? fx : -fx);
151         // Best point encountered so far (which is the initial guess).
152         UnivariatePointValuePair best = current;
153 
154         while (true) {
155             final double m = 0.5 * (a + b);
156             final double tol1 = relativeThreshold * JdkMath.abs(x) + absoluteThreshold;
157             final double tol2 = 2 * tol1;
158 
159             // Default stopping criterion.
160             final boolean stop = JdkMath.abs(x - m) <= tol2 - 0.5 * (b - a);
161             if (!stop) {
162                 double p = 0;
163                 double q = 0;
164                 double r = 0;
165                 double u = 0;
166 
167                 if (JdkMath.abs(e) > tol1) { // Fit parabola.
168                     r = (x - w) * (fx - fv);
169                     q = (x - v) * (fx - fw);
170                     p = (x - v) * q - (x - w) * r;
171                     q = 2 * (q - r);
172 
173                     if (q > 0) {
174                         p = -p;
175                     } else {
176                         q = -q;
177                     }
178 
179                     r = e;
180                     e = d;
181 
182                     if (p > q * (a - x) &&
183                         p < q * (b - x) &&
184                         JdkMath.abs(p) < JdkMath.abs(0.5 * q * r)) {
185                         // Parabolic interpolation step.
186                         d = p / q;
187                         u = x + d;
188 
189                         // f must not be evaluated too close to a or b.
190                         if (u - a < tol2 || b - u < tol2) {
191                             if (x <= m) {
192                                 d = tol1;
193                             } else {
194                                 d = -tol1;
195                             }
196                         }
197                     } else {
198                         // Golden section step.
199                         if (x < m) {
200                             e = b - x;
201                         } else {
202                             e = a - x;
203                         }
204                         d = GOLDEN_SECTION * e;
205                     }
206                 } else {
207                     // Golden section step.
208                     if (x < m) {
209                         e = b - x;
210                     } else {
211                         e = a - x;
212                     }
213                     d = GOLDEN_SECTION * e;
214                 }
215 
216                 // Update by at least "tol1".
217                 if (JdkMath.abs(d) < tol1) {
218                     if (d >= 0) {
219                         u = x + tol1;
220                     } else {
221                         u = x - tol1;
222                     }
223                 } else {
224                     u = x + d;
225                 }
226 
227                 double fu = func.value(u);
228                 if (!isMinim) {
229                     fu = -fu;
230                 }
231 
232                 // User-defined convergence checker.
233                 previous = current;
234                 current = new UnivariatePointValuePair(u, isMinim ? fu : -fu);
235                 best = best(best,
236                             best(previous,
237                                  current,
238                                  isMinim),
239                             isMinim);
240 
241                 if (checker != null && checker.converged(getIterations(), previous, current)) {
242                     return best;
243                 }
244 
245                 // Update a, b, v, w and x.
246                 if (fu <= fx) {
247                     if (u < x) {
248                         b = x;
249                     } else {
250                         a = x;
251                     }
252                     v = w;
253                     fv = fw;
254                     w = x;
255                     fw = fx;
256                     x = u;
257                     fx = fu;
258                 } else {
259                     if (u < x) {
260                         a = u;
261                     } else {
262                         b = u;
263                     }
264                     if (fu <= fw ||
265                         Precision.equals(w, x)) {
266                         v = w;
267                         fv = fw;
268                         w = u;
269                         fw = fu;
270                     } else if (fu <= fv ||
271                                Precision.equals(v, x) ||
272                                Precision.equals(v, w)) {
273                         v = u;
274                         fv = fu;
275                     }
276                 }
277             } else { // Default termination (Brent's criterion).
278                 return best(best,
279                             best(previous,
280                                  current,
281                                  isMinim),
282                             isMinim);
283             }
284 
285             incrementIterationCount();
286         }
287     }
288 
289     /**
290      * Selects the best of two points.
291      *
292      * @param a Point and value.
293      * @param b Point and value.
294      * @param isMinim {@code true} if the selected point must be the one with
295      * the lowest value.
296      * @return the best point, or {@code null} if {@code a} and {@code b} are
297      * both {@code null}. When {@code a} and {@code b} have the same function
298      * value, {@code a} is returned.
299      */
300     private UnivariatePointValuePair best(UnivariatePointValuePair a,
301                                           UnivariatePointValuePair b,
302                                           boolean isMinim) {
303         if (a == null) {
304             return b;
305         }
306         if (b == null) {
307             return a;
308         }
309 
310         if (isMinim) {
311             return a.getValue() <= b.getValue() ? a : b;
312         } else {
313             return a.getValue() >= b.getValue() ? a : b;
314         }
315     }
316 }