LeastSquaresFactory.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.math4.legacy.fitting.leastsquares;

  18. import org.apache.commons.math4.legacy.analysis.MultivariateMatrixFunction;
  19. import org.apache.commons.math4.legacy.analysis.MultivariateVectorFunction;
  20. import org.apache.commons.math4.legacy.exception.MathIllegalStateException;
  21. import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
  22. import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresProblem.Evaluation;
  23. import org.apache.commons.math4.legacy.linear.Array2DRowRealMatrix;
  24. import org.apache.commons.math4.legacy.linear.ArrayRealVector;
  25. import org.apache.commons.math4.legacy.linear.DiagonalMatrix;
  26. import org.apache.commons.math4.legacy.linear.EigenDecomposition;
  27. import org.apache.commons.math4.legacy.linear.RealMatrix;
  28. import org.apache.commons.math4.legacy.linear.RealVector;
  29. import org.apache.commons.math4.legacy.optim.AbstractOptimizationProblem;
  30. import org.apache.commons.math4.legacy.optim.ConvergenceChecker;
  31. import org.apache.commons.math4.legacy.optim.PointVectorValuePair;
  32. import org.apache.commons.math4.core.jdkmath.JdkMath;
  33. import org.apache.commons.math4.legacy.core.IntegerSequence;
  34. import org.apache.commons.math4.legacy.core.Pair;

  35. /**
  36.  * A Factory for creating {@link LeastSquaresProblem}s.
  37.  *
  38.  * @since 3.3
  39.  */
  40. public final class LeastSquaresFactory {

  41.     /** Prevent instantiation. */
  42.     private LeastSquaresFactory() {}

  43.     /**
  44.      * Create a {@link org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresProblem}
  45.      * from the given elements.
  46.      *
  47.      * @param model          the model function. Produces the computed values.
  48.      * @param observed       the observed (target) values
  49.      * @param start          the initial guess.
  50.      * @param weight         the weight matrix
  51.      * @param checker        convergence checker
  52.      * @param maxEvaluations the maximum number of times to evaluate the model
  53.      * @param maxIterations  the maximum number to times to iterate in the algorithm
  54.      * @param lazyEvaluation Whether the call to {@link Evaluation#evaluate(RealVector)}
  55.      * will defer the evaluation until access to the value is requested.
  56.      * @param paramValidator Model parameters validator.
  57.      * @return the specified General Least Squares problem.
  58.      *
  59.      * @since 3.4
  60.      */
  61.     public static LeastSquaresProblem create(final MultivariateJacobianFunction model,
  62.                                              final RealVector observed,
  63.                                              final RealVector start,
  64.                                              final RealMatrix weight,
  65.                                              final ConvergenceChecker<Evaluation> checker,
  66.                                              final int maxEvaluations,
  67.                                              final int maxIterations,
  68.                                              final boolean lazyEvaluation,
  69.                                              final ParameterValidator paramValidator) {
  70.         final LeastSquaresProblem p = new LocalLeastSquaresProblem(model,
  71.                                                                    observed,
  72.                                                                    start,
  73.                                                                    checker,
  74.                                                                    maxEvaluations,
  75.                                                                    maxIterations,
  76.                                                                    lazyEvaluation,
  77.                                                                    paramValidator);
  78.         if (weight != null) {
  79.             return weightMatrix(p, weight);
  80.         } else {
  81.             return p;
  82.         }
  83.     }

  84.     /**
  85.      * Create a {@link org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresProblem}
  86.      * from the given elements. There will be no weights applied (unit weights).
  87.      *
  88.      * @param model          the model function. Produces the computed values.
  89.      * @param observed       the observed (target) values
  90.      * @param start          the initial guess.
  91.      * @param checker        convergence checker
  92.      * @param maxEvaluations the maximum number of times to evaluate the model
  93.      * @param maxIterations  the maximum number to times to iterate in the algorithm
  94.      * @return the specified General Least Squares problem.
  95.      */
  96.     public static LeastSquaresProblem create(final MultivariateJacobianFunction model,
  97.                                              final RealVector observed,
  98.                                              final RealVector start,
  99.                                              final ConvergenceChecker<Evaluation> checker,
  100.                                              final int maxEvaluations,
  101.                                              final int maxIterations) {
  102.         return create(model,
  103.                       observed,
  104.                       start,
  105.                       null,
  106.                       checker,
  107.                       maxEvaluations,
  108.                       maxIterations,
  109.                       false,
  110.                       null);
  111.     }

  112.     /**
  113.      * Create a {@link org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresProblem}
  114.      * from the given elements.
  115.      *
  116.      * @param model          the model function. Produces the computed values.
  117.      * @param observed       the observed (target) values
  118.      * @param start          the initial guess.
  119.      * @param weight         the weight matrix
  120.      * @param checker        convergence checker
  121.      * @param maxEvaluations the maximum number of times to evaluate the model
  122.      * @param maxIterations  the maximum number to times to iterate in the algorithm
  123.      * @return the specified General Least Squares problem.
  124.      */
  125.     public static LeastSquaresProblem create(final MultivariateJacobianFunction model,
  126.                                              final RealVector observed,
  127.                                              final RealVector start,
  128.                                              final RealMatrix weight,
  129.                                              final ConvergenceChecker<Evaluation> checker,
  130.                                              final int maxEvaluations,
  131.                                              final int maxIterations) {
  132.         return weightMatrix(create(model,
  133.                                    observed,
  134.                                    start,
  135.                                    checker,
  136.                                    maxEvaluations,
  137.                                    maxIterations),
  138.                             weight);
  139.     }

  140.     /**
  141.      * Create a {@link org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresProblem}
  142.      * from the given elements.
  143.      * <p>
  144.      * This factory method is provided for continuity with previous interfaces. Newer
  145.      * applications should use {@link #create(MultivariateJacobianFunction, RealVector,
  146.      * RealVector, ConvergenceChecker, int, int)}, or {@link #create(MultivariateJacobianFunction,
  147.      * RealVector, RealVector, RealMatrix, ConvergenceChecker, int, int)}.
  148.      *
  149.      * @param model          the model function. Produces the computed values.
  150.      * @param jacobian       the jacobian of the model with respect to the parameters
  151.      * @param observed       the observed (target) values
  152.      * @param start          the initial guess.
  153.      * @param weight         the weight matrix
  154.      * @param checker        convergence checker
  155.      * @param maxEvaluations the maximum number of times to evaluate the model
  156.      * @param maxIterations  the maximum number to times to iterate in the algorithm
  157.      * @return the specified General Least Squares problem.
  158.      */
  159.     public static LeastSquaresProblem create(final MultivariateVectorFunction model,
  160.                                              final MultivariateMatrixFunction jacobian,
  161.                                              final double[] observed,
  162.                                              final double[] start,
  163.                                              final RealMatrix weight,
  164.                                              final ConvergenceChecker<Evaluation> checker,
  165.                                              final int maxEvaluations,
  166.                                              final int maxIterations) {
  167.         return create(model(model, jacobian),
  168.                       new ArrayRealVector(observed, false),
  169.                       new ArrayRealVector(start, false),
  170.                       weight,
  171.                       checker,
  172.                       maxEvaluations,
  173.                       maxIterations);
  174.     }

  175.     /**
  176.      * Apply a dense weight matrix to the {@link LeastSquaresProblem}.
  177.      *
  178.      * @param problem the unweighted problem
  179.      * @param weights the matrix of weights
  180.      * @return a new {@link LeastSquaresProblem} with the weights applied. The original
  181.      *         {@code problem} is not modified.
  182.      */
  183.     public static LeastSquaresProblem weightMatrix(final LeastSquaresProblem problem,
  184.                                                    final RealMatrix weights) {
  185.         final RealMatrix weightSquareRoot = squareRoot(weights);
  186.         return new LeastSquaresAdapter(problem) {
  187.             /** {@inheritDoc} */
  188.             @Override
  189.             public Evaluation evaluate(final RealVector point) {
  190.                 return new DenseWeightedEvaluation(super.evaluate(point), weightSquareRoot);
  191.             }
  192.         };
  193.     }

  194.     /**
  195.      * Apply a diagonal weight matrix to the {@link LeastSquaresProblem}.
  196.      *
  197.      * @param problem the unweighted problem
  198.      * @param weights the diagonal of the weight matrix
  199.      * @return a new {@link LeastSquaresProblem} with the weights applied. The original
  200.      *         {@code problem} is not modified.
  201.      */
  202.     public static LeastSquaresProblem weightDiagonal(final LeastSquaresProblem problem,
  203.                                                      final RealVector weights) {
  204.         // TODO more efficient implementation
  205.         return weightMatrix(problem, new DiagonalMatrix(weights.toArray()));
  206.     }

  207.     /**
  208.      * Count the evaluations of a particular problem. The {@code counter} will be
  209.      * incremented every time {@link LeastSquaresProblem#evaluate(RealVector)} is called on
  210.      * the <em>returned</em> problem.
  211.      *
  212.      * @param problem the problem to track.
  213.      * @param counter the counter to increment.
  214.      * @return a least squares problem that tracks evaluations
  215.      */
  216.     public static LeastSquaresProblem countEvaluations(final LeastSquaresProblem problem,
  217.                                                        final IntegerSequence.Incrementor counter) {
  218.         return new LeastSquaresAdapter(problem) {

  219.             /** {@inheritDoc} */
  220.             @Override
  221.             public Evaluation evaluate(final RealVector point) {
  222.                 counter.increment();
  223.                 return super.evaluate(point);
  224.             }

  225.             // Delegate the rest.
  226.         };
  227.     }

  228.     /**
  229.      * View a convergence checker specified for a {@link PointVectorValuePair} as one
  230.      * specified for an {@link Evaluation}.
  231.      *
  232.      * @param checker the convergence checker to adapt.
  233.      * @return a convergence checker that delegates to {@code checker}.
  234.      */
  235.     public static ConvergenceChecker<Evaluation> evaluationChecker(final ConvergenceChecker<PointVectorValuePair> checker) {
  236.         return new ConvergenceChecker<Evaluation>() {
  237.             /** {@inheritDoc} */
  238.             @Override
  239.             public boolean converged(final int iteration,
  240.                                      final Evaluation previous,
  241.                                      final Evaluation current) {
  242.                 return checker.converged(
  243.                         iteration,
  244.                         new PointVectorValuePair(
  245.                                 previous.getPoint().toArray(),
  246.                                 previous.getResiduals().toArray(),
  247.                                 false),
  248.                         new PointVectorValuePair(
  249.                                 current.getPoint().toArray(),
  250.                                 current.getResiduals().toArray(),
  251.                                 false)
  252.                 );
  253.             }
  254.         };
  255.     }

  256.     /**
  257.      * Computes the square-root of the weight matrix.
  258.      *
  259.      * @param m Symmetric, positive-definite (weight) matrix.
  260.      * @return the square-root of the weight matrix.
  261.      */
  262.     private static RealMatrix squareRoot(final RealMatrix m) {
  263.         if (m instanceof DiagonalMatrix) {
  264.             final int dim = m.getRowDimension();
  265.             final RealMatrix sqrtM = new DiagonalMatrix(dim);
  266.             for (int i = 0; i < dim; i++) {
  267.                 sqrtM.setEntry(i, i, JdkMath.sqrt(m.getEntry(i, i)));
  268.             }
  269.             return sqrtM;
  270.         } else {
  271.             final EigenDecomposition dec = new EigenDecomposition(m);
  272.             return dec.getSquareRoot();
  273.         }
  274.     }

  275.     /**
  276.      * Combine a {@link MultivariateVectorFunction} with a {@link
  277.      * MultivariateMatrixFunction} to produce a {@link MultivariateJacobianFunction}.
  278.      *
  279.      * @param value    the vector value function
  280.      * @param jacobian the Jacobian function
  281.      * @return a function that computes both at the same time
  282.      */
  283.     public static MultivariateJacobianFunction model(final MultivariateVectorFunction value,
  284.                                                      final MultivariateMatrixFunction jacobian) {
  285.         return new LocalValueAndJacobianFunction(value, jacobian);
  286.     }

  287.     /**
  288.      * Combine a {@link MultivariateVectorFunction} with a {@link
  289.      * MultivariateMatrixFunction} to produce a {@link MultivariateJacobianFunction}.
  290.      */
  291.     private static final class LocalValueAndJacobianFunction
  292.         implements ValueAndJacobianFunction {
  293.         /** Model. */
  294.         private final MultivariateVectorFunction value;
  295.         /** Model's Jacobian. */
  296.         private final MultivariateMatrixFunction jacobian;

  297.         /**
  298.          * @param value Model function.
  299.          * @param jacobian Model's Jacobian function.
  300.          */
  301.         LocalValueAndJacobianFunction(final MultivariateVectorFunction value,
  302.                                       final MultivariateMatrixFunction jacobian) {
  303.             this.value = value;
  304.             this.jacobian = jacobian;
  305.         }

  306.         /** {@inheritDoc} */
  307.         @Override
  308.         public Pair<RealVector, RealMatrix> value(final RealVector point) {
  309.             //TODO get array from RealVector without copying?
  310.             final double[] p = point.toArray();

  311.             // Evaluate.
  312.             return new Pair<>(computeValue(p), computeJacobian(p));
  313.         }

  314.         /** {@inheritDoc} */
  315.         @Override
  316.         public RealVector computeValue(final double[] params) {
  317.             return new ArrayRealVector(value.value(params), false);
  318.         }

  319.         /** {@inheritDoc} */
  320.         @Override
  321.         public RealMatrix computeJacobian(final double[] params) {
  322.             return new Array2DRowRealMatrix(jacobian.value(params), false);
  323.         }
  324.     }


  325.     /**
  326.      * A private, "field" immutable (not "real" immutable) implementation of {@link
  327.      * LeastSquaresProblem}.
  328.      * @since 3.3
  329.      */
  330.     private static final class LocalLeastSquaresProblem
  331.             extends AbstractOptimizationProblem<Evaluation>
  332.             implements LeastSquaresProblem {

  333.         /** Target values for the model function at optimum. */
  334.         private final RealVector target;
  335.         /** Model function. */
  336.         private final MultivariateJacobianFunction model;
  337.         /** Initial guess. */
  338.         private final RealVector start;
  339.         /** Whether to use lazy evaluation. */
  340.         private final boolean lazyEvaluation;
  341.         /** Model parameters validator. */
  342.         private final ParameterValidator paramValidator;

  343.         /**
  344.          * Create a {@link LeastSquaresProblem} from the given data.
  345.          *
  346.          * @param model          the model function
  347.          * @param target         the observed data
  348.          * @param start          the initial guess
  349.          * @param checker        the convergence checker
  350.          * @param maxEvaluations the allowed evaluations
  351.          * @param maxIterations  the allowed iterations
  352.          * @param lazyEvaluation Whether the call to {@link Evaluation#evaluate(RealVector)}
  353.          * will defer the evaluation until access to the value is requested.
  354.          * @param paramValidator Model parameters validator.
  355.          */
  356.         LocalLeastSquaresProblem(final MultivariateJacobianFunction model,
  357.                                  final RealVector target,
  358.                                  final RealVector start,
  359.                                  final ConvergenceChecker<Evaluation> checker,
  360.                                  final int maxEvaluations,
  361.                                  final int maxIterations,
  362.                                  final boolean lazyEvaluation,
  363.                                  final ParameterValidator paramValidator) {
  364.             super(maxEvaluations, maxIterations, checker);
  365.             this.target = target;
  366.             this.model = model;
  367.             this.start = start;
  368.             this.lazyEvaluation = lazyEvaluation;
  369.             this.paramValidator = paramValidator;

  370.             if (lazyEvaluation &&
  371.                 !(model instanceof ValueAndJacobianFunction)) {
  372.                 // Lazy evaluation requires that value and Jacobian
  373.                 // can be computed separately.
  374.                 throw new MathIllegalStateException(LocalizedFormats.INVALID_IMPLEMENTATION,
  375.                                                     model.getClass().getName());
  376.             }
  377.         }

  378.         /** {@inheritDoc} */
  379.         @Override
  380.         public int getObservationSize() {
  381.             return target.getDimension();
  382.         }

  383.         /** {@inheritDoc} */
  384.         @Override
  385.         public int getParameterSize() {
  386.             return start.getDimension();
  387.         }

  388.         /** {@inheritDoc} */
  389.         @Override
  390.         public RealVector getStart() {
  391.             return start == null ? null : start.copy();
  392.         }

  393.         /** {@inheritDoc} */
  394.         @Override
  395.         public Evaluation evaluate(final RealVector point) {
  396.             // Copy so optimizer can change point without changing our instance.
  397.             final RealVector p = paramValidator == null ?
  398.                 point.copy() :
  399.                 paramValidator.validate(point.copy());

  400.             if (lazyEvaluation) {
  401.                 return new LazyUnweightedEvaluation((ValueAndJacobianFunction) model,
  402.                                                     target,
  403.                                                     p);
  404.             } else {
  405.                 // Evaluate value and jacobian in one function call.
  406.                 final Pair<RealVector, RealMatrix> value = model.value(p);
  407.                 return new UnweightedEvaluation(value.getFirst(),
  408.                                                 value.getSecond(),
  409.                                                 target,
  410.                                                 p);
  411.             }
  412.         }

  413.         /**
  414.          * Container with the model evaluation at a particular point.
  415.          */
  416.         private static final class UnweightedEvaluation extends AbstractEvaluation {
  417.             /** Point of evaluation. */
  418.             private final RealVector point;
  419.             /** Derivative at point. */
  420.             private final RealMatrix jacobian;
  421.             /** Computed residuals. */
  422.             private final RealVector residuals;

  423.             /**
  424.              * Create an {@link Evaluation} with no weights.
  425.              *
  426.              * @param values   the computed function values
  427.              * @param jacobian the computed function Jacobian
  428.              * @param target   the observed values
  429.              * @param point    the abscissa
  430.              */
  431.             private UnweightedEvaluation(final RealVector values,
  432.                                          final RealMatrix jacobian,
  433.                                          final RealVector target,
  434.                                          final RealVector point) {
  435.                 super(target.getDimension());
  436.                 this.jacobian = jacobian;
  437.                 this.point = point;
  438.                 this.residuals = target.subtract(values);
  439.             }

  440.             /** {@inheritDoc} */
  441.             @Override
  442.             public RealMatrix getJacobian() {
  443.                 return jacobian;
  444.             }

  445.             /** {@inheritDoc} */
  446.             @Override
  447.             public RealVector getPoint() {
  448.                 return point;
  449.             }

  450.             /** {@inheritDoc} */
  451.             @Override
  452.             public RealVector getResiduals() {
  453.                 return residuals;
  454.             }
  455.         }

  456.         /**
  457.          * Container with the model <em>lazy</em> evaluation at a particular point.
  458.          */
  459.         private static final class LazyUnweightedEvaluation extends AbstractEvaluation {
  460.             /** Point of evaluation. */
  461.             private final RealVector point;
  462.             /** Model and Jacobian functions. */
  463.             private final ValueAndJacobianFunction model;
  464.             /** Target values for the model function at optimum. */
  465.             private final RealVector target;

  466.             /**
  467.              * Create an {@link Evaluation} with no weights.
  468.              *
  469.              * @param model  the model function
  470.              * @param target the observed values
  471.              * @param point  the abscissa
  472.              */
  473.             private LazyUnweightedEvaluation(final ValueAndJacobianFunction model,
  474.                                              final RealVector target,
  475.                                              final RealVector point) {
  476.                 super(target.getDimension());
  477.                 // Safe to cast as long as we control usage of this class.
  478.                 this.model = model;
  479.                 this.point = point;
  480.                 this.target = target;
  481.             }

  482.             /** {@inheritDoc} */
  483.             @Override
  484.             public RealMatrix getJacobian() {
  485.                 return model.computeJacobian(point.toArray());
  486.             }

  487.             /** {@inheritDoc} */
  488.             @Override
  489.             public RealVector getPoint() {
  490.                 return point;
  491.             }

  492.             /** {@inheritDoc} */
  493.             @Override
  494.             public RealVector getResiduals() {
  495.                 return target.subtract(model.computeValue(point.toArray()));
  496.             }
  497.         }
  498.     }
  499. }