001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    
018    package org.apache.commons.math.optimization.general;
019    
020    import org.apache.commons.math.exception.MathIllegalStateException;
021    import org.apache.commons.math.analysis.UnivariateRealFunction;
022    import org.apache.commons.math.analysis.solvers.BrentSolver;
023    import org.apache.commons.math.analysis.solvers.UnivariateRealSolver;
024    import org.apache.commons.math.exception.util.LocalizedFormats;
025    import org.apache.commons.math.optimization.GoalType;
026    import org.apache.commons.math.optimization.RealPointValuePair;
027    import org.apache.commons.math.optimization.SimpleScalarValueChecker;
028    import org.apache.commons.math.optimization.ConvergenceChecker;
029    import org.apache.commons.math.util.FastMath;
030    
031    /**
032     * Non-linear conjugate gradient optimizer.
033     * <p>
034     * This class supports both the Fletcher-Reeves and the Polak-Ribi&egrave;re
035     * update formulas for the conjugate search directions. It also supports
036     * optional preconditioning.
037     * </p>
038     *
039     * @version $Id: NonLinearConjugateGradientOptimizer.java 1178805 2011-10-04 14:13:46Z luc $
040     * @since 2.0
041     *
042     */
043    public class NonLinearConjugateGradientOptimizer
044        extends AbstractScalarDifferentiableOptimizer {
045        /** Update formula for the beta parameter. */
046        private final ConjugateGradientFormula updateFormula;
047        /** Preconditioner (may be null). */
048        private final Preconditioner preconditioner;
049        /** solver to use in the line search (may be null). */
050        private final UnivariateRealSolver solver;
051        /** Initial step used to bracket the optimum in line search. */
052        private double initialStep;
053        /** Current point. */
054        private double[] point;
055    
056        /**
057         * Constructor with default {@link SimpleScalarValueChecker checker},
058         * {@link BrentSolver line search solver} and
059         * {@link IdentityPreconditioner preconditioner}.
060         *
061         * @param updateFormula formula to use for updating the &beta; parameter,
062         * must be one of {@link ConjugateGradientFormula#FLETCHER_REEVES} or {@link
063         * ConjugateGradientFormula#POLAK_RIBIERE}.
064         */
065        public NonLinearConjugateGradientOptimizer(final ConjugateGradientFormula updateFormula) {
066            this(updateFormula,
067                 new SimpleScalarValueChecker());
068        }
069    
070        /**
071         * Constructor with default {@link BrentSolver line search solver} and
072         * {@link IdentityPreconditioner preconditioner}.
073         *
074         * @param updateFormula formula to use for updating the &beta; parameter,
075         * must be one of {@link ConjugateGradientFormula#FLETCHER_REEVES} or {@link
076         * ConjugateGradientFormula#POLAK_RIBIERE}.
077         * @param checker Convergence checker.
078         */
079        public NonLinearConjugateGradientOptimizer(final ConjugateGradientFormula updateFormula,
080                                                   ConvergenceChecker<RealPointValuePair> checker) {
081            this(updateFormula,
082                 checker,
083                 new BrentSolver(),
084                 new IdentityPreconditioner());
085        }
086    
087    
088        /**
089         * Constructor with default {@link IdentityPreconditioner preconditioner}.
090         *
091         * @param updateFormula formula to use for updating the &beta; parameter,
092         * must be one of {@link ConjugateGradientFormula#FLETCHER_REEVES} or {@link
093         * ConjugateGradientFormula#POLAK_RIBIERE}.
094         * @param checker Convergence checker.
095         * @param lineSearchSolver Solver to use during line search.
096         */
097        public NonLinearConjugateGradientOptimizer(final ConjugateGradientFormula updateFormula,
098                                                   ConvergenceChecker<RealPointValuePair> checker,
099                                                   final UnivariateRealSolver lineSearchSolver) {
100            this(updateFormula,
101                 checker,
102                 lineSearchSolver,
103                 new IdentityPreconditioner());
104        }
105    
106        /**
107         * @param updateFormula formula to use for updating the &beta; parameter,
108         * must be one of {@link ConjugateGradientFormula#FLETCHER_REEVES} or {@link
109         * ConjugateGradientFormula#POLAK_RIBIERE}.
110         * @param checker Convergence checker.
111         * @param lineSearchSolver Solver to use during line search.
112         * @param preconditioner Preconditioner.
113         */
114        public NonLinearConjugateGradientOptimizer(final ConjugateGradientFormula updateFormula,
115                                                   ConvergenceChecker<RealPointValuePair> checker,
116                                                   final UnivariateRealSolver lineSearchSolver,
117                                                   final Preconditioner preconditioner) {
118            super(checker);
119    
120            this.updateFormula = updateFormula;
121            solver = lineSearchSolver;
122            this.preconditioner = preconditioner;
123            initialStep = 1.0;
124        }
125    
126        /**
127         * Set the initial step used to bracket the optimum in line search.
128         * <p>
129         * The initial step is a factor with respect to the search direction,
130         * which itself is roughly related to the gradient of the function
131         * </p>
132         * @param initialStep initial step used to bracket the optimum in line search,
133         * if a non-positive value is used, the initial step is reset to its
134         * default value of 1.0
135         */
136        public void setInitialStep(final double initialStep) {
137            if (initialStep <= 0) {
138                this.initialStep = 1.0;
139            } else {
140                this.initialStep = initialStep;
141            }
142        }
143    
144        /** {@inheritDoc} */
145        @Override
146        protected RealPointValuePair doOptimize() {
147            final ConvergenceChecker<RealPointValuePair> checker = getConvergenceChecker();
148            point = getStartPoint();
149            final GoalType goal = getGoalType();
150            final int n = point.length;
151            double[] r = computeObjectiveGradient(point);
152            if (goal == GoalType.MINIMIZE) {
153                for (int i = 0; i < n; ++i) {
154                    r[i] = -r[i];
155                }
156            }
157    
158            // Initial search direction.
159            double[] steepestDescent = preconditioner.precondition(point, r);
160            double[] searchDirection = steepestDescent.clone();
161    
162            double delta = 0;
163            for (int i = 0; i < n; ++i) {
164                delta += r[i] * searchDirection[i];
165            }
166    
167            RealPointValuePair current = null;
168            int iter = 0;
169            int maxEval = getMaxEvaluations();
170            while (true) {
171                ++iter;
172    
173                final double objective = computeObjectiveValue(point);
174                RealPointValuePair previous = current;
175                current = new RealPointValuePair(point, objective);
176                if (previous != null) {
177                    if (checker.converged(iter, previous, current)) {
178                        // We have found an optimum.
179                        return current;
180                    }
181                }
182    
183                // Find the optimal step in the search direction.
184                final UnivariateRealFunction lsf = new LineSearchFunction(searchDirection);
185                final double uB = findUpperBound(lsf, 0, initialStep);
186                // XXX Last parameters is set to a value close to zero in order to
187                // work around the divergence problem in the "testCircleFitting"
188                // unit test (see MATH-439).
189                final double step = solver.solve(maxEval, lsf, 0, uB, 1e-15);
190                maxEval -= solver.getEvaluations(); // Subtract used up evaluations.
191    
192                // Validate new point.
193                for (int i = 0; i < point.length; ++i) {
194                    point[i] += step * searchDirection[i];
195                }
196    
197                r = computeObjectiveGradient(point);
198                if (goal == GoalType.MINIMIZE) {
199                    for (int i = 0; i < n; ++i) {
200                        r[i] = -r[i];
201                    }
202                }
203    
204                // Compute beta.
205                final double deltaOld = delta;
206                final double[] newSteepestDescent = preconditioner.precondition(point, r);
207                delta = 0;
208                for (int i = 0; i < n; ++i) {
209                    delta += r[i] * newSteepestDescent[i];
210                }
211    
212                final double beta;
213                if (updateFormula == ConjugateGradientFormula.FLETCHER_REEVES) {
214                    beta = delta / deltaOld;
215                } else {
216                    double deltaMid = 0;
217                    for (int i = 0; i < r.length; ++i) {
218                        deltaMid += r[i] * steepestDescent[i];
219                    }
220                    beta = (delta - deltaMid) / deltaOld;
221                }
222                steepestDescent = newSteepestDescent;
223    
224                // Compute conjugate search direction.
225                if (iter % n == 0 ||
226                    beta < 0) {
227                    // Break conjugation: reset search direction.
228                    searchDirection = steepestDescent.clone();
229                } else {
230                    // Compute new conjugate search direction.
231                    for (int i = 0; i < n; ++i) {
232                        searchDirection[i] = steepestDescent[i] + beta * searchDirection[i];
233                    }
234                }
235            }
236        }
237    
238        /**
239         * Find the upper bound b ensuring bracketing of a root between a and b.
240         *
241         * @param f function whose root must be bracketed.
242         * @param a lower bound of the interval.
243         * @param h initial step to try.
244         * @return b such that f(a) and f(b) have opposite signs.
245         * @throws MathIllegalStateException if no bracket can be found.
246         */
247        private double findUpperBound(final UnivariateRealFunction f,
248                                      final double a, final double h) {
249            final double yA = f.value(a);
250            double yB = yA;
251            for (double step = h; step < Double.MAX_VALUE; step *= FastMath.max(2, yA / yB)) {
252                final double b = a + step;
253                yB = f.value(b);
254                if (yA * yB <= 0) {
255                    return b;
256                }
257            }
258            throw new MathIllegalStateException(LocalizedFormats.UNABLE_TO_BRACKET_OPTIMUM_IN_LINE_SEARCH);
259        }
260    
261        /** Default identity preconditioner. */
262        public static class IdentityPreconditioner implements Preconditioner {
263    
264            /** {@inheritDoc} */
265            public double[] precondition(double[] variables, double[] r) {
266                return r.clone();
267            }
268        }
269    
270        /** Internal class for line search.
271         * <p>
272         * The function represented by this class is the dot product of
273         * the objective function gradient and the search direction. Its
274         * value is zero when the gradient is orthogonal to the search
275         * direction, i.e. when the objective function value is a local
276         * extremum along the search direction.
277         * </p>
278         */
279        private class LineSearchFunction implements UnivariateRealFunction {
280            /** Search direction. */
281            private final double[] searchDirection;
282    
283            /** Simple constructor.
284             * @param searchDirection search direction
285             */
286            public LineSearchFunction(final double[] searchDirection) {
287                this.searchDirection = searchDirection;
288            }
289    
290            /** {@inheritDoc} */
291            public double value(double x) {
292                // current point in the search direction
293                final double[] shiftedPoint = point.clone();
294                for (int i = 0; i < shiftedPoint.length; ++i) {
295                    shiftedPoint[i] += x * searchDirection[i];
296                }
297    
298                // gradient of the objective function
299                final double[] gradient = computeObjectiveGradient(shiftedPoint);
300    
301                // dot product with the search direction
302                double dotProduct = 0;
303                for (int i = 0; i < gradient.length; ++i) {
304                    dotProduct += gradient[i] * searchDirection[i];
305                }
306    
307                return dotProduct;
308            }
309        }
310    }