001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    package org.apache.commons.math.optimization.univariate;
018    
019    import org.apache.commons.math.util.Precision;
020    import org.apache.commons.math.util.FastMath;
021    import org.apache.commons.math.exception.NumberIsTooSmallException;
022    import org.apache.commons.math.exception.NotStrictlyPositiveException;
023    import org.apache.commons.math.optimization.ConvergenceChecker;
024    import org.apache.commons.math.optimization.GoalType;
025    
026    /**
027     * Implements Richard Brent's algorithm (from his book "Algorithms for
028     * Minimization without Derivatives", p. 79) for finding minima of real
029     * univariate functions. This implementation is an adaptation partly
030     * based on the Python code from SciPy (module "optimize.py" v0.5).
031     * If the function is defined on some interval {@code (lo, hi)}, then
032     * this method finds an approximation {@code x} to the point at which
033     * the function attains its minimum.
034     * <br/>
035     * The user is responsible for calling {@link
036     * #setConvergenceChecker(ConvergenceChecker) ConvergenceChecker}
037     * prior to using the optimizer.
038     *
039     * @version $Id: BrentOptimizer.java 1181282 2011-10-10 22:35:54Z erans $
040     * @since 2.0
041     */
042    public class BrentOptimizer extends AbstractUnivariateRealOptimizer {
043        /**
044         * Golden section.
045         */
046        private static final double GOLDEN_SECTION = 0.5 * (3 - FastMath.sqrt(5));
047        /**
048         * Minimum relative tolerance.
049         */
050        private static final double MIN_RELATIVE_TOLERANCE = 2 * FastMath.ulp(1d);
051        /**
052         * Relative threshold.
053         */
054        private final double relativeThreshold;
055        /**
056         * Absolute threshold.
057         */
058        private final double absoluteThreshold;
059    
060        /**
061         * The arguments are used implement the original stopping criterion
062         * of Brent's algorithm.
063         * {@code abs} and {@code rel} define a tolerance
064         * {@code tol = rel |x| + abs}. {@code rel} should be no smaller than
065         * <em>2 macheps</em> and preferably not much less than <em>sqrt(macheps)</em>,
066         * where <em>macheps</em> is the relative machine precision. {@code abs} must
067         * be positive.
068         *
069         * @param rel Relative threshold.
070         * @param abs Absolute threshold.
071         * @throws NotStrictlyPositiveException if {@code abs <= 0}.
072         * @throws NumberIsTooSmallException if {@code rel < 2 * Math.ulp(1d)}.
073         */
074        public BrentOptimizer(double rel,
075                              double abs) {
076            if (rel < MIN_RELATIVE_TOLERANCE) {
077                throw new NumberIsTooSmallException(rel, MIN_RELATIVE_TOLERANCE, true);
078            }
079            if (abs <= 0) {
080                throw new NotStrictlyPositiveException(abs);
081            }
082            relativeThreshold = rel;
083            absoluteThreshold = abs;
084        }
085    
086        /** {@inheritDoc} */
087        @Override
088        protected UnivariateRealPointValuePair doOptimize() {
089            final boolean isMinim = getGoalType() == GoalType.MINIMIZE;
090            final double lo = getMin();
091            final double mid = getStartValue();
092            final double hi = getMax();
093    
094            // Optional additional convergence criteria.
095            final ConvergenceChecker<UnivariateRealPointValuePair> checker
096                = getConvergenceChecker();
097    
098            double a;
099            double b;
100            if (lo < hi) {
101                a = lo;
102                b = hi;
103            } else {
104                a = hi;
105                b = lo;
106            }
107    
108            double x = mid;
109            double v = x;
110            double w = x;
111            double d = 0;
112            double e = 0;
113            double fx = computeObjectiveValue(x);
114            if (!isMinim) {
115                fx = -fx;
116            }
117            double fv = fx;
118            double fw = fx;
119    
120            UnivariateRealPointValuePair previous = null;
121            UnivariateRealPointValuePair current
122                = new UnivariateRealPointValuePair(x, isMinim ? fx : -fx);
123    
124            int iter = 0;
125            while (true) {
126                final double m = 0.5 * (a + b);
127                final double tol1 = relativeThreshold * FastMath.abs(x) + absoluteThreshold;
128                final double tol2 = 2 * tol1;
129    
130                // Default stopping criterion.
131                final boolean stop = FastMath.abs(x - m) <= tol2 - 0.5 * (b - a);
132                if (!stop) {
133                    double p = 0;
134                    double q = 0;
135                    double r = 0;
136                    double u = 0;
137    
138                    if (FastMath.abs(e) > tol1) { // Fit parabola.
139                        r = (x - w) * (fx - fv);
140                        q = (x - v) * (fx - fw);
141                        p = (x - v) * q - (x - w) * r;
142                        q = 2 * (q - r);
143    
144                        if (q > 0) {
145                            p = -p;
146                        } else {
147                            q = -q;
148                        }
149    
150                        r = e;
151                        e = d;
152    
153                        if (p > q * (a - x) &&
154                            p < q * (b - x) &&
155                            FastMath.abs(p) < FastMath.abs(0.5 * q * r)) {
156                            // Parabolic interpolation step.
157                            d = p / q;
158                            u = x + d;
159    
160                            // f must not be evaluated too close to a or b.
161                            if (u - a < tol2 || b - u < tol2) {
162                                if (x <= m) {
163                                    d = tol1;
164                                } else {
165                                    d = -tol1;
166                                }
167                            }
168                        } else {
169                            // Golden section step.
170                            if (x < m) {
171                                e = b - x;
172                            } else {
173                                e = a - x;
174                            }
175                            d = GOLDEN_SECTION * e;
176                        }
177                    } else {
178                        // Golden section step.
179                        if (x < m) {
180                            e = b - x;
181                        } else {
182                            e = a - x;
183                        }
184                        d = GOLDEN_SECTION * e;
185                    }
186    
187                    // Update by at least "tol1".
188                    if (FastMath.abs(d) < tol1) {
189                        if (d >= 0) {
190                            u = x + tol1;
191                        } else {
192                            u = x - tol1;
193                        }
194                    } else {
195                        u = x + d;
196                    }
197    
198                    double fu = computeObjectiveValue(u);
199                    if (!isMinim) {
200                        fu = -fu;
201                    }
202    
203                    // Update a, b, v, w and x.
204                    if (fu <= fx) {
205                        if (u < x) {
206                            b = x;
207                        } else {
208                            a = x;
209                        }
210                        v = w;
211                        fv = fw;
212                        w = x;
213                        fw = fx;
214                        x = u;
215                        fx = fu;
216                    } else {
217                        if (u < x) {
218                            a = u;
219                        } else {
220                            b = u;
221                        }
222                        if (fu <= fw ||
223                            Precision.equals(w, x)) {
224                            v = w;
225                            fv = fw;
226                            w = u;
227                            fw = fu;
228                        } else if (fu <= fv ||
229                                   Precision.equals(v, x) ||
230                                   Precision.equals(v, w)) {
231                            v = u;
232                            fv = fu;
233                        }
234                    }
235    
236                    previous = current;
237                    current = new UnivariateRealPointValuePair(x, isMinim ? fx : -fx);
238    
239                    // User-defined convergence checker.
240                    if (checker != null) {
241                        if (checker.converged(iter, previous, current)) {
242                            return current;
243                        }
244                    }
245                } else { // Default termination (Brent's criterion).
246                    return current;
247                }
248                ++iter;
249            }
250        }
251    }