AbstractCurveFitter.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.math4.legacy.fitting;

  18. import java.util.Arrays;
  19. import java.util.Collection;

  20. import org.apache.commons.math4.legacy.analysis.MultivariateMatrixFunction;
  21. import org.apache.commons.math4.legacy.analysis.MultivariateVectorFunction;
  22. import org.apache.commons.math4.legacy.analysis.ParametricUnivariateFunction;
  23. import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresOptimizer;
  24. import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresProblem;
  25. import org.apache.commons.math4.legacy.fitting.leastsquares.LevenbergMarquardtOptimizer;

  26. /**
  27.  * Base class that contains common code for fitting parametric univariate
  28.  * real functions <code>y = f(p<sub>i</sub>;x)</code>, where {@code x} is
  29.  * the independent variable and the <code>p<sub>i</sub></code> are the
  30.  * <em>parameters</em>.
  31.  * <br>
  32.  * A fitter will find the optimal values of the parameters by
  33.  * <em>fitting</em> the curve so it remains very close to a set of
  34.  * {@code N} observed points <code>(x<sub>k</sub>, y<sub>k</sub>)</code>,
  35.  * {@code 0 <= k < N}.
  36.  * <br>
  37.  * An algorithm usually performs the fit by finding the parameter
  38.  * values that minimizes the objective function
  39.  * <pre><code>
  40.  *  &sum;y<sub>k</sub> - f(x<sub>k</sub>)<sup>2</sup>,
  41.  * </code></pre>
  42.  * which is actually a least-squares problem.
  43.  * This class contains boilerplate code for calling the
  44.  * {@link #fit(Collection)} method for obtaining the parameters.
  45.  * The problem setup, such as the choice of optimization algorithm
  46.  * for fitting a specific function is delegated to subclasses.
  47.  *
  48.  * @since 3.3
  49.  */
  50. public abstract class AbstractCurveFitter {
  51.     /**
  52.      * Fits a curve.
  53.      * This method computes the coefficients of the curve that best
  54.      * fit the sample of observed points.
  55.      *
  56.      * @param points Observations.
  57.      * @return the fitted parameters.
  58.      */
  59.     public double[] fit(Collection<WeightedObservedPoint> points) {
  60.         // Perform the fit.
  61.         return getOptimizer().optimize(getProblem(points)).getPoint().toArray();
  62.     }

  63.     /**
  64.      * Creates an optimizer set up to fit the appropriate curve.
  65.      * <p>
  66.      * The default implementation uses a {@link LevenbergMarquardtOptimizer
  67.      * Levenberg-Marquardt} optimizer.
  68.      * </p>
  69.      * @return the optimizer to use for fitting the curve to the
  70.      * given {@code points}.
  71.      */
  72.     protected LeastSquaresOptimizer getOptimizer() {
  73.         return new LevenbergMarquardtOptimizer();
  74.     }

  75.     /**
  76.      * Creates a least squares problem corresponding to the appropriate curve.
  77.      *
  78.      * @param points Sample points.
  79.      * @return the least squares problem to use for fitting the curve to the
  80.      * given {@code points}.
  81.      */
  82.     protected abstract LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> points);

  83.     /**
  84.      * Vector function for computing function theoretical values.
  85.      */
  86.     protected static class TheoreticalValuesFunction {
  87.         /** Function to fit. */
  88.         private final ParametricUnivariateFunction f;
  89.         /** Observations. */
  90.         private final double[] points;

  91.         /**
  92.          * @param f function to fit.
  93.          * @param observations Observations.
  94.          */
  95.         public TheoreticalValuesFunction(final ParametricUnivariateFunction f,
  96.                                          final Collection<WeightedObservedPoint> observations) {
  97.             this.f = f;
  98.             this.points = observations.stream().mapToDouble(WeightedObservedPoint::getX).toArray();
  99.         }

  100.         /**
  101.          * @return the model function values.
  102.          */
  103.         public MultivariateVectorFunction getModelFunction() {
  104.             return new MultivariateVectorFunction() {
  105.                 /** {@inheritDoc} */
  106.                 @Override
  107.                 public double[] value(double[] p) {
  108.                     return Arrays.stream(points).map(point -> f.value(point, p)).toArray();
  109.                 }
  110.             };
  111.         }

  112.         /**
  113.          * @return the model function Jacobian.
  114.          */
  115.         public MultivariateMatrixFunction getModelFunctionJacobian() {
  116.             return new MultivariateMatrixFunction() {
  117.                 /** {@inheritDoc} */
  118.                 @Override
  119.                 public double[][] value(double[] p) {
  120.                     final int len = points.length;
  121.                     final double[][] jacobian = new double[len][];
  122.                     for (int i = 0; i < len; i++) {
  123.                         jacobian[i] = f.gradient(points[i], p);
  124.                     }
  125.                     return jacobian;
  126.                 }
  127.             };
  128.         }
  129.     }
  130. }