001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    
018    package org.apache.commons.math3.optim.nonlinear.scalar.gradient;
019    
020    import org.apache.commons.math3.analysis.UnivariateFunction;
021    import org.apache.commons.math3.analysis.solvers.BrentSolver;
022    import org.apache.commons.math3.analysis.solvers.UnivariateSolver;
023    import org.apache.commons.math3.exception.MathInternalError;
024    import org.apache.commons.math3.exception.MathIllegalStateException;
025    import org.apache.commons.math3.exception.TooManyEvaluationsException;
026    import org.apache.commons.math3.exception.util.LocalizedFormats;
027    import org.apache.commons.math3.optim.OptimizationData;
028    import org.apache.commons.math3.optim.nonlinear.scalar.GoalType;
029    import org.apache.commons.math3.optim.PointValuePair;
030    import org.apache.commons.math3.optim.ConvergenceChecker;
031    import org.apache.commons.math3.optim.nonlinear.scalar.GradientMultivariateOptimizer;
032    import org.apache.commons.math3.util.FastMath;
033    
034    /**
035     * Non-linear conjugate gradient optimizer.
036     * <p>
037     * This class supports both the Fletcher-Reeves and the Polak-Ribière
038     * update formulas for the conjugate search directions.
039     * It also supports optional preconditioning.
040     * </p>
041     *
042     * @version $Id: NonLinearConjugateGradientOptimizer.java 1416643 2012-12-03 19:37:14Z tn $
043     * @since 2.0
044     */
045    public class NonLinearConjugateGradientOptimizer
046        extends GradientMultivariateOptimizer {
047        /** Update formula for the beta parameter. */
048        private final Formula updateFormula;
049        /** Preconditioner (may be null). */
050        private final Preconditioner preconditioner;
051        /** solver to use in the line search (may be null). */
052        private final UnivariateSolver solver;
053        /** Initial step used to bracket the optimum in line search. */
054        private double initialStep = 1;
055    
056        /**
057         * Constructor with default {@link BrentSolver line search solver} and
058         * {@link IdentityPreconditioner preconditioner}.
059         *
060         * @param updateFormula formula to use for updating the &beta; parameter,
061         * must be one of {@link Formula#FLETCHER_REEVES} or
062         * {@link Formula#POLAK_RIBIERE}.
063         * @param checker Convergence checker.
064         */
065        public NonLinearConjugateGradientOptimizer(final Formula updateFormula,
066                                                   ConvergenceChecker<PointValuePair> checker) {
067            this(updateFormula,
068                 checker,
069                 new BrentSolver(),
070                 new IdentityPreconditioner());
071        }
072    
073        /**
074         * Available choices of update formulas for the updating the parameter
075         * that is used to compute the successive conjugate search directions.
076         * For non-linear conjugate gradients, there are
077         * two formulas:
078         * <ul>
079         *   <li>Fletcher-Reeves formula</li>
080         *   <li>Polak-Ribière formula</li>
081         * </ul>
082         *
083         * On the one hand, the Fletcher-Reeves formula is guaranteed to converge
084         * if the start point is close enough of the optimum whether the
085         * Polak-Ribière formula may not converge in rare cases. On the
086         * other hand, the Polak-Ribière formula is often faster when it
087         * does converge. Polak-Ribière is often used.
088         *
089         * @since 2.0
090         */
091        public static enum Formula {
092            /** Fletcher-Reeves formula. */
093            FLETCHER_REEVES,
094            /** Polak-Ribière formula. */
095            POLAK_RIBIERE
096        }
097    
098        /**
099         * The initial step is a factor with respect to the search direction
100         * (which itself is roughly related to the gradient of the function).
101         * <br/>
102         * It is used to find an interval that brackets the optimum in line
103         * search.
104         *
105         * @since 3.1
106         */
107        public static class BracketingStep implements OptimizationData {
108            /** Initial step. */
109            private final double initialStep;
110    
111            /**
112             * @param step Initial step for the bracket search.
113             */
114            public BracketingStep(double step) {
115                initialStep = step;
116            }
117    
118            /**
119             * Gets the initial step.
120             *
121             * @return the initial step.
122             */
123            public double getBracketingStep() {
124                return initialStep;
125            }
126        }
127    
128        /**
129         * Constructor with default {@link IdentityPreconditioner preconditioner}.
130         *
131         * @param updateFormula formula to use for updating the &beta; parameter,
132         * must be one of {@link Formula#FLETCHER_REEVES} or
133         * {@link Formula#POLAK_RIBIERE}.
134         * @param checker Convergence checker.
135         * @param lineSearchSolver Solver to use during line search.
136         */
137        public NonLinearConjugateGradientOptimizer(final Formula updateFormula,
138                                                   ConvergenceChecker<PointValuePair> checker,
139                                                   final UnivariateSolver lineSearchSolver) {
140            this(updateFormula,
141                 checker,
142                 lineSearchSolver,
143                 new IdentityPreconditioner());
144        }
145    
146        /**
147         * @param updateFormula formula to use for updating the &beta; parameter,
148         * must be one of {@link Formula#FLETCHER_REEVES} or
149         * {@link Formula#POLAK_RIBIERE}.
150         * @param checker Convergence checker.
151         * @param lineSearchSolver Solver to use during line search.
152         * @param preconditioner Preconditioner.
153         */
154        public NonLinearConjugateGradientOptimizer(final Formula updateFormula,
155                                                   ConvergenceChecker<PointValuePair> checker,
156                                                   final UnivariateSolver lineSearchSolver,
157                                                   final Preconditioner preconditioner) {
158            super(checker);
159    
160            this.updateFormula = updateFormula;
161            solver = lineSearchSolver;
162            this.preconditioner = preconditioner;
163            initialStep = 1;
164        }
165    
166        /**
167         * {@inheritDoc}
168         *
169         * @param optData Optimization data.
170         * The following data will be looked for:
171         * <ul>
172         *  <li>{@link org.apache.commons.math3.optim.MaxEval}</li>
173         *  <li>{@link org.apache.commons.math3.optim.InitialGuess}</li>
174         *  <li>{@link org.apache.commons.math3.optim.SimpleBounds}</li>
175         *  <li>{@link org.apache.commons.math3.optim.nonlinear.scalar.GoalType}</li>
176         *  <li>{@link org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction}</li>
177         *  <li>{@link org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunctionGradient}</li>
178         *  <li>{@link BracketingStep}</li>
179         * </ul>
180         * @return {@inheritDoc}
181         * @throws TooManyEvaluationsException if the maximal number of
182         * evaluations (of the objective function) is exceeded.
183         */
184        @Override
185        public PointValuePair optimize(OptimizationData... optData)
186            throws TooManyEvaluationsException {
187             // Retrieve settings.
188            parseOptimizationData(optData);
189            // Set up base class and perform computation.
190            return super.optimize(optData);
191        }
192    
193        /** {@inheritDoc} */
194        @Override
195        protected PointValuePair doOptimize() {
196            final ConvergenceChecker<PointValuePair> checker = getConvergenceChecker();
197            final double[] point = getStartPoint();
198            final GoalType goal = getGoalType();
199            final int n = point.length;
200            double[] r = computeObjectiveGradient(point);
201            if (goal == GoalType.MINIMIZE) {
202                for (int i = 0; i < n; i++) {
203                    r[i] = -r[i];
204                }
205            }
206    
207            // Initial search direction.
208            double[] steepestDescent = preconditioner.precondition(point, r);
209            double[] searchDirection = steepestDescent.clone();
210    
211            double delta = 0;
212            for (int i = 0; i < n; ++i) {
213                delta += r[i] * searchDirection[i];
214            }
215    
216            PointValuePair current = null;
217            int iter = 0;
218            int maxEval = getMaxEvaluations();
219            while (true) {
220                ++iter;
221    
222                final double objective = computeObjectiveValue(point);
223                PointValuePair previous = current;
224                current = new PointValuePair(point, objective);
225                if (previous != null) {
226                    if (checker.converged(iter, previous, current)) {
227                        // We have found an optimum.
228                        return current;
229                    }
230                }
231    
232                // Find the optimal step in the search direction.
233                final UnivariateFunction lsf = new LineSearchFunction(point, searchDirection);
234                final double uB = findUpperBound(lsf, 0, initialStep);
235                // XXX Last parameters is set to a value close to zero in order to
236                // work around the divergence problem in the "testCircleFitting"
237                // unit test (see MATH-439).
238                final double step = solver.solve(maxEval, lsf, 0, uB, 1e-15);
239                maxEval -= solver.getEvaluations(); // Subtract used up evaluations.
240    
241                // Validate new point.
242                for (int i = 0; i < point.length; ++i) {
243                    point[i] += step * searchDirection[i];
244                }
245    
246                r = computeObjectiveGradient(point);
247                if (goal == GoalType.MINIMIZE) {
248                    for (int i = 0; i < n; ++i) {
249                        r[i] = -r[i];
250                    }
251                }
252    
253                // Compute beta.
254                final double deltaOld = delta;
255                final double[] newSteepestDescent = preconditioner.precondition(point, r);
256                delta = 0;
257                for (int i = 0; i < n; ++i) {
258                    delta += r[i] * newSteepestDescent[i];
259                }
260    
261                final double beta;
262                switch (updateFormula) {
263                case FLETCHER_REEVES:
264                    beta = delta / deltaOld;
265                    break;
266                case POLAK_RIBIERE:
267                    double deltaMid = 0;
268                    for (int i = 0; i < r.length; ++i) {
269                        deltaMid += r[i] * steepestDescent[i];
270                    }
271                    beta = (delta - deltaMid) / deltaOld;
272                    break;
273                default:
274                    // Should never happen.
275                    throw new MathInternalError();
276                }
277                steepestDescent = newSteepestDescent;
278    
279                // Compute conjugate search direction.
280                if (iter % n == 0 ||
281                    beta < 0) {
282                    // Break conjugation: reset search direction.
283                    searchDirection = steepestDescent.clone();
284                } else {
285                    // Compute new conjugate search direction.
286                    for (int i = 0; i < n; ++i) {
287                        searchDirection[i] = steepestDescent[i] + beta * searchDirection[i];
288                    }
289                }
290            }
291        }
292    
293        /**
294         * Scans the list of (required and optional) optimization data that
295         * characterize the problem.
296         *
297         * @param optData Optimization data.
298         * The following data will be looked for:
299         * <ul>
300         *  <li>{@link InitialStep}</li>
301         * </ul>
302         */
303        private void parseOptimizationData(OptimizationData... optData) {
304            // The existing values (as set by the previous call) are reused if
305            // not provided in the argument list.
306            for (OptimizationData data : optData) {
307                if  (data instanceof BracketingStep) {
308                    initialStep = ((BracketingStep) data).getBracketingStep();
309                    // If more data must be parsed, this statement _must_ be
310                    // changed to "continue".
311                    break;
312                }
313            }
314        }
315    
316        /**
317         * Finds the upper bound b ensuring bracketing of a root between a and b.
318         *
319         * @param f function whose root must be bracketed.
320         * @param a lower bound of the interval.
321         * @param h initial step to try.
322         * @return b such that f(a) and f(b) have opposite signs.
323         * @throws MathIllegalStateException if no bracket can be found.
324         */
325        private double findUpperBound(final UnivariateFunction f,
326                                      final double a, final double h) {
327            final double yA = f.value(a);
328            double yB = yA;
329            for (double step = h; step < Double.MAX_VALUE; step *= FastMath.max(2, yA / yB)) {
330                final double b = a + step;
331                yB = f.value(b);
332                if (yA * yB <= 0) {
333                    return b;
334                }
335            }
336            throw new MathIllegalStateException(LocalizedFormats.UNABLE_TO_BRACKET_OPTIMUM_IN_LINE_SEARCH);
337        }
338    
339        /** Default identity preconditioner. */
340        public static class IdentityPreconditioner implements Preconditioner {
341            /** {@inheritDoc} */
342            public double[] precondition(double[] variables, double[] r) {
343                return r.clone();
344            }
345        }
346    
347        /**
348         * Internal class for line search.
349         * <p>
350         * The function represented by this class is the dot product of
351         * the objective function gradient and the search direction. Its
352         * value is zero when the gradient is orthogonal to the search
353         * direction, i.e. when the objective function value is a local
354         * extremum along the search direction.
355         * </p>
356         */
357        private class LineSearchFunction implements UnivariateFunction {
358            /** Current point. */
359            private final double[] currentPoint;
360            /** Search direction. */
361            private final double[] searchDirection;
362    
363            /**
364             * @param point Current point.
365             * @param direction Search direction.
366             */
367            public LineSearchFunction(double[] point,
368                                      double[] direction) {
369                currentPoint = point.clone();
370                searchDirection = direction.clone();
371            }
372    
373            /** {@inheritDoc} */
374            public double value(double x) {
375                // current point in the search direction
376                final double[] shiftedPoint = currentPoint.clone();
377                for (int i = 0; i < shiftedPoint.length; ++i) {
378                    shiftedPoint[i] += x * searchDirection[i];
379                }
380    
381                // gradient of the objective function
382                final double[] gradient = computeObjectiveGradient(shiftedPoint);
383    
384                // dot product with the search direction
385                double dotProduct = 0;
386                for (int i = 0; i < gradient.length; ++i) {
387                    dotProduct += gradient[i] * searchDirection[i];
388                }
389    
390                return dotProduct;
391            }
392        }
393    }