View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.math4.legacy.fitting;
18  
19  import java.util.Arrays;
20  import java.util.Collection;
21  
22  import org.apache.commons.math4.legacy.analysis.MultivariateMatrixFunction;
23  import org.apache.commons.math4.legacy.analysis.MultivariateVectorFunction;
24  import org.apache.commons.math4.legacy.analysis.ParametricUnivariateFunction;
25  import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresOptimizer;
26  import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresProblem;
27  import org.apache.commons.math4.legacy.fitting.leastsquares.LevenbergMarquardtOptimizer;
28  
29  /**
30   * Base class that contains common code for fitting parametric univariate
31   * real functions <code>y = f(p<sub>i</sub>;x)</code>, where {@code x} is
32   * the independent variable and the <code>p<sub>i</sub></code> are the
33   * <em>parameters</em>.
34   * <br>
35   * A fitter will find the optimal values of the parameters by
36   * <em>fitting</em> the curve so it remains very close to a set of
37   * {@code N} observed points <code>(x<sub>k</sub>, y<sub>k</sub>)</code>,
38   * {@code 0 <= k < N}.
39   * <br>
40   * An algorithm usually performs the fit by finding the parameter
41   * values that minimizes the objective function
42   * <pre><code>
43   *  &sum;y<sub>k</sub> - f(x<sub>k</sub>)<sup>2</sup>,
44   * </code></pre>
45   * which is actually a least-squares problem.
46   * This class contains boilerplate code for calling the
47   * {@link #fit(Collection)} method for obtaining the parameters.
48   * The problem setup, such as the choice of optimization algorithm
49   * for fitting a specific function is delegated to subclasses.
50   *
51   * @since 3.3
52   */
53  public abstract class AbstractCurveFitter {
54      /**
55       * Fits a curve.
56       * This method computes the coefficients of the curve that best
57       * fit the sample of observed points.
58       *
59       * @param points Observations.
60       * @return the fitted parameters.
61       */
62      public double[] fit(Collection<WeightedObservedPoint> points) {
63          // Perform the fit.
64          return getOptimizer().optimize(getProblem(points)).getPoint().toArray();
65      }
66  
67      /**
68       * Creates an optimizer set up to fit the appropriate curve.
69       * <p>
70       * The default implementation uses a {@link LevenbergMarquardtOptimizer
71       * Levenberg-Marquardt} optimizer.
72       * </p>
73       * @return the optimizer to use for fitting the curve to the
74       * given {@code points}.
75       */
76      protected LeastSquaresOptimizer getOptimizer() {
77          return new LevenbergMarquardtOptimizer();
78      }
79  
80      /**
81       * Creates a least squares problem corresponding to the appropriate curve.
82       *
83       * @param points Sample points.
84       * @return the least squares problem to use for fitting the curve to the
85       * given {@code points}.
86       */
87      protected abstract LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> points);
88  
89      /**
90       * Vector function for computing function theoretical values.
91       */
92      protected static class TheoreticalValuesFunction {
93          /** Function to fit. */
94          private final ParametricUnivariateFunction f;
95          /** Observations. */
96          private final double[] points;
97  
98          /**
99           * @param f function to fit.
100          * @param observations Observations.
101          */
102         public TheoreticalValuesFunction(final ParametricUnivariateFunction f,
103                                          final Collection<WeightedObservedPoint> observations) {
104             this.f = f;
105             this.points = observations.stream().mapToDouble(WeightedObservedPoint::getX).toArray();
106         }
107 
108         /**
109          * @return the model function values.
110          */
111         public MultivariateVectorFunction getModelFunction() {
112             return new MultivariateVectorFunction() {
113                 /** {@inheritDoc} */
114                 @Override
115                 public double[] value(double[] p) {
116                     return Arrays.stream(points).map(point -> f.value(point, p)).toArray();
117                 }
118             };
119         }
120 
121         /**
122          * @return the model function Jacobian.
123          */
124         public MultivariateMatrixFunction getModelFunctionJacobian() {
125             return new MultivariateMatrixFunction() {
126                 /** {@inheritDoc} */
127                 @Override
128                 public double[][] value(double[] p) {
129                     final int len = points.length;
130                     final double[][] jacobian = new double[len][];
131                     for (int i = 0; i < len; i++) {
132                         jacobian[i] = f.gradient(points[i], p);
133                     }
134                     return jacobian;
135                 }
136             };
137         }
138     }
139 }