MultivariateOptimizer.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.math4.legacy.optim.nonlinear.scalar;

  18. import org.apache.commons.math4.legacy.analysis.MultivariateFunction;
  19. import org.apache.commons.math4.legacy.analysis.UnivariateFunction;
  20. import org.apache.commons.math4.legacy.optim.BaseMultivariateOptimizer;
  21. import org.apache.commons.math4.legacy.optim.ConvergenceChecker;
  22. import org.apache.commons.math4.legacy.optim.OptimizationData;
  23. import org.apache.commons.math4.legacy.optim.PointValuePair;
  24. import org.apache.commons.math4.legacy.optim.MaxEval;
  25. import org.apache.commons.math4.legacy.optim.univariate.BracketFinder;
  26. import org.apache.commons.math4.legacy.optim.univariate.BrentOptimizer;
  27. import org.apache.commons.math4.legacy.optim.univariate.SearchInterval;
  28. import org.apache.commons.math4.legacy.optim.univariate.SimpleUnivariateValueChecker;
  29. import org.apache.commons.math4.legacy.optim.univariate.UnivariateObjectiveFunction;
  30. import org.apache.commons.math4.legacy.optim.univariate.UnivariateOptimizer;
  31. import org.apache.commons.math4.legacy.optim.univariate.UnivariatePointValuePair;

  32. /**
  33.  * Base class for a multivariate scalar function optimizer.
  34.  *
  35.  * @since 3.1
  36.  */
  37. public abstract class MultivariateOptimizer
  38.     extends BaseMultivariateOptimizer<PointValuePair> {
  39.     /** Objective function. */
  40.     private MultivariateFunction function;
  41.     /** Type of optimization. */
  42.     private GoalType goal;
  43.     /** Line search relative tolerance. */
  44.     private double lineSearchRelativeTolerance = 1e-8;
  45.     /** Line search absolute tolerance. */
  46.     private double lineSearchAbsoluteTolerance = 1e-8;
  47.     /** Line serach initial bracketing range. */
  48.     private double lineSearchInitialBracketingRange = 1d;
  49.     /** Line search. */
  50.     private LineSearch lineSearch;

  51.     /**
  52.      * @param checker Convergence checker.
  53.      */
  54.     protected MultivariateOptimizer(ConvergenceChecker<PointValuePair> checker) {
  55.         super(checker);
  56.     }

  57.     /**
  58.      * {@inheritDoc}
  59.      *
  60.      * @param optData Optimization data. In addition to those documented in
  61.      * {@link BaseMultivariateOptimizer#parseOptimizationData(OptimizationData[])
  62.      * BaseMultivariateOptimizer}, this method will register the following data:
  63.      * <ul>
  64.      *  <li>{@link ObjectiveFunction}</li>
  65.      *  <li>{@link GoalType}</li>
  66.      *  <li>{@link LineSearchTolerance}</li>
  67.      * </ul>
  68.      * @return {@inheritDoc}
  69.      * @throws org.apache.commons.math4.legacy.exception.TooManyEvaluationsException
  70.      * if the maximal number of evaluations is exceeded.
  71.      */
  72.     @Override
  73.     public PointValuePair optimize(OptimizationData... optData) {
  74.         // Set up base class and perform computation.
  75.         return super.optimize(optData);
  76.     }

  77.     /**
  78.      * Scans the list of (required and optional) optimization data that
  79.      * characterize the problem.
  80.      *
  81.      * @param optData Optimization data.
  82.      * The following data will be looked for:
  83.      * <ul>
  84.      *  <li>{@link ObjectiveFunction}</li>
  85.      *  <li>{@link GoalType}</li>
  86.      *  <li>{@link LineSearchTolerance}</li>
  87.      * </ul>
  88.      */
  89.     @Override
  90.     protected void parseOptimizationData(OptimizationData... optData) {
  91.         // Allow base class to register its own data.
  92.         super.parseOptimizationData(optData);

  93.         // The existing values (as set by the previous call) are reused if
  94.         // not provided in the argument list.
  95.         for (OptimizationData data : optData) {
  96.             if (data instanceof GoalType) {
  97.                 goal = (GoalType) data;
  98.                 continue;
  99.             }
  100.             if (data instanceof ObjectiveFunction) {
  101.                 final MultivariateFunction delegate = ((ObjectiveFunction) data).getObjectiveFunction();
  102.                 function = new MultivariateFunction() {
  103.                         @Override
  104.                         public double value(double[] point) {
  105.                             incrementEvaluationCount();
  106.                             return delegate.value(point);
  107.                         }
  108.                     };
  109.                 continue;
  110.             }
  111.             if (data instanceof LineSearchTolerance) {
  112.                 final LineSearchTolerance tol = (LineSearchTolerance) data;
  113.                 lineSearchRelativeTolerance = tol.getRelativeTolerance();
  114.                 lineSearchAbsoluteTolerance = tol.getAbsoluteTolerance();
  115.                 lineSearchInitialBracketingRange = tol.getInitialBracketingRange();
  116.                 continue;
  117.             }
  118.         }
  119.     }

  120.     /**
  121.      * Intantiate the line search implementation.
  122.      */
  123.     protected void createLineSearch() {
  124.         lineSearch = new LineSearch(this,
  125.                                     lineSearchRelativeTolerance,
  126.                                     lineSearchAbsoluteTolerance,
  127.                                     lineSearchInitialBracketingRange);
  128.     }

  129.     /**
  130.      * Finds the number {@code alpha} that optimizes
  131.      * {@code f(startPoint + alpha * direction)}.
  132.      *
  133.      * @param startPoint Starting point.
  134.      * @param direction Search direction.
  135.      * @return the optimum.
  136.      * @throws org.apache.commons.math4.legacy.exception.TooManyEvaluationsException
  137.      * if the number of evaluations is exceeded.
  138.      */
  139.     protected UnivariatePointValuePair lineSearch(final double[] startPoint,
  140.                                                   final double[] direction) {
  141.         return lineSearch.search(startPoint, direction);
  142.     }

  143.     /**
  144.      * @return the optimization type.
  145.      */
  146.     protected GoalType getGoalType() {
  147.         return goal;
  148.     }

  149.     /**
  150.      * @return a wrapper that delegates to the user-supplied function,
  151.      * and counts the number of evaluations.
  152.      */
  153.     protected MultivariateFunction getObjectiveFunction() {
  154.         return function;
  155.     }

  156.     /**
  157.      * Computes the objective function value.
  158.      * This method <em>must</em> be called by subclasses to enforce the
  159.      * evaluation counter limit.
  160.      *
  161.      * @param params Point at which the objective function must be evaluated.
  162.      * @return the objective function value at the specified point.
  163.      * @throws org.apache.commons.math4.legacy.exception.TooManyEvaluationsException
  164.      * if the maximal number of evaluations is exceeded.
  165.      *
  166.      * @deprecated Use {@link #getObjectiveFunction()} instead.
  167.      */
  168.     @Deprecated
  169.     public double computeObjectiveValue(double[] params) {
  170.         return function.value(params);
  171.     }

  172.     /**
  173.      * Find the minimum of the objective function along a given direction.
  174.      *
  175.      * @since 4.0
  176.      */
  177.     private static class LineSearch {
  178.         /**
  179.          * Value that will pass the precondition check for {@link BrentOptimizer}
  180.          * but will not pass the convergence check, so that the custom checker
  181.          * will always decide when to stop the line search.
  182.          */
  183.         private static final double REL_TOL_UNUSED = 1e-15;
  184.         /**
  185.          * Value that will pass the precondition check for {@link BrentOptimizer}
  186.          * but will not pass the convergence check, so that the custom checker
  187.          * will always decide when to stop the line search.
  188.          */
  189.         private static final double ABS_TOL_UNUSED = Double.MIN_VALUE;
  190.         /**
  191.          * Optimizer used for line search.
  192.          */
  193.         private final UnivariateOptimizer lineOptimizer;
  194.         /**
  195.          * Automatic bracketing.
  196.          */
  197.         private final BracketFinder bracket = new BracketFinder();
  198.         /**
  199.          * Extent of the initial interval used to find an interval that
  200.          * brackets the optimum.
  201.          */
  202.         private final double initialBracketingRange;
  203.         /**
  204.          * Optimizer on behalf of which the line search must be performed.
  205.          */
  206.         private final MultivariateOptimizer mainOptimizer;

  207.         /**
  208.          * The {@code BrentOptimizer} default stopping criterion uses the
  209.          * tolerances to check the domain (point) values, not the function
  210.          * values.
  211.          * The {@code relativeTolerance} and {@code absoluteTolerance}
  212.          * arguments are thus passed to a {@link SimpleUnivariateValueChecker
  213.          * custom checker} that will use the function values.
  214.          *
  215.          * @param optimizer Optimizer on behalf of which the line search
  216.          * be performed.
  217.          * Its {@link MultivariateOptimizer#getObjectiveFunction() objective
  218.          * function} will be called by the {@link #search(double[],double[])
  219.          * search} method.
  220.          * @param relativeTolerance Search will stop when the function relative
  221.          * difference between successive iterations is below this value.
  222.          * @param absoluteTolerance Search will stop when the function absolute
  223.          * difference between successive iterations is below this value.
  224.          * @param initialBracketingRange Extent of the initial interval used to
  225.          * find an interval that brackets the optimum.
  226.          * If the optimized function varies a lot in the vicinity of the optimum,
  227.          * it may be necessary to provide a value lower than the distance between
  228.          * successive local minima.
  229.          */
  230.         /* package-private */ LineSearch(MultivariateOptimizer optimizer,
  231.                                          double relativeTolerance,
  232.                                          double absoluteTolerance,
  233.                                          double initialBracketingRange) {
  234.             mainOptimizer = optimizer;
  235.             lineOptimizer = new BrentOptimizer(REL_TOL_UNUSED,
  236.                                                ABS_TOL_UNUSED,
  237.                                                new SimpleUnivariateValueChecker(relativeTolerance,
  238.                                                                                 absoluteTolerance));
  239.             this.initialBracketingRange = initialBracketingRange;
  240.         }

  241.         /**
  242.          * Finds the number {@code alpha} that optimizes
  243.          * {@code f(startPoint + alpha * direction)}.
  244.          *
  245.          * @param startPoint Starting point.
  246.          * @param direction Search direction.
  247.          * @return the optimum.
  248.          * @throws org.apache.commons.math4.legacy.exception.TooManyEvaluationsException
  249.          * if the number of evaluations is exceeded.
  250.          */
  251.         /* package-private */ UnivariatePointValuePair search(final double[] startPoint,
  252.                                                               final double[] direction) {
  253.             final int n = startPoint.length;
  254.             final MultivariateFunction func = mainOptimizer.getObjectiveFunction();
  255.             final UnivariateFunction f = new UnivariateFunction() {
  256.                     /** {@inheritDoc} */
  257.                     @Override
  258.                     public double value(double alpha) {
  259.                         final double[] x = new double[n];
  260.                         for (int i = 0; i < n; i++) {
  261.                             x[i] = startPoint[i] + alpha * direction[i];
  262.                         }
  263.                         return func.value(x);
  264.                     }
  265.                 };

  266.             final GoalType goal = mainOptimizer.getGoalType();
  267.             bracket.search(f, goal, 0, initialBracketingRange);
  268.             // Passing "MAX_VALUE" as a dummy value because it is the enclosing
  269.             // class that counts the number of evaluations (and will eventually
  270.             // generate the exception).
  271.             return lineOptimizer.optimize(new MaxEval(Integer.MAX_VALUE),
  272.                                           new UnivariateObjectiveFunction(f),
  273.                                           goal,
  274.                                           new SearchInterval(bracket.getLo(),
  275.                                                              bracket.getHi(),
  276.                                                              bracket.getMid()));
  277.         }
  278.     }
  279. }