NonLinearConjugateGradientOptimizer.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math4.legacy.optim.nonlinear.scalar.gradient;

  18. import org.apache.commons.math4.legacy.analysis.MultivariateFunction;
  19. import org.apache.commons.math4.legacy.exception.MathInternalError;
  20. import org.apache.commons.math4.legacy.exception.MathUnsupportedOperationException;
  21. import org.apache.commons.math4.legacy.exception.TooManyEvaluationsException;
  22. import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
  23. import org.apache.commons.math4.legacy.optim.ConvergenceChecker;
  24. import org.apache.commons.math4.legacy.optim.OptimizationData;
  25. import org.apache.commons.math4.legacy.optim.PointValuePair;
  26. import org.apache.commons.math4.legacy.optim.nonlinear.scalar.GoalType;
  27. import org.apache.commons.math4.legacy.optim.nonlinear.scalar.GradientMultivariateOptimizer;

  28. /**
  29.  * Non-linear conjugate gradient optimizer.
  30.  * <br>
  31.  * This class supports both the Fletcher-Reeves and the Polak-Ribière
  32.  * update formulas for the conjugate search directions.
  33.  * It also supports optional preconditioning.
  34.  * <br>
  35.  * Line search must be setup via {@link org.apache.commons.math4.legacy.optim.nonlinear.scalar.LineSearchTolerance}.
  36.  * <br>
  37.  * Constraints are not supported: the call to
  38.  * {@link #optimize(OptimizationData[]) optimize} will throw
  39.  * {@link MathUnsupportedOperationException} if bounds are passed to it.
  40.  *
  41.  * @since 2.0
  42.  */
  43. public class NonLinearConjugateGradientOptimizer
  44.     extends GradientMultivariateOptimizer {
  45.     /** Update formula for the beta parameter. */
  46.     private final Formula updateFormula;
  47.     /** Preconditioner (may be null). */
  48.     private final Preconditioner preconditioner;

  49.     /**
  50.      * Available choices of update formulas for the updating the parameter
  51.      * that is used to compute the successive conjugate search directions.
  52.      * For non-linear conjugate gradients, there are
  53.      * two formulas:
  54.      * <ul>
  55.      *   <li>Fletcher-Reeves formula</li>
  56.      *   <li>Polak-Ribière formula</li>
  57.      * </ul>
  58.      *
  59.      * On the one hand, the Fletcher-Reeves formula is guaranteed to converge
  60.      * if the start point is close enough of the optimum whether the
  61.      * Polak-Ribière formula may not converge in rare cases. On the
  62.      * other hand, the Polak-Ribière formula is often faster when it
  63.      * does converge. Polak-Ribière is often used.
  64.      *
  65.      * @since 2.0
  66.      */
  67.     public enum Formula {
  68.         /** Fletcher-Reeves formula. */
  69.         FLETCHER_REEVES,
  70.         /** Polak-Ribière formula. */
  71.         POLAK_RIBIERE
  72.     }

  73.     /**
  74.      * Constructor with default {@link IdentityPreconditioner preconditioner}.
  75.      *
  76.      * @param updateFormula formula to use for updating the &beta; parameter,
  77.      * must be one of {@link Formula#FLETCHER_REEVES} or
  78.      * {@link Formula#POLAK_RIBIERE}.
  79.      * @param checker Convergence checker.
  80.      *
  81.      * @since 3.3
  82.      */
  83.     public NonLinearConjugateGradientOptimizer(final Formula updateFormula,
  84.                                                ConvergenceChecker<PointValuePair> checker) {
  85.         this(updateFormula,
  86.              checker,
  87.              new IdentityPreconditioner());
  88.     }

  89.     /**
  90.      * @param updateFormula formula to use for updating the &beta; parameter,
  91.      * must be one of {@link Formula#FLETCHER_REEVES} or
  92.      * {@link Formula#POLAK_RIBIERE}.
  93.      * @param checker Convergence checker.
  94.      * @param preconditioner Preconditioner.
  95.      *
  96.      * @since 3.3
  97.      */
  98.     public NonLinearConjugateGradientOptimizer(final Formula updateFormula,
  99.                                                ConvergenceChecker<PointValuePair> checker,
  100.                                                final Preconditioner preconditioner) {
  101.         super(checker);

  102.         this.updateFormula = updateFormula;
  103.         this.preconditioner = preconditioner;
  104.     }

  105.     /**
  106.      * {@inheritDoc}
  107.      */
  108.     @Override
  109.     public PointValuePair optimize(OptimizationData... optData)
  110.         throws TooManyEvaluationsException {
  111.         // Set up base class and perform computation.
  112.         return super.optimize(optData);
  113.     }

  114.     /** {@inheritDoc} */
  115.     @Override
  116.     protected PointValuePair doOptimize() {
  117.         final ConvergenceChecker<PointValuePair> checker = getConvergenceChecker();
  118.         final double[] point = getStartPoint();
  119.         final GoalType goal = getGoalType();
  120.         final MultivariateFunction func = getObjectiveFunction();
  121.         final int n = point.length;
  122.         double[] r = computeObjectiveGradient(point);
  123.         if (goal == GoalType.MINIMIZE) {
  124.             for (int i = 0; i < n; i++) {
  125.                 r[i] = -r[i];
  126.             }
  127.         }

  128.         // Initial search direction.
  129.         double[] steepestDescent = preconditioner.precondition(point, r);
  130.         double[] searchDirection = steepestDescent.clone();

  131.         double delta = 0;
  132.         for (int i = 0; i < n; ++i) {
  133.             delta += r[i] * searchDirection[i];
  134.         }

  135.         createLineSearch();

  136.         PointValuePair current = null;
  137.         while (true) {
  138.             incrementIterationCount();

  139.             final double objective = func.value(point);
  140.             PointValuePair previous = current;
  141.             current = new PointValuePair(point, objective);
  142.             if (previous != null &&
  143.                 checker.converged(getIterations(), previous, current)) {
  144.                 // We have found an optimum.
  145.                 return current;
  146.             }

  147.             final double step = lineSearch(point, searchDirection).getPoint();

  148.             // Validate new point.
  149.             for (int i = 0; i < point.length; ++i) {
  150.                 point[i] += step * searchDirection[i];
  151.             }

  152.             r = computeObjectiveGradient(point);
  153.             if (goal == GoalType.MINIMIZE) {
  154.                 for (int i = 0; i < n; ++i) {
  155.                     r[i] = -r[i];
  156.                 }
  157.             }

  158.             // Compute beta.
  159.             final double deltaOld = delta;
  160.             final double[] newSteepestDescent = preconditioner.precondition(point, r);
  161.             delta = 0;
  162.             for (int i = 0; i < n; ++i) {
  163.                 delta += r[i] * newSteepestDescent[i];
  164.             }

  165.             final double beta;
  166.             switch (updateFormula) {
  167.             case FLETCHER_REEVES:
  168.                 beta = delta / deltaOld;
  169.                 break;
  170.             case POLAK_RIBIERE:
  171.                 double deltaMid = 0;
  172.                 for (int i = 0; i < r.length; ++i) {
  173.                     deltaMid += r[i] * steepestDescent[i];
  174.                 }
  175.                 beta = (delta - deltaMid) / deltaOld;
  176.                 break;
  177.             default:
  178.                 // Should never happen.
  179.                 throw new MathInternalError();
  180.             }
  181.             steepestDescent = newSteepestDescent;

  182.             // Compute conjugate search direction.
  183.             if (getIterations() % n == 0 ||
  184.                 beta < 0) {
  185.                 // Break conjugation: reset search direction.
  186.                 searchDirection = steepestDescent.clone();
  187.             } else {
  188.                 // Compute new conjugate search direction.
  189.                 for (int i = 0; i < n; ++i) {
  190.                     searchDirection[i] = steepestDescent[i] + beta * searchDirection[i];
  191.                 }
  192.             }
  193.         }
  194.     }

  195.     /**
  196.      * {@inheritDoc}
  197.      */
  198.     @Override
  199.     protected void parseOptimizationData(OptimizationData... optData) {
  200.         // Allow base class to register its own data.
  201.         super.parseOptimizationData(optData);

  202.         checkParameters();
  203.     }

  204.     /** Default identity preconditioner. */
  205.     public static class IdentityPreconditioner implements Preconditioner {
  206.         /** {@inheritDoc} */
  207.         @Override
  208.         public double[] precondition(double[] variables, double[] r) {
  209.             return r.clone();
  210.         }
  211.     }

  212.     // Class is not used anymore (cf. MATH-1092). However, it might
  213.     // be interesting to create a class similar to "LineSearch", but
  214.     // that will take advantage that the model's gradient is available.
  215. //     /**
  216. //      * Internal class for line search.
  217. //      * <p>
  218. //      * The function represented by this class is the dot product of
  219. //      * the objective function gradient and the search direction. Its
  220. //      * value is zero when the gradient is orthogonal to the search
  221. //      * direction, i.e. when the objective function value is a local
  222. //      * extremum along the search direction.
  223. //      * </p>
  224. //      */
  225. //     private class LineSearchFunction implements UnivariateFunction {
  226. //         /** Current point. */
  227. //         private final double[] currentPoint;
  228. //         /** Search direction. */
  229. //         private final double[] searchDirection;

  230. //         /**
  231. //          * @param point Current point.
  232. //          * @param direction Search direction.
  233. //          */
  234. //         public LineSearchFunction(double[] point,
  235. //                                   double[] direction) {
  236. //             currentPoint = point.clone();
  237. //             searchDirection = direction.clone();
  238. //         }

  239. //         /** {@inheritDoc} */
  240. //         public double value(double x) {
  241. //             // current point in the search direction
  242. //             final double[] shiftedPoint = currentPoint.clone();
  243. //             for (int i = 0; i < shiftedPoint.length; ++i) {
  244. //                 shiftedPoint[i] += x * searchDirection[i];
  245. //             }

  246. //             // gradient of the objective function
  247. //             final double[] gradient = computeObjectiveGradient(shiftedPoint);

  248. //             // dot product with the search direction
  249. //             double dotProduct = 0;
  250. //             for (int i = 0; i < gradient.length; ++i) {
  251. //                 dotProduct += gradient[i] * searchDirection[i];
  252. //             }

  253. //             return dotProduct;
  254. //         }
  255. //     }

  256.     /**
  257.      * @throws MathUnsupportedOperationException if bounds were passed to the
  258.      * {@link #optimize(OptimizationData[]) optimize} method.
  259.      */
  260.     private void checkParameters() {
  261.         if (getLowerBound() != null ||
  262.             getUpperBound() != null) {
  263.             throw new MathUnsupportedOperationException(LocalizedFormats.CONSTRAINT);
  264.         }
  265.     }
  266. }