View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.math4.legacy.optim.nonlinear.scalar;
18  
19  import org.apache.commons.math4.legacy.analysis.MultivariateVectorFunction;
20  import org.apache.commons.math4.legacy.exception.TooManyEvaluationsException;
21  import org.apache.commons.math4.legacy.optim.ConvergenceChecker;
22  import org.apache.commons.math4.legacy.optim.OptimizationData;
23  import org.apache.commons.math4.legacy.optim.PointValuePair;
24  
25  /**
26   * Base class for implementing optimizers for multivariate scalar
27   * differentiable functions.
28   * It contains boiler-plate code for dealing with gradient evaluation.
29   *
30   * @since 3.1
31   */
32  public abstract class GradientMultivariateOptimizer
33      extends MultivariateOptimizer {
34      /**
35       * Gradient of the objective function.
36       */
37      private MultivariateVectorFunction gradient;
38  
39      /**
40       * @param checker Convergence checker.
41       */
42      protected GradientMultivariateOptimizer(ConvergenceChecker<PointValuePair> checker) {
43          super(checker);
44      }
45  
46      /**
47       * Compute the gradient vector.
48       *
49       * @param params Point at which the gradient must be evaluated.
50       * @return the gradient at the specified point.
51       */
52      protected double[] computeObjectiveGradient(final double[] params) {
53          return gradient.value(params);
54      }
55  
56      /**
57       * {@inheritDoc}
58       *
59       * @param optData Optimization data. In addition to those documented in
60       * {@link MultivariateOptimizer#parseOptimizationData(OptimizationData[])
61       * MultivariateOptimizer}, this method will register the following data:
62       * <ul>
63       *  <li>{@link ObjectiveFunctionGradient}</li>
64       * </ul>
65       * @return {@inheritDoc}
66       * @throws TooManyEvaluationsException if the maximal number of
67       * evaluations (of the objective function) is exceeded.
68       */
69      @Override
70      public PointValuePair optimize(OptimizationData... optData)
71          throws TooManyEvaluationsException {
72          // Set up base class and perform computation.
73          return super.optimize(optData);
74      }
75  
76      /**
77       * Scans the list of (required and optional) optimization data that
78       * characterize the problem.
79       *
80       * @param optData Optimization data.
81       * The following data will be looked for:
82       * <ul>
83       *  <li>{@link ObjectiveFunctionGradient}</li>
84       * </ul>
85       */
86      @Override
87      protected void parseOptimizationData(OptimizationData... optData) {
88          // Allow base class to register its own data.
89          super.parseOptimizationData(optData);
90  
91          // The existing values (as set by the previous call) are reused if
92          // not provided in the argument list.
93          for (OptimizationData data : optData) {
94              if  (data instanceof ObjectiveFunctionGradient) {
95                  gradient = ((ObjectiveFunctionGradient) data).getObjectiveFunctionGradient();
96                  // If more data must be parsed, this statement _must_ be
97                  // changed to "continue".
98                  break;
99              }
100         }
101     }
102 }