View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.math3.optim.univariate;
18  
19  import org.apache.commons.math3.util.Precision;
20  import org.apache.commons.math3.util.FastMath;
21  import org.apache.commons.math3.exception.NumberIsTooSmallException;
22  import org.apache.commons.math3.exception.NotStrictlyPositiveException;
23  import org.apache.commons.math3.optim.ConvergenceChecker;
24  import org.apache.commons.math3.optim.nonlinear.scalar.GoalType;
25  
26  /**
27   * For a function defined on some interval {@code (lo, hi)}, this class
28   * finds an approximation {@code x} to the point at which the function
29   * attains its minimum.
30   * It implements Richard Brent's algorithm (from his book "Algorithms for
31   * Minimization without Derivatives", p. 79) for finding minima of real
32   * univariate functions.
33   * <br/>
34   * This code is an adaptation, partly based on the Python code from SciPy
35   * (module "optimize.py" v0.5); the original algorithm is also modified
36   * <ul>
37   *  <li>to use an initial guess provided by the user,</li>
38   *  <li>to ensure that the best point encountered is the one returned.</li>
39   * </ul>
40   *
41   * @since 2.0
42   */
43  public class BrentOptimizer extends UnivariateOptimizer {
44      /**
45       * Golden section.
46       */
47      private static final double GOLDEN_SECTION = 0.5 * (3 - FastMath.sqrt(5));
48      /**
49       * Minimum relative tolerance.
50       */
51      private static final double MIN_RELATIVE_TOLERANCE = 2 * FastMath.ulp(1d);
52      /**
53       * Relative threshold.
54       */
55      private final double relativeThreshold;
56      /**
57       * Absolute threshold.
58       */
59      private final double absoluteThreshold;
60  
61      /**
62       * The arguments are used implement the original stopping criterion
63       * of Brent's algorithm.
64       * {@code abs} and {@code rel} define a tolerance
65       * {@code tol = rel |x| + abs}. {@code rel} should be no smaller than
66       * <em>2 macheps</em> and preferably not much less than <em>sqrt(macheps)</em>,
67       * where <em>macheps</em> is the relative machine precision. {@code abs} must
68       * be positive.
69       *
70       * @param rel Relative threshold.
71       * @param abs Absolute threshold.
72       * @param checker Additional, user-defined, convergence checking
73       * procedure.
74       * @throws NotStrictlyPositiveException if {@code abs <= 0}.
75       * @throws NumberIsTooSmallException if {@code rel < 2 * Math.ulp(1d)}.
76       */
77      public BrentOptimizer(double rel,
78                            double abs,
79                            ConvergenceChecker<UnivariatePointValuePair> checker) {
80          super(checker);
81  
82          if (rel < MIN_RELATIVE_TOLERANCE) {
83              throw new NumberIsTooSmallException(rel, MIN_RELATIVE_TOLERANCE, true);
84          }
85          if (abs <= 0) {
86              throw new NotStrictlyPositiveException(abs);
87          }
88  
89          relativeThreshold = rel;
90          absoluteThreshold = abs;
91      }
92  
93      /**
94       * The arguments are used for implementing the original stopping criterion
95       * of Brent's algorithm.
96       * {@code abs} and {@code rel} define a tolerance
97       * {@code tol = rel |x| + abs}. {@code rel} should be no smaller than
98       * <em>2 macheps</em> and preferably not much less than <em>sqrt(macheps)</em>,
99       * where <em>macheps</em> is the relative machine precision. {@code abs} must
100      * be positive.
101      *
102      * @param rel Relative threshold.
103      * @param abs Absolute threshold.
104      * @throws NotStrictlyPositiveException if {@code abs <= 0}.
105      * @throws NumberIsTooSmallException if {@code rel < 2 * Math.ulp(1d)}.
106      */
107     public BrentOptimizer(double rel,
108                           double abs) {
109         this(rel, abs, null);
110     }
111 
112     /** {@inheritDoc} */
113     @Override
114     protected UnivariatePointValuePair doOptimize() {
115         final boolean isMinim = getGoalType() == GoalType.MINIMIZE;
116         final double lo = getMin();
117         final double mid = getStartValue();
118         final double hi = getMax();
119 
120         // Optional additional convergence criteria.
121         final ConvergenceChecker<UnivariatePointValuePair> checker
122             = getConvergenceChecker();
123 
124         double a;
125         double b;
126         if (lo < hi) {
127             a = lo;
128             b = hi;
129         } else {
130             a = hi;
131             b = lo;
132         }
133 
134         double x = mid;
135         double v = x;
136         double w = x;
137         double d = 0;
138         double e = 0;
139         double fx = computeObjectiveValue(x);
140         if (!isMinim) {
141             fx = -fx;
142         }
143         double fv = fx;
144         double fw = fx;
145 
146         UnivariatePointValuePair previous = null;
147         UnivariatePointValuePair current
148             = new UnivariatePointValuePair(x, isMinim ? fx : -fx);
149         // Best point encountered so far (which is the initial guess).
150         UnivariatePointValuePair best = current;
151 
152         while (true) {
153             final double m = 0.5 * (a + b);
154             final double tol1 = relativeThreshold * FastMath.abs(x) + absoluteThreshold;
155             final double tol2 = 2 * tol1;
156 
157             // Default stopping criterion.
158             final boolean stop = FastMath.abs(x - m) <= tol2 - 0.5 * (b - a);
159             if (!stop) {
160                 double p = 0;
161                 double q = 0;
162                 double r = 0;
163                 double u = 0;
164 
165                 if (FastMath.abs(e) > tol1) { // Fit parabola.
166                     r = (x - w) * (fx - fv);
167                     q = (x - v) * (fx - fw);
168                     p = (x - v) * q - (x - w) * r;
169                     q = 2 * (q - r);
170 
171                     if (q > 0) {
172                         p = -p;
173                     } else {
174                         q = -q;
175                     }
176 
177                     r = e;
178                     e = d;
179 
180                     if (p > q * (a - x) &&
181                         p < q * (b - x) &&
182                         FastMath.abs(p) < FastMath.abs(0.5 * q * r)) {
183                         // Parabolic interpolation step.
184                         d = p / q;
185                         u = x + d;
186 
187                         // f must not be evaluated too close to a or b.
188                         if (u - a < tol2 || b - u < tol2) {
189                             if (x <= m) {
190                                 d = tol1;
191                             } else {
192                                 d = -tol1;
193                             }
194                         }
195                     } else {
196                         // Golden section step.
197                         if (x < m) {
198                             e = b - x;
199                         } else {
200                             e = a - x;
201                         }
202                         d = GOLDEN_SECTION * e;
203                     }
204                 } else {
205                     // Golden section step.
206                     if (x < m) {
207                         e = b - x;
208                     } else {
209                         e = a - x;
210                     }
211                     d = GOLDEN_SECTION * e;
212                 }
213 
214                 // Update by at least "tol1".
215                 if (FastMath.abs(d) < tol1) {
216                     if (d >= 0) {
217                         u = x + tol1;
218                     } else {
219                         u = x - tol1;
220                     }
221                 } else {
222                     u = x + d;
223                 }
224 
225                 double fu = computeObjectiveValue(u);
226                 if (!isMinim) {
227                     fu = -fu;
228                 }
229 
230                 // User-defined convergence checker.
231                 previous = current;
232                 current = new UnivariatePointValuePair(u, isMinim ? fu : -fu);
233                 best = best(best,
234                             best(previous,
235                                  current,
236                                  isMinim),
237                             isMinim);
238 
239                 if (checker != null && checker.converged(getIterations(), previous, current)) {
240                     return best;
241                 }
242 
243                 // Update a, b, v, w and x.
244                 if (fu <= fx) {
245                     if (u < x) {
246                         b = x;
247                     } else {
248                         a = x;
249                     }
250                     v = w;
251                     fv = fw;
252                     w = x;
253                     fw = fx;
254                     x = u;
255                     fx = fu;
256                 } else {
257                     if (u < x) {
258                         a = u;
259                     } else {
260                         b = u;
261                     }
262                     if (fu <= fw ||
263                         Precision.equals(w, x)) {
264                         v = w;
265                         fv = fw;
266                         w = u;
267                         fw = fu;
268                     } else if (fu <= fv ||
269                                Precision.equals(v, x) ||
270                                Precision.equals(v, w)) {
271                         v = u;
272                         fv = fu;
273                     }
274                 }
275             } else { // Default termination (Brent's criterion).
276                 return best(best,
277                             best(previous,
278                                  current,
279                                  isMinim),
280                             isMinim);
281             }
282 
283             incrementIterationCount();
284         }
285     }
286 
287     /**
288      * Selects the best of two points.
289      *
290      * @param a Point and value.
291      * @param b Point and value.
292      * @param isMinim {@code true} if the selected point must be the one with
293      * the lowest value.
294      * @return the best point, or {@code null} if {@code a} and {@code b} are
295      * both {@code null}. When {@code a} and {@code b} have the same function
296      * value, {@code a} is returned.
297      */
298     private UnivariatePointValuePair best(UnivariatePointValuePair a,
299                                           UnivariatePointValuePair b,
300                                           boolean isMinim) {
301         if (a == null) {
302             return b;
303         }
304         if (b == null) {
305             return a;
306         }
307 
308         if (isMinim) {
309             return a.getValue() <= b.getValue() ? a : b;
310         } else {
311             return a.getValue() >= b.getValue() ? a : b;
312         }
313     }
314 }