001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.math4.legacy.optim.univariate;
018
019import org.apache.commons.math4.legacy.analysis.UnivariateFunction;
020import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
021import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
022import org.apache.commons.math4.legacy.optim.ConvergenceChecker;
023import org.apache.commons.math4.legacy.optim.nonlinear.scalar.GoalType;
024import org.apache.commons.math4.core.jdkmath.JdkMath;
025import org.apache.commons.numbers.core.Precision;
026
027/**
028 * For a function defined on some interval {@code (lo, hi)}, this class
029 * finds an approximation {@code x} to the point at which the function
030 * attains its minimum.
031 * It implements Richard Brent's algorithm (from his book "Algorithms for
032 * Minimization without Derivatives", p. 79) for finding minima of real
033 * univariate functions.
034 * <br>
035 * This code is an adaptation, partly based on the Python code from SciPy
036 * (module "optimize.py" v0.5); the original algorithm is also modified
037 * <ul>
038 *  <li>to use an initial guess provided by the user,</li>
039 *  <li>to ensure that the best point encountered is the one returned.</li>
040 * </ul>
041 *
042 * @since 2.0
043 */
044public class BrentOptimizer extends UnivariateOptimizer {
045    /**
046     * Golden section.
047     */
048    private static final double GOLDEN_SECTION = 0.5 * (3 - JdkMath.sqrt(5));
049    /**
050     * Minimum relative tolerance.
051     */
052    private static final double MIN_RELATIVE_TOLERANCE = 2 * JdkMath.ulp(1d);
053    /**
054     * Relative threshold.
055     */
056    private final double relativeThreshold;
057    /**
058     * Absolute threshold.
059     */
060    private final double absoluteThreshold;
061
062    /**
063     * The arguments are used implement the original stopping criterion
064     * of Brent's algorithm.
065     * {@code abs} and {@code rel} define a tolerance
066     * {@code tol = rel |x| + abs}. {@code rel} should be no smaller than
067     * <em>2 macheps</em> and preferably not much less than <em>sqrt(macheps)</em>,
068     * where <em>macheps</em> is the relative machine precision. {@code abs} must
069     * be positive.
070     *
071     * @param rel Relative threshold.
072     * @param abs Absolute threshold.
073     * @param checker Additional, user-defined, convergence checking
074     * procedure.
075     * @throws NotStrictlyPositiveException if {@code abs <= 0}.
076     * @throws NumberIsTooSmallException if {@code rel < 2 * Math.ulp(1d)}.
077     */
078    public BrentOptimizer(double rel,
079                          double abs,
080                          ConvergenceChecker<UnivariatePointValuePair> checker) {
081        super(checker);
082
083        if (rel < MIN_RELATIVE_TOLERANCE) {
084            throw new NumberIsTooSmallException(rel, MIN_RELATIVE_TOLERANCE, true);
085        }
086        if (abs <= 0) {
087            throw new NotStrictlyPositiveException(abs);
088        }
089
090        relativeThreshold = rel;
091        absoluteThreshold = abs;
092    }
093
094    /**
095     * The arguments are used for implementing the original stopping criterion
096     * of Brent's algorithm.
097     * {@code abs} and {@code rel} define a tolerance
098     * {@code tol = rel |x| + abs}. {@code rel} should be no smaller than
099     * <em>2 macheps</em> and preferably not much less than <em>sqrt(macheps)</em>,
100     * where <em>macheps</em> is the relative machine precision. {@code abs} must
101     * be positive.
102     *
103     * @param rel Relative threshold.
104     * @param abs Absolute threshold.
105     * @throws NotStrictlyPositiveException if {@code abs <= 0}.
106     * @throws NumberIsTooSmallException if {@code rel < 2 * Math.ulp(1d)}.
107     */
108    public BrentOptimizer(double rel,
109                          double abs) {
110        this(rel, abs, null);
111    }
112
113    /** {@inheritDoc} */
114    @Override
115    protected UnivariatePointValuePair doOptimize() {
116        final boolean isMinim = getGoalType() == GoalType.MINIMIZE;
117        final double lo = getMin();
118        final double mid = getStartValue();
119        final double hi = getMax();
120        final UnivariateFunction func = getObjectiveFunction();
121
122        // Optional additional convergence criteria.
123        final ConvergenceChecker<UnivariatePointValuePair> checker
124            = getConvergenceChecker();
125
126        double a;
127        double b;
128        if (lo < hi) {
129            a = lo;
130            b = hi;
131        } else {
132            a = hi;
133            b = lo;
134        }
135
136        double x = mid;
137        double v = x;
138        double w = x;
139        double d = 0;
140        double e = 0;
141        double fx = func.value(x);
142        if (!isMinim) {
143            fx = -fx;
144        }
145        double fv = fx;
146        double fw = fx;
147
148        UnivariatePointValuePair previous = null;
149        UnivariatePointValuePair current
150            = new UnivariatePointValuePair(x, isMinim ? fx : -fx);
151        // Best point encountered so far (which is the initial guess).
152        UnivariatePointValuePair best = current;
153
154        while (true) {
155            final double m = 0.5 * (a + b);
156            final double tol1 = relativeThreshold * JdkMath.abs(x) + absoluteThreshold;
157            final double tol2 = 2 * tol1;
158
159            // Default stopping criterion.
160            final boolean stop = JdkMath.abs(x - m) <= tol2 - 0.5 * (b - a);
161            if (!stop) {
162                double p = 0;
163                double q = 0;
164                double r = 0;
165                double u = 0;
166
167                if (JdkMath.abs(e) > tol1) { // Fit parabola.
168                    r = (x - w) * (fx - fv);
169                    q = (x - v) * (fx - fw);
170                    p = (x - v) * q - (x - w) * r;
171                    q = 2 * (q - r);
172
173                    if (q > 0) {
174                        p = -p;
175                    } else {
176                        q = -q;
177                    }
178
179                    r = e;
180                    e = d;
181
182                    if (p > q * (a - x) &&
183                        p < q * (b - x) &&
184                        JdkMath.abs(p) < JdkMath.abs(0.5 * q * r)) {
185                        // Parabolic interpolation step.
186                        d = p / q;
187                        u = x + d;
188
189                        // f must not be evaluated too close to a or b.
190                        if (u - a < tol2 || b - u < tol2) {
191                            if (x <= m) {
192                                d = tol1;
193                            } else {
194                                d = -tol1;
195                            }
196                        }
197                    } else {
198                        // Golden section step.
199                        if (x < m) {
200                            e = b - x;
201                        } else {
202                            e = a - x;
203                        }
204                        d = GOLDEN_SECTION * e;
205                    }
206                } else {
207                    // Golden section step.
208                    if (x < m) {
209                        e = b - x;
210                    } else {
211                        e = a - x;
212                    }
213                    d = GOLDEN_SECTION * e;
214                }
215
216                // Update by at least "tol1".
217                if (JdkMath.abs(d) < tol1) {
218                    if (d >= 0) {
219                        u = x + tol1;
220                    } else {
221                        u = x - tol1;
222                    }
223                } else {
224                    u = x + d;
225                }
226
227                double fu = func.value(u);
228                if (!isMinim) {
229                    fu = -fu;
230                }
231
232                // User-defined convergence checker.
233                previous = current;
234                current = new UnivariatePointValuePair(u, isMinim ? fu : -fu);
235                best = best(best,
236                            best(previous,
237                                 current,
238                                 isMinim),
239                            isMinim);
240
241                if (checker != null && checker.converged(getIterations(), previous, current)) {
242                    return best;
243                }
244
245                // Update a, b, v, w and x.
246                if (fu <= fx) {
247                    if (u < x) {
248                        b = x;
249                    } else {
250                        a = x;
251                    }
252                    v = w;
253                    fv = fw;
254                    w = x;
255                    fw = fx;
256                    x = u;
257                    fx = fu;
258                } else {
259                    if (u < x) {
260                        a = u;
261                    } else {
262                        b = u;
263                    }
264                    if (fu <= fw ||
265                        Precision.equals(w, x)) {
266                        v = w;
267                        fv = fw;
268                        w = u;
269                        fw = fu;
270                    } else if (fu <= fv ||
271                               Precision.equals(v, x) ||
272                               Precision.equals(v, w)) {
273                        v = u;
274                        fv = fu;
275                    }
276                }
277            } else { // Default termination (Brent's criterion).
278                return best(best,
279                            best(previous,
280                                 current,
281                                 isMinim),
282                            isMinim);
283            }
284
285            incrementIterationCount();
286        }
287    }
288
289    /**
290     * Selects the best of two points.
291     *
292     * @param a Point and value.
293     * @param b Point and value.
294     * @param isMinim {@code true} if the selected point must be the one with
295     * the lowest value.
296     * @return the best point, or {@code null} if {@code a} and {@code b} are
297     * both {@code null}. When {@code a} and {@code b} have the same function
298     * value, {@code a} is returned.
299     */
300    private UnivariatePointValuePair best(UnivariatePointValuePair a,
301                                          UnivariatePointValuePair b,
302                                          boolean isMinim) {
303        if (a == null) {
304            return b;
305        }
306        if (b == null) {
307            return a;
308        }
309
310        if (isMinim) {
311            return a.getValue() <= b.getValue() ? a : b;
312        } else {
313            return a.getValue() >= b.getValue() ? a : b;
314        }
315    }
316}