View Javadoc

1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.math3.fitting;
18  
19  import java.util.ArrayList;
20  import java.util.List;
21  import org.apache.commons.math3.analysis.MultivariateVectorFunction;
22  import org.apache.commons.math3.analysis.MultivariateMatrixFunction;
23  import org.apache.commons.math3.analysis.ParametricUnivariateFunction;
24  import org.apache.commons.math3.optim.MaxEval;
25  import org.apache.commons.math3.optim.InitialGuess;
26  import org.apache.commons.math3.optim.PointVectorValuePair;
27  import org.apache.commons.math3.optim.nonlinear.vector.MultivariateVectorOptimizer;
28  import org.apache.commons.math3.optim.nonlinear.vector.ModelFunction;
29  import org.apache.commons.math3.optim.nonlinear.vector.ModelFunctionJacobian;
30  import org.apache.commons.math3.optim.nonlinear.vector.Target;
31  import org.apache.commons.math3.optim.nonlinear.vector.Weight;
32  
33  /**
34   * Fitter for parametric univariate real functions y = f(x).
35   * <br/>
36   * When a univariate real function y = f(x) does depend on some
37   * unknown parameters p<sub>0</sub>, p<sub>1</sub> ... p<sub>n-1</sub>,
38   * this class can be used to find these parameters. It does this
39   * by <em>fitting</em> the curve so it remains very close to a set of
40   * observed points (x<sub>0</sub>, y<sub>0</sub>), (x<sub>1</sub>,
41   * y<sub>1</sub>) ... (x<sub>k-1</sub>, y<sub>k-1</sub>). This fitting
42   * is done by finding the parameters values that minimizes the objective
43   * function &sum;(y<sub>i</sub>-f(x<sub>i</sub>))<sup>2</sup>. This is
44   * really a least squares problem.
45   *
46   * @param <T> Function to use for the fit.
47   *
48   * @version $Id: CurveFitter.java 1416643 2012-12-03 19:37:14Z tn $
49   * @since 2.0
50   */
51  public class CurveFitter<T extends ParametricUnivariateFunction> {
52      /** Optimizer to use for the fitting. */
53      private final MultivariateVectorOptimizer optimizer;
54      /** Observed points. */
55      private final List<WeightedObservedPoint> observations;
56  
57      /**
58       * Simple constructor.
59       *
60       * @param optimizer Optimizer to use for the fitting.
61       * @since 3.1
62       */
63      public CurveFitter(final MultivariateVectorOptimizer optimizer) {
64          this.optimizer = optimizer;
65          observations = new ArrayList<WeightedObservedPoint>();
66      }
67  
68      /** Add an observed (x,y) point to the sample with unit weight.
69       * <p>Calling this method is equivalent to call
70       * {@code addObservedPoint(1.0, x, y)}.</p>
71       * @param x abscissa of the point
72       * @param y observed value of the point at x, after fitting we should
73       * have f(x) as close as possible to this value
74       * @see #addObservedPoint(double, double, double)
75       * @see #addObservedPoint(WeightedObservedPoint)
76       * @see #getObservations()
77       */
78      public void addObservedPoint(double x, double y) {
79          addObservedPoint(1.0, x, y);
80      }
81  
82      /** Add an observed weighted (x,y) point to the sample.
83       * @param weight weight of the observed point in the fit
84       * @param x abscissa of the point
85       * @param y observed value of the point at x, after fitting we should
86       * have f(x) as close as possible to this value
87       * @see #addObservedPoint(double, double)
88       * @see #addObservedPoint(WeightedObservedPoint)
89       * @see #getObservations()
90       */
91      public void addObservedPoint(double weight, double x, double y) {
92          observations.add(new WeightedObservedPoint(weight, x, y));
93      }
94  
95      /** Add an observed weighted (x,y) point to the sample.
96       * @param observed observed point to add
97       * @see #addObservedPoint(double, double)
98       * @see #addObservedPoint(double, double, double)
99       * @see #getObservations()
100      */
101     public void addObservedPoint(WeightedObservedPoint observed) {
102         observations.add(observed);
103     }
104 
105     /** Get the observed points.
106      * @return observed points
107      * @see #addObservedPoint(double, double)
108      * @see #addObservedPoint(double, double, double)
109      * @see #addObservedPoint(WeightedObservedPoint)
110      */
111     public WeightedObservedPoint[] getObservations() {
112         return observations.toArray(new WeightedObservedPoint[observations.size()]);
113     }
114 
115     /**
116      * Remove all observations.
117      */
118     public void clearObservations() {
119         observations.clear();
120     }
121 
122     /**
123      * Fit a curve.
124      * This method compute the coefficients of the curve that best
125      * fit the sample of observed points previously given through calls
126      * to the {@link #addObservedPoint(WeightedObservedPoint)
127      * addObservedPoint} method.
128      *
129      * @param f parametric function to fit.
130      * @param initialGuess first guess of the function parameters.
131      * @return the fitted parameters.
132      * @throws org.apache.commons.math3.exception.DimensionMismatchException
133      * if the start point dimension is wrong.
134      */
135     public double[] fit(T f, final double[] initialGuess) {
136         return fit(Integer.MAX_VALUE, f, initialGuess);
137     }
138 
139     /**
140      * Fit a curve.
141      * This method compute the coefficients of the curve that best
142      * fit the sample of observed points previously given through calls
143      * to the {@link #addObservedPoint(WeightedObservedPoint)
144      * addObservedPoint} method.
145      *
146      * @param f parametric function to fit.
147      * @param initialGuess first guess of the function parameters.
148      * @param maxEval Maximum number of function evaluations.
149      * @return the fitted parameters.
150      * @throws org.apache.commons.math3.exception.TooManyEvaluationsException
151      * if the number of allowed evaluations is exceeded.
152      * @throws org.apache.commons.math3.exception.DimensionMismatchException
153      * if the start point dimension is wrong.
154      * @since 3.0
155      */
156     public double[] fit(int maxEval, T f,
157                         final double[] initialGuess) {
158         // Prepare least squares problem.
159         double[] target  = new double[observations.size()];
160         double[] weights = new double[observations.size()];
161         int i = 0;
162         for (WeightedObservedPoint point : observations) {
163             target[i]  = point.getY();
164             weights[i] = point.getWeight();
165             ++i;
166         }
167 
168         // Input to the optimizer: the model and its Jacobian.
169         final TheoreticalValuesFunction model = new TheoreticalValuesFunction(f);
170 
171         // Perform the fit.
172         final PointVectorValuePair optimum
173             = optimizer.optimize(new MaxEval(maxEval),
174                                  model.getModelFunction(),
175                                  model.getModelFunctionJacobian(),
176                                  new Target(target),
177                                  new Weight(weights),
178                                  new InitialGuess(initialGuess));
179         // Extract the coefficients.
180         return optimum.getPointRef();
181     }
182 
183     /** Vectorial function computing function theoretical values. */
184     private class TheoreticalValuesFunction {
185         /** Function to fit. */
186         private final ParametricUnivariateFunction f;
187 
188         /**
189          * @param f function to fit.
190          */
191         public TheoreticalValuesFunction(final ParametricUnivariateFunction f) {
192             this.f = f;
193         }
194 
195         /**
196          * @return the model function values.
197          */
198         public ModelFunction getModelFunction() {
199             return new ModelFunction(new MultivariateVectorFunction() {
200                     /** {@inheritDoc} */
201                     public double[] value(double[] point) {
202                         // compute the residuals
203                         final double[] values = new double[observations.size()];
204                         int i = 0;
205                         for (WeightedObservedPoint observed : observations) {
206                             values[i++] = f.value(observed.getX(), point);
207                         }
208 
209                         return values;
210                     }
211                 });
212         }
213 
214         /**
215          * @return the model function Jacobian.
216          */
217         public ModelFunctionJacobian getModelFunctionJacobian() {
218             return new ModelFunctionJacobian(new MultivariateMatrixFunction() {
219                     public double[][] value(double[] point) {
220                         final double[][] jacobian = new double[observations.size()][];
221                         int i = 0;
222                         for (WeightedObservedPoint observed : observations) {
223                             jacobian[i++] = f.gradient(observed.getX(), point);
224                         }
225                         return jacobian;
226                     }
227                 });
228         }
229     }
230 }