View Javadoc

1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math3.optim.nonlinear.scalar.gradient;
19  
20  import org.apache.commons.math3.analysis.UnivariateFunction;
21  import org.apache.commons.math3.analysis.solvers.BrentSolver;
22  import org.apache.commons.math3.analysis.solvers.UnivariateSolver;
23  import org.apache.commons.math3.exception.MathInternalError;
24  import org.apache.commons.math3.exception.MathIllegalStateException;
25  import org.apache.commons.math3.exception.TooManyEvaluationsException;
26  import org.apache.commons.math3.exception.MathUnsupportedOperationException;
27  import org.apache.commons.math3.exception.util.LocalizedFormats;
28  import org.apache.commons.math3.optim.OptimizationData;
29  import org.apache.commons.math3.optim.PointValuePair;
30  import org.apache.commons.math3.optim.ConvergenceChecker;
31  import org.apache.commons.math3.optim.nonlinear.scalar.GoalType;
32  import org.apache.commons.math3.optim.nonlinear.scalar.GradientMultivariateOptimizer;
33  import org.apache.commons.math3.util.FastMath;
34  
35  /**
36   * Non-linear conjugate gradient optimizer.
37   * <br/>
38   * This class supports both the Fletcher-Reeves and the Polak-Ribière
39   * update formulas for the conjugate search directions.
40   * It also supports optional preconditioning.
41   * <br/>
42   * Constraints are not supported: the call to
43   * {@link #optimize(OptimizationData[]) optimize} will throw
44   * {@link MathUnsupportedOperationException} if bounds are passed to it.
45   *
46   * @version $Id: NonLinearConjugateGradientOptimizer.java 1462503 2013-03-29 15:48:27Z luc $
47   * @since 2.0
48   */
49  public class NonLinearConjugateGradientOptimizer
50      extends GradientMultivariateOptimizer {
51      /** Update formula for the beta parameter. */
52      private final Formula updateFormula;
53      /** Preconditioner (may be null). */
54      private final Preconditioner preconditioner;
55      /** solver to use in the line search (may be null). */
56      private final UnivariateSolver solver;
57      /** Initial step used to bracket the optimum in line search. */
58      private double initialStep = 1;
59  
60      /**
61       * Constructor with default {@link BrentSolver line search solver} and
62       * {@link IdentityPreconditioner preconditioner}.
63       *
64       * @param updateFormula formula to use for updating the &beta; parameter,
65       * must be one of {@link Formula#FLETCHER_REEVES} or
66       * {@link Formula#POLAK_RIBIERE}.
67       * @param checker Convergence checker.
68       */
69      public NonLinearConjugateGradientOptimizer(final Formula updateFormula,
70                                                 ConvergenceChecker<PointValuePair> checker) {
71          this(updateFormula,
72               checker,
73               new BrentSolver(),
74               new IdentityPreconditioner());
75      }
76  
77      /**
78       * Available choices of update formulas for the updating the parameter
79       * that is used to compute the successive conjugate search directions.
80       * For non-linear conjugate gradients, there are
81       * two formulas:
82       * <ul>
83       *   <li>Fletcher-Reeves formula</li>
84       *   <li>Polak-Ribière formula</li>
85       * </ul>
86       *
87       * On the one hand, the Fletcher-Reeves formula is guaranteed to converge
88       * if the start point is close enough of the optimum whether the
89       * Polak-Ribière formula may not converge in rare cases. On the
90       * other hand, the Polak-Ribière formula is often faster when it
91       * does converge. Polak-Ribière is often used.
92       *
93       * @since 2.0
94       */
95      public static enum Formula {
96          /** Fletcher-Reeves formula. */
97          FLETCHER_REEVES,
98          /** Polak-Ribière formula. */
99          POLAK_RIBIERE
100     }
101 
102     /**
103      * The initial step is a factor with respect to the search direction
104      * (which itself is roughly related to the gradient of the function).
105      * <br/>
106      * It is used to find an interval that brackets the optimum in line
107      * search.
108      *
109      * @since 3.1
110      */
111     public static class BracketingStep implements OptimizationData {
112         /** Initial step. */
113         private final double initialStep;
114 
115         /**
116          * @param step Initial step for the bracket search.
117          */
118         public BracketingStep(double step) {
119             initialStep = step;
120         }
121 
122         /**
123          * Gets the initial step.
124          *
125          * @return the initial step.
126          */
127         public double getBracketingStep() {
128             return initialStep;
129         }
130     }
131 
132     /**
133      * Constructor with default {@link IdentityPreconditioner preconditioner}.
134      *
135      * @param updateFormula formula to use for updating the &beta; parameter,
136      * must be one of {@link Formula#FLETCHER_REEVES} or
137      * {@link Formula#POLAK_RIBIERE}.
138      * @param checker Convergence checker.
139      * @param lineSearchSolver Solver to use during line search.
140      */
141     public NonLinearConjugateGradientOptimizer(final Formula updateFormula,
142                                                ConvergenceChecker<PointValuePair> checker,
143                                                final UnivariateSolver lineSearchSolver) {
144         this(updateFormula,
145              checker,
146              lineSearchSolver,
147              new IdentityPreconditioner());
148     }
149 
150     /**
151      * @param updateFormula formula to use for updating the &beta; parameter,
152      * must be one of {@link Formula#FLETCHER_REEVES} or
153      * {@link Formula#POLAK_RIBIERE}.
154      * @param checker Convergence checker.
155      * @param lineSearchSolver Solver to use during line search.
156      * @param preconditioner Preconditioner.
157      */
158     public NonLinearConjugateGradientOptimizer(final Formula updateFormula,
159                                                ConvergenceChecker<PointValuePair> checker,
160                                                final UnivariateSolver lineSearchSolver,
161                                                final Preconditioner preconditioner) {
162         super(checker);
163 
164         this.updateFormula = updateFormula;
165         solver = lineSearchSolver;
166         this.preconditioner = preconditioner;
167         initialStep = 1;
168     }
169 
170     /**
171      * {@inheritDoc}
172      *
173      * @param optData Optimization data. In addition to those documented in
174      * {@link GradientMultivariateOptimizer#parseOptimizationData(OptimizationData[])
175      * GradientMultivariateOptimizer}, this method will register the following data:
176      * <ul>
177      *  <li>{@link BracketingStep}</li>
178      * </ul>
179      * @return {@inheritDoc}
180      * @throws TooManyEvaluationsException if the maximal number of
181      * evaluations (of the objective function) is exceeded.
182      */
183     @Override
184     public PointValuePair optimize(OptimizationData... optData)
185         throws TooManyEvaluationsException {
186         // Set up base class and perform computation.
187         return super.optimize(optData);
188     }
189 
190     /** {@inheritDoc} */
191     @Override
192     protected PointValuePair doOptimize() {
193         final ConvergenceChecker<PointValuePair> checker = getConvergenceChecker();
194         final double[] point = getStartPoint();
195         final GoalType goal = getGoalType();
196         final int n = point.length;
197         double[] r = computeObjectiveGradient(point);
198         if (goal == GoalType.MINIMIZE) {
199             for (int i = 0; i < n; i++) {
200                 r[i] = -r[i];
201             }
202         }
203 
204         // Initial search direction.
205         double[] steepestDescent = preconditioner.precondition(point, r);
206         double[] searchDirection = steepestDescent.clone();
207 
208         double delta = 0;
209         for (int i = 0; i < n; ++i) {
210             delta += r[i] * searchDirection[i];
211         }
212 
213         PointValuePair current = null;
214         int maxEval = getMaxEvaluations();
215         while (true) {
216             incrementIterationCount();
217 
218             final double objective = computeObjectiveValue(point);
219             PointValuePair previous = current;
220             current = new PointValuePair(point, objective);
221             if (previous != null && checker.converged(getIterations(), previous, current)) {
222                 // We have found an optimum.
223                 return current;
224             }
225 
226             // Find the optimal step in the search direction.
227             final UnivariateFunction lsf = new LineSearchFunction(point, searchDirection);
228             final double uB = findUpperBound(lsf, 0, initialStep);
229             // XXX Last parameters is set to a value close to zero in order to
230             // work around the divergence problem in the "testCircleFitting"
231             // unit test (see MATH-439).
232             final double step = solver.solve(maxEval, lsf, 0, uB, 1e-15);
233             maxEval -= solver.getEvaluations(); // Subtract used up evaluations.
234 
235             // Validate new point.
236             for (int i = 0; i < point.length; ++i) {
237                 point[i] += step * searchDirection[i];
238             }
239 
240             r = computeObjectiveGradient(point);
241             if (goal == GoalType.MINIMIZE) {
242                 for (int i = 0; i < n; ++i) {
243                     r[i] = -r[i];
244                 }
245             }
246 
247             // Compute beta.
248             final double deltaOld = delta;
249             final double[] newSteepestDescent = preconditioner.precondition(point, r);
250             delta = 0;
251             for (int i = 0; i < n; ++i) {
252                 delta += r[i] * newSteepestDescent[i];
253             }
254 
255             final double beta;
256             switch (updateFormula) {
257             case FLETCHER_REEVES:
258                 beta = delta / deltaOld;
259                 break;
260             case POLAK_RIBIERE:
261                 double deltaMid = 0;
262                 for (int i = 0; i < r.length; ++i) {
263                     deltaMid += r[i] * steepestDescent[i];
264                 }
265                 beta = (delta - deltaMid) / deltaOld;
266                 break;
267             default:
268                 // Should never happen.
269                 throw new MathInternalError();
270             }
271             steepestDescent = newSteepestDescent;
272 
273             // Compute conjugate search direction.
274             if (getIterations() % n == 0 ||
275                 beta < 0) {
276                 // Break conjugation: reset search direction.
277                 searchDirection = steepestDescent.clone();
278             } else {
279                 // Compute new conjugate search direction.
280                 for (int i = 0; i < n; ++i) {
281                     searchDirection[i] = steepestDescent[i] + beta * searchDirection[i];
282                 }
283             }
284         }
285     }
286 
287     /**
288      * Scans the list of (required and optional) optimization data that
289      * characterize the problem.
290      *
291      * @param optData Optimization data.
292      * The following data will be looked for:
293      * <ul>
294      *  <li>{@link BracketingStep}</li>
295      * </ul>
296      */
297     @Override
298     protected void parseOptimizationData(OptimizationData... optData) {
299         // Allow base class to register its own data.
300         super.parseOptimizationData(optData);
301 
302         // The existing values (as set by the previous call) are reused if
303         // not provided in the argument list.
304         for (OptimizationData data : optData) {
305             if  (data instanceof BracketingStep) {
306                 initialStep = ((BracketingStep) data).getBracketingStep();
307                 // If more data must be parsed, this statement _must_ be
308                 // changed to "continue".
309                 break;
310             }
311         }
312 
313         checkParameters();
314     }
315 
316     /**
317      * Finds the upper bound b ensuring bracketing of a root between a and b.
318      *
319      * @param f function whose root must be bracketed.
320      * @param a lower bound of the interval.
321      * @param h initial step to try.
322      * @return b such that f(a) and f(b) have opposite signs.
323      * @throws MathIllegalStateException if no bracket can be found.
324      */
325     private double findUpperBound(final UnivariateFunction f,
326                                   final double a, final double h) {
327         final double yA = f.value(a);
328         double yB = yA;
329         for (double step = h; step < Double.MAX_VALUE; step *= FastMath.max(2, yA / yB)) {
330             final double b = a + step;
331             yB = f.value(b);
332             if (yA * yB <= 0) {
333                 return b;
334             }
335         }
336         throw new MathIllegalStateException(LocalizedFormats.UNABLE_TO_BRACKET_OPTIMUM_IN_LINE_SEARCH);
337     }
338 
339     /** Default identity preconditioner. */
340     public static class IdentityPreconditioner implements Preconditioner {
341         /** {@inheritDoc} */
342         public double[] precondition(double[] variables, double[] r) {
343             return r.clone();
344         }
345     }
346 
347     /**
348      * Internal class for line search.
349      * <p>
350      * The function represented by this class is the dot product of
351      * the objective function gradient and the search direction. Its
352      * value is zero when the gradient is orthogonal to the search
353      * direction, i.e. when the objective function value is a local
354      * extremum along the search direction.
355      * </p>
356      */
357     private class LineSearchFunction implements UnivariateFunction {
358         /** Current point. */
359         private final double[] currentPoint;
360         /** Search direction. */
361         private final double[] searchDirection;
362 
363         /**
364          * @param point Current point.
365          * @param direction Search direction.
366          */
367         public LineSearchFunction(double[] point,
368                                   double[] direction) {
369             currentPoint = point.clone();
370             searchDirection = direction.clone();
371         }
372 
373         /** {@inheritDoc} */
374         public double value(double x) {
375             // current point in the search direction
376             final double[] shiftedPoint = currentPoint.clone();
377             for (int i = 0; i < shiftedPoint.length; ++i) {
378                 shiftedPoint[i] += x * searchDirection[i];
379             }
380 
381             // gradient of the objective function
382             final double[] gradient = computeObjectiveGradient(shiftedPoint);
383 
384             // dot product with the search direction
385             double dotProduct = 0;
386             for (int i = 0; i < gradient.length; ++i) {
387                 dotProduct += gradient[i] * searchDirection[i];
388             }
389 
390             return dotProduct;
391         }
392     }
393 
394     /**
395      * @throws MathUnsupportedOperationException if bounds were passed to the
396      * {@link #optimize(OptimizationData[]) optimize} method.
397      */
398     private void checkParameters() {
399         if (getLowerBound() != null ||
400             getUpperBound() != null) {
401             throw new MathUnsupportedOperationException(LocalizedFormats.CONSTRAINT);
402         }
403     }
404 }