001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    package org.apache.commons.math.analysis;
018    
019    
020    import org.apache.commons.math.FunctionEvaluationException;
021    import org.apache.commons.math.MaxIterationsExceededException;
022    
023    /**
024     * Implements the <a href="http://mathworld.wolfram.com/BrentsMethod.html">
025     * Brent algorithm</a> for  finding zeros of real univariate functions.
026     * <p>
027     * The function should be continuous but not necessarily smooth.</p>
028     *  
029     * @version $Revision:670469 $ $Date:2008-06-23 10:01:38 +0200 (lun., 23 juin 2008) $
030     */
031    public class BrentSolver extends UnivariateRealSolverImpl {
032        
033        /** Serializable version identifier */
034        private static final long serialVersionUID = -2136672307739067002L;
035    
036        /**
037         * Construct a solver for the given function.
038         * 
039         * @param f function to solve.
040         */
041        public BrentSolver(UnivariateRealFunction f) {
042            super(f, 100, 1E-6);
043        }
044    
045        /**
046         * Find a zero in the given interval with an initial guess.
047         * <p>Throws <code>IllegalArgumentException</code> if the values of the
048         * function at the three points have the same sign (note that it is
049         * allowed to have endpoints with the same sign if the initial point has
050         * opposite sign function-wise).</p>
051         * 
052         * @param min the lower bound for the interval.
053         * @param max the upper bound for the interval.
054         * @param initial the start value to use (must be set to min if no
055         * initial point is known).
056         * @return the value where the function is zero
057         * @throws MaxIterationsExceededException the maximum iteration count
058         * is exceeded 
059         * @throws FunctionEvaluationException if an error occurs evaluating
060         *  the function
061         * @throws IllegalArgumentException if initial is not between min and max
062         * (even if it <em>is</em> a root)
063         */
064        public double solve(double min, double max, double initial)
065            throws MaxIterationsExceededException, FunctionEvaluationException {
066    
067            if (((initial - min) * (max -initial)) < 0) {
068                throw new IllegalArgumentException("Initial guess is not in search" +
069                          " interval." + "  Initial: " + initial +
070                          "  Endpoints: [" + min + "," + max + "]");
071            }
072    
073            // return the initial guess if it is good enough
074            double yInitial = f.value(initial);
075            if (Math.abs(yInitial) <= functionValueAccuracy) {
076                setResult(initial, 0);
077                return result;
078            }
079    
080            // return the first endpoint if it is good enough
081            double yMin = f.value(min);
082            if (Math.abs(yMin) <= functionValueAccuracy) {
083                setResult(yMin, 0);
084                return result;
085            }
086    
087            // reduce interval if min and initial bracket the root
088            if (yInitial * yMin < 0) {
089                return solve(min, yMin, initial, yInitial, min, yMin);
090            }
091    
092            // return the second endpoint if it is good enough
093            double yMax = f.value(max);
094            if (Math.abs(yMax) <= functionValueAccuracy) {
095                setResult(yMax, 0);
096                return result;
097            }
098    
099            // reduce interval if initial and max bracket the root
100            if (yInitial * yMax < 0) {
101                return solve(initial, yInitial, max, yMax, initial, yInitial);
102            }
103    
104            // full Brent algorithm starting with provided initial guess
105            return solve(min, yMin, max, yMax, initial, yInitial);
106    
107        }
108        
109        /**
110         * Find a zero in the given interval.
111         * <p>
112         * Requires that the values of the function at the endpoints have opposite
113         * signs. An <code>IllegalArgumentException</code> is thrown if this is not
114         * the case.</p>
115         * 
116         * @param min the lower bound for the interval.
117         * @param max the upper bound for the interval.
118         * @return the value where the function is zero
119         * @throws MaxIterationsExceededException if the maximum iteration count is exceeded
120         * @throws FunctionEvaluationException if an error occurs evaluating the
121         * function 
122         * @throws IllegalArgumentException if min is not less than max or the
123         * signs of the values of the function at the endpoints are not opposites
124         */
125        public double solve(double min, double max) throws MaxIterationsExceededException, 
126            FunctionEvaluationException {
127            
128            clearResult();
129            verifyInterval(min, max);
130            
131            double ret = Double.NaN;
132            
133            double yMin = f.value(min);
134            double yMax = f.value(max);
135            
136            // Verify bracketing
137            double sign = yMin * yMax;
138            if (sign > 0) {
139                // check if either value is close to a zero
140                if (Math.abs(yMin) <= functionValueAccuracy) {
141                    setResult(min, 0);
142                    ret = min;
143                } else if (Math.abs(yMax) <= functionValueAccuracy) {
144                    setResult(max, 0);
145                    ret = max;
146                } else {
147                    // neither value is close to zero and min and max do not bracket root.
148                    throw new IllegalArgumentException
149                    ("Function values at endpoints do not have different signs." +
150                            "  Endpoints: [" + min + "," + max + "]" + 
151                            "  Values: [" + yMin + "," + yMax + "]");
152                }
153            } else if (sign < 0){
154                // solve using only the first endpoint as initial guess
155                ret = solve(min, yMin, max, yMax, min, yMin);
156            } else {
157                // either min or max is a root
158                if (yMin == 0.0) {
159                    ret = min;
160                } else {
161                    ret = max;
162                }
163            }
164    
165            return ret;
166        }
167            
168        /**
169         * Find a zero starting search according to the three provided points.
170         * @param x0 old approximation for the root
171         * @param y0 function value at the approximation for the root
172         * @param x1 last calculated approximation for the root
173         * @param y1 function value at the last calculated approximation
174         * for the root
175         * @param x2 bracket point (must be set to x0 if no bracket point is
176         * known, this will force starting with linear interpolation)
177         * @param y2 function value at the bracket point.
178         * @return the value where the function is zero
179         * @throws MaxIterationsExceededException if the maximum iteration count
180         * is exceeded
181         * @throws FunctionEvaluationException if an error occurs evaluating
182         * the function 
183         */
184        private double solve(double x0, double y0,
185                             double x1, double y1,
186                             double x2, double y2)
187        throws MaxIterationsExceededException, FunctionEvaluationException {
188    
189            double delta = x1 - x0;
190            double oldDelta = delta;
191    
192            int i = 0;
193            while (i < maximalIterationCount) {
194                if (Math.abs(y2) < Math.abs(y1)) {
195                    // use the bracket point if is better than last approximation
196                    x0 = x1;
197                    x1 = x2;
198                    x2 = x0;
199                    y0 = y1;
200                    y1 = y2;
201                    y2 = y0;
202                }
203                if (Math.abs(y1) <= functionValueAccuracy) {
204                    // Avoid division by very small values. Assume
205                    // the iteration has converged (the problem may
206                    // still be ill conditioned)
207                    setResult(x1, i);
208                    return result;
209                }
210                double dx = (x2 - x1);
211                double tolerance =
212                    Math.max(relativeAccuracy * Math.abs(x1), absoluteAccuracy);
213                if (Math.abs(dx) <= tolerance) {
214                    setResult(x1, i);
215                    return result;
216                }
217                if ((Math.abs(oldDelta) < tolerance) ||
218                        (Math.abs(y0) <= Math.abs(y1))) {
219                    // Force bisection.
220                    delta = 0.5 * dx;
221                    oldDelta = delta;
222                } else {
223                    double r3 = y1 / y0;
224                    double p;
225                    double p1;
226                    // the equality test (x0 == x2) is intentional,
227                    // it is part of the original Brent's method,
228                    // it should NOT be replaced by proximity test
229                    if (x0 == x2) {
230                        // Linear interpolation.
231                        p = dx * r3;
232                        p1 = 1.0 - r3;
233                    } else {
234                        // Inverse quadratic interpolation.
235                        double r1 = y0 / y2;
236                        double r2 = y1 / y2;
237                        p = r3 * (dx * r1 * (r1 - r2) - (x1 - x0) * (r2 - 1.0));
238                        p1 = (r1 - 1.0) * (r2 - 1.0) * (r3 - 1.0);
239                    }
240                    if (p > 0.0) {
241                        p1 = -p1;
242                    } else {
243                        p = -p;
244                    }
245                    if (2.0 * p >= 1.5 * dx * p1 - Math.abs(tolerance * p1) ||
246                            p >= Math.abs(0.5 * oldDelta * p1)) {
247                        // Inverse quadratic interpolation gives a value
248                        // in the wrong direction, or progress is slow.
249                        // Fall back to bisection.
250                        delta = 0.5 * dx;
251                        oldDelta = delta;
252                    } else {
253                        oldDelta = delta;
254                        delta = p / p1;
255                    }
256                }
257                // Save old X1, Y1 
258                x0 = x1;
259                y0 = y1;
260                // Compute new X1, Y1
261                if (Math.abs(delta) > tolerance) {
262                    x1 = x1 + delta;
263                } else if (dx > 0.0) {
264                    x1 = x1 + 0.5 * tolerance;
265                } else if (dx <= 0.0) {
266                    x1 = x1 - 0.5 * tolerance;
267                }
268                y1 = f.value(x1);
269                if ((y1 > 0) == (y2 > 0)) {
270                    x2 = x0;
271                    y2 = y0;
272                    delta = x1 - x0;
273                    oldDelta = delta;
274                }
275                i++;
276            }
277            throw new MaxIterationsExceededException(maximalIterationCount);
278        }
279    }