001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.math4.legacy.optim.univariate;
018
019import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
020import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
021import org.apache.commons.math4.legacy.optim.ConvergenceChecker;
022import org.apache.commons.math4.legacy.optim.nonlinear.scalar.GoalType;
023import org.apache.commons.math4.core.jdkmath.JdkMath;
024import org.apache.commons.numbers.core.Precision;
025
026/**
027 * For a function defined on some interval {@code (lo, hi)}, this class
028 * finds an approximation {@code x} to the point at which the function
029 * attains its minimum.
030 * It implements Richard Brent's algorithm (from his book "Algorithms for
031 * Minimization without Derivatives", p. 79) for finding minima of real
032 * univariate functions.
033 * <br>
034 * This code is an adaptation, partly based on the Python code from SciPy
035 * (module "optimize.py" v0.5); the original algorithm is also modified
036 * <ul>
037 *  <li>to use an initial guess provided by the user,</li>
038 *  <li>to ensure that the best point encountered is the one returned.</li>
039 * </ul>
040 *
041 * @since 2.0
042 */
043public class BrentOptimizer extends UnivariateOptimizer {
044    /**
045     * Golden section.
046     */
047    private static final double GOLDEN_SECTION = 0.5 * (3 - JdkMath.sqrt(5));
048    /**
049     * Minimum relative tolerance.
050     */
051    private static final double MIN_RELATIVE_TOLERANCE = 2 * JdkMath.ulp(1d);
052    /**
053     * Relative threshold.
054     */
055    private final double relativeThreshold;
056    /**
057     * Absolute threshold.
058     */
059    private final double absoluteThreshold;
060
061    /**
062     * The arguments are used implement the original stopping criterion
063     * of Brent's algorithm.
064     * {@code abs} and {@code rel} define a tolerance
065     * {@code tol = rel |x| + abs}. {@code rel} should be no smaller than
066     * <em>2 macheps</em> and preferably not much less than <em>sqrt(macheps)</em>,
067     * where <em>macheps</em> is the relative machine precision. {@code abs} must
068     * be positive.
069     *
070     * @param rel Relative threshold.
071     * @param abs Absolute threshold.
072     * @param checker Additional, user-defined, convergence checking
073     * procedure.
074     * @throws NotStrictlyPositiveException if {@code abs <= 0}.
075     * @throws NumberIsTooSmallException if {@code rel < 2 * Math.ulp(1d)}.
076     */
077    public BrentOptimizer(double rel,
078                          double abs,
079                          ConvergenceChecker<UnivariatePointValuePair> checker) {
080        super(checker);
081
082        if (rel < MIN_RELATIVE_TOLERANCE) {
083            throw new NumberIsTooSmallException(rel, MIN_RELATIVE_TOLERANCE, true);
084        }
085        if (abs <= 0) {
086            throw new NotStrictlyPositiveException(abs);
087        }
088
089        relativeThreshold = rel;
090        absoluteThreshold = abs;
091    }
092
093    /**
094     * The arguments are used for implementing the original stopping criterion
095     * of Brent's algorithm.
096     * {@code abs} and {@code rel} define a tolerance
097     * {@code tol = rel |x| + abs}. {@code rel} should be no smaller than
098     * <em>2 macheps</em> and preferably not much less than <em>sqrt(macheps)</em>,
099     * where <em>macheps</em> is the relative machine precision. {@code abs} must
100     * be positive.
101     *
102     * @param rel Relative threshold.
103     * @param abs Absolute threshold.
104     * @throws NotStrictlyPositiveException if {@code abs <= 0}.
105     * @throws NumberIsTooSmallException if {@code rel < 2 * Math.ulp(1d)}.
106     */
107    public BrentOptimizer(double rel,
108                          double abs) {
109        this(rel, abs, null);
110    }
111
112    /** {@inheritDoc} */
113    @Override
114    protected UnivariatePointValuePair doOptimize() {
115        final boolean isMinim = getGoalType() == GoalType.MINIMIZE;
116        final double lo = getMin();
117        final double mid = getStartValue();
118        final double hi = getMax();
119
120        // Optional additional convergence criteria.
121        final ConvergenceChecker<UnivariatePointValuePair> checker
122            = getConvergenceChecker();
123
124        double a;
125        double b;
126        if (lo < hi) {
127            a = lo;
128            b = hi;
129        } else {
130            a = hi;
131            b = lo;
132        }
133
134        double x = mid;
135        double v = x;
136        double w = x;
137        double d = 0;
138        double e = 0;
139        double fx = computeObjectiveValue(x);
140        if (!isMinim) {
141            fx = -fx;
142        }
143        double fv = fx;
144        double fw = fx;
145
146        UnivariatePointValuePair previous = null;
147        UnivariatePointValuePair current
148            = new UnivariatePointValuePair(x, isMinim ? fx : -fx);
149        // Best point encountered so far (which is the initial guess).
150        UnivariatePointValuePair best = current;
151
152        while (true) {
153            final double m = 0.5 * (a + b);
154            final double tol1 = relativeThreshold * JdkMath.abs(x) + absoluteThreshold;
155            final double tol2 = 2 * tol1;
156
157            // Default stopping criterion.
158            final boolean stop = JdkMath.abs(x - m) <= tol2 - 0.5 * (b - a);
159            if (!stop) {
160                double p = 0;
161                double q = 0;
162                double r = 0;
163                double u = 0;
164
165                if (JdkMath.abs(e) > tol1) { // Fit parabola.
166                    r = (x - w) * (fx - fv);
167                    q = (x - v) * (fx - fw);
168                    p = (x - v) * q - (x - w) * r;
169                    q = 2 * (q - r);
170
171                    if (q > 0) {
172                        p = -p;
173                    } else {
174                        q = -q;
175                    }
176
177                    r = e;
178                    e = d;
179
180                    if (p > q * (a - x) &&
181                        p < q * (b - x) &&
182                        JdkMath.abs(p) < JdkMath.abs(0.5 * q * r)) {
183                        // Parabolic interpolation step.
184                        d = p / q;
185                        u = x + d;
186
187                        // f must not be evaluated too close to a or b.
188                        if (u - a < tol2 || b - u < tol2) {
189                            if (x <= m) {
190                                d = tol1;
191                            } else {
192                                d = -tol1;
193                            }
194                        }
195                    } else {
196                        // Golden section step.
197                        if (x < m) {
198                            e = b - x;
199                        } else {
200                            e = a - x;
201                        }
202                        d = GOLDEN_SECTION * e;
203                    }
204                } else {
205                    // Golden section step.
206                    if (x < m) {
207                        e = b - x;
208                    } else {
209                        e = a - x;
210                    }
211                    d = GOLDEN_SECTION * e;
212                }
213
214                // Update by at least "tol1".
215                if (JdkMath.abs(d) < tol1) {
216                    if (d >= 0) {
217                        u = x + tol1;
218                    } else {
219                        u = x - tol1;
220                    }
221                } else {
222                    u = x + d;
223                }
224
225                double fu = computeObjectiveValue(u);
226                if (!isMinim) {
227                    fu = -fu;
228                }
229
230                // User-defined convergence checker.
231                previous = current;
232                current = new UnivariatePointValuePair(u, isMinim ? fu : -fu);
233                best = best(best,
234                            best(previous,
235                                 current,
236                                 isMinim),
237                            isMinim);
238
239                if (checker != null && checker.converged(getIterations(), previous, current)) {
240                    return best;
241                }
242
243                // Update a, b, v, w and x.
244                if (fu <= fx) {
245                    if (u < x) {
246                        b = x;
247                    } else {
248                        a = x;
249                    }
250                    v = w;
251                    fv = fw;
252                    w = x;
253                    fw = fx;
254                    x = u;
255                    fx = fu;
256                } else {
257                    if (u < x) {
258                        a = u;
259                    } else {
260                        b = u;
261                    }
262                    if (fu <= fw ||
263                        Precision.equals(w, x)) {
264                        v = w;
265                        fv = fw;
266                        w = u;
267                        fw = fu;
268                    } else if (fu <= fv ||
269                               Precision.equals(v, x) ||
270                               Precision.equals(v, w)) {
271                        v = u;
272                        fv = fu;
273                    }
274                }
275            } else { // Default termination (Brent's criterion).
276                return best(best,
277                            best(previous,
278                                 current,
279                                 isMinim),
280                            isMinim);
281            }
282
283            incrementIterationCount();
284        }
285    }
286
287    /**
288     * Selects the best of two points.
289     *
290     * @param a Point and value.
291     * @param b Point and value.
292     * @param isMinim {@code true} if the selected point must be the one with
293     * the lowest value.
294     * @return the best point, or {@code null} if {@code a} and {@code b} are
295     * both {@code null}. When {@code a} and {@code b} have the same function
296     * value, {@code a} is returned.
297     */
298    private UnivariatePointValuePair best(UnivariatePointValuePair a,
299                                          UnivariatePointValuePair b,
300                                          boolean isMinim) {
301        if (a == null) {
302            return b;
303        }
304        if (b == null) {
305            return a;
306        }
307
308        if (isMinim) {
309            return a.getValue() <= b.getValue() ? a : b;
310        } else {
311            return a.getValue() >= b.getValue() ? a : b;
312        }
313    }
314}