View Javadoc

1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.math.optimization.univariate;
18  
19  import org.apache.commons.math.util.Precision;
20  import org.apache.commons.math.util.FastMath;
21  import org.apache.commons.math.exception.NumberIsTooSmallException;
22  import org.apache.commons.math.exception.NotStrictlyPositiveException;
23  import org.apache.commons.math.optimization.ConvergenceChecker;
24  import org.apache.commons.math.optimization.GoalType;
25  
26  /**
27   * Implements Richard Brent's algorithm (from his book "Algorithms for
28   * Minimization without Derivatives", p. 79) for finding minima of real
29   * univariate functions. This implementation is an adaptation partly
30   * based on the Python code from SciPy (module "optimize.py" v0.5).
31   * If the function is defined on some interval {@code (lo, hi)}, then
32   * this method finds an approximation {@code x} to the point at which
33   * the function attains its minimum.
34   * <br/>
35   * The user is responsible for calling {@link
36   * #setConvergenceChecker(ConvergenceChecker) ConvergenceChecker}
37   * prior to using the optimizer.
38   *
39   * @version $Id: BrentOptimizer.java 1181282 2011-10-10 22:35:54Z erans $
40   * @since 2.0
41   */
42  public class BrentOptimizer extends AbstractUnivariateRealOptimizer {
43      /**
44       * Golden section.
45       */
46      private static final double GOLDEN_SECTION = 0.5 * (3 - FastMath.sqrt(5));
47      /**
48       * Minimum relative tolerance.
49       */
50      private static final double MIN_RELATIVE_TOLERANCE = 2 * FastMath.ulp(1d);
51      /**
52       * Relative threshold.
53       */
54      private final double relativeThreshold;
55      /**
56       * Absolute threshold.
57       */
58      private final double absoluteThreshold;
59  
60      /**
61       * The arguments are used implement the original stopping criterion
62       * of Brent's algorithm.
63       * {@code abs} and {@code rel} define a tolerance
64       * {@code tol = rel |x| + abs}. {@code rel} should be no smaller than
65       * <em>2 macheps</em> and preferably not much less than <em>sqrt(macheps)</em>,
66       * where <em>macheps</em> is the relative machine precision. {@code abs} must
67       * be positive.
68       *
69       * @param rel Relative threshold.
70       * @param abs Absolute threshold.
71       * @throws NotStrictlyPositiveException if {@code abs <= 0}.
72       * @throws NumberIsTooSmallException if {@code rel < 2 * Math.ulp(1d)}.
73       */
74      public BrentOptimizer(double rel,
75                            double abs) {
76          if (rel < MIN_RELATIVE_TOLERANCE) {
77              throw new NumberIsTooSmallException(rel, MIN_RELATIVE_TOLERANCE, true);
78          }
79          if (abs <= 0) {
80              throw new NotStrictlyPositiveException(abs);
81          }
82          relativeThreshold = rel;
83          absoluteThreshold = abs;
84      }
85  
86      /** {@inheritDoc} */
87      @Override
88      protected UnivariateRealPointValuePair doOptimize() {
89          final boolean isMinim = getGoalType() == GoalType.MINIMIZE;
90          final double lo = getMin();
91          final double mid = getStartValue();
92          final double hi = getMax();
93  
94          // Optional additional convergence criteria.
95          final ConvergenceChecker<UnivariateRealPointValuePair> checker
96              = getConvergenceChecker();
97  
98          double a;
99          double b;
100         if (lo < hi) {
101             a = lo;
102             b = hi;
103         } else {
104             a = hi;
105             b = lo;
106         }
107 
108         double x = mid;
109         double v = x;
110         double w = x;
111         double d = 0;
112         double e = 0;
113         double fx = computeObjectiveValue(x);
114         if (!isMinim) {
115             fx = -fx;
116         }
117         double fv = fx;
118         double fw = fx;
119 
120         UnivariateRealPointValuePair previous = null;
121         UnivariateRealPointValuePair current
122             = new UnivariateRealPointValuePair(x, isMinim ? fx : -fx);
123 
124         int iter = 0;
125         while (true) {
126             final double m = 0.5 * (a + b);
127             final double tol1 = relativeThreshold * FastMath.abs(x) + absoluteThreshold;
128             final double tol2 = 2 * tol1;
129 
130             // Default stopping criterion.
131             final boolean stop = FastMath.abs(x - m) <= tol2 - 0.5 * (b - a);
132             if (!stop) {
133                 double p = 0;
134                 double q = 0;
135                 double r = 0;
136                 double u = 0;
137 
138                 if (FastMath.abs(e) > tol1) { // Fit parabola.
139                     r = (x - w) * (fx - fv);
140                     q = (x - v) * (fx - fw);
141                     p = (x - v) * q - (x - w) * r;
142                     q = 2 * (q - r);
143 
144                     if (q > 0) {
145                         p = -p;
146                     } else {
147                         q = -q;
148                     }
149 
150                     r = e;
151                     e = d;
152 
153                     if (p > q * (a - x) &&
154                         p < q * (b - x) &&
155                         FastMath.abs(p) < FastMath.abs(0.5 * q * r)) {
156                         // Parabolic interpolation step.
157                         d = p / q;
158                         u = x + d;
159 
160                         // f must not be evaluated too close to a or b.
161                         if (u - a < tol2 || b - u < tol2) {
162                             if (x <= m) {
163                                 d = tol1;
164                             } else {
165                                 d = -tol1;
166                             }
167                         }
168                     } else {
169                         // Golden section step.
170                         if (x < m) {
171                             e = b - x;
172                         } else {
173                             e = a - x;
174                         }
175                         d = GOLDEN_SECTION * e;
176                     }
177                 } else {
178                     // Golden section step.
179                     if (x < m) {
180                         e = b - x;
181                     } else {
182                         e = a - x;
183                     }
184                     d = GOLDEN_SECTION * e;
185                 }
186 
187                 // Update by at least "tol1".
188                 if (FastMath.abs(d) < tol1) {
189                     if (d >= 0) {
190                         u = x + tol1;
191                     } else {
192                         u = x - tol1;
193                     }
194                 } else {
195                     u = x + d;
196                 }
197 
198                 double fu = computeObjectiveValue(u);
199                 if (!isMinim) {
200                     fu = -fu;
201                 }
202 
203                 // Update a, b, v, w and x.
204                 if (fu <= fx) {
205                     if (u < x) {
206                         b = x;
207                     } else {
208                         a = x;
209                     }
210                     v = w;
211                     fv = fw;
212                     w = x;
213                     fw = fx;
214                     x = u;
215                     fx = fu;
216                 } else {
217                     if (u < x) {
218                         a = u;
219                     } else {
220                         b = u;
221                     }
222                     if (fu <= fw ||
223                         Precision.equals(w, x)) {
224                         v = w;
225                         fv = fw;
226                         w = u;
227                         fw = fu;
228                     } else if (fu <= fv ||
229                                Precision.equals(v, x) ||
230                                Precision.equals(v, w)) {
231                         v = u;
232                         fv = fu;
233                     }
234                 }
235 
236                 previous = current;
237                 current = new UnivariateRealPointValuePair(x, isMinim ? fx : -fx);
238 
239                 // User-defined convergence checker.
240                 if (checker != null) {
241                     if (checker.converged(iter, previous, current)) {
242                         return current;
243                     }
244                 }
245             } else { // Default termination (Brent's criterion).
246                 return current;
247             }
248             ++iter;
249         }
250     }
251 }