View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math3.optimization.fitting;
19  
20  import java.util.ArrayList;
21  import java.util.List;
22  
23  import org.apache.commons.math3.analysis.DifferentiableMultivariateVectorFunction;
24  import org.apache.commons.math3.analysis.MultivariateMatrixFunction;
25  import org.apache.commons.math3.analysis.ParametricUnivariateFunction;
26  import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
27  import org.apache.commons.math3.analysis.differentiation.MultivariateDifferentiableVectorFunction;
28  import org.apache.commons.math3.optimization.DifferentiableMultivariateVectorOptimizer;
29  import org.apache.commons.math3.optimization.MultivariateDifferentiableVectorOptimizer;
30  import org.apache.commons.math3.optimization.PointVectorValuePair;
31  
32  /** Fitter for parametric univariate real functions y = f(x).
33   * <br/>
34   * When a univariate real function y = f(x) does depend on some
35   * unknown parameters p<sub>0</sub>, p<sub>1</sub> ... p<sub>n-1</sub>,
36   * this class can be used to find these parameters. It does this
37   * by <em>fitting</em> the curve so it remains very close to a set of
38   * observed points (x<sub>0</sub>, y<sub>0</sub>), (x<sub>1</sub>,
39   * y<sub>1</sub>) ... (x<sub>k-1</sub>, y<sub>k-1</sub>). This fitting
40   * is done by finding the parameters values that minimizes the objective
41   * function &sum;(y<sub>i</sub>-f(x<sub>i</sub>))<sup>2</sup>. This is
42   * really a least squares problem.
43   *
44   * @param <T> Function to use for the fit.
45   *
46   * @deprecated As of 3.1 (to be removed in 4.0).
47   * @since 2.0
48   */
49  @Deprecated
50  public class CurveFitter<T extends ParametricUnivariateFunction> {
51  
52      /** Optimizer to use for the fitting.
53       * @deprecated as of 3.1 replaced by {@link #optimizer}
54       */
55      @Deprecated
56      private final DifferentiableMultivariateVectorOptimizer oldOptimizer;
57  
58      /** Optimizer to use for the fitting. */
59      private final MultivariateDifferentiableVectorOptimizer optimizer;
60  
61      /** Observed points. */
62      private final List<WeightedObservedPoint> observations;
63  
64      /** Simple constructor.
65       * @param optimizer optimizer to use for the fitting
66       * @deprecated as of 3.1 replaced by {@link #CurveFitter(MultivariateDifferentiableVectorOptimizer)}
67       */
68      @Deprecated
69      public CurveFitter(final DifferentiableMultivariateVectorOptimizer optimizer) {
70          this.oldOptimizer = optimizer;
71          this.optimizer    = null;
72          observations      = new ArrayList<WeightedObservedPoint>();
73      }
74  
75      /** Simple constructor.
76       * @param optimizer optimizer to use for the fitting
77       * @since 3.1
78       */
79      public CurveFitter(final MultivariateDifferentiableVectorOptimizer optimizer) {
80          this.oldOptimizer = null;
81          this.optimizer    = optimizer;
82          observations      = new ArrayList<WeightedObservedPoint>();
83      }
84  
85      /** Add an observed (x,y) point to the sample with unit weight.
86       * <p>Calling this method is equivalent to call
87       * {@code addObservedPoint(1.0, x, y)}.</p>
88       * @param x abscissa of the point
89       * @param y observed value of the point at x, after fitting we should
90       * have f(x) as close as possible to this value
91       * @see #addObservedPoint(double, double, double)
92       * @see #addObservedPoint(WeightedObservedPoint)
93       * @see #getObservations()
94       */
95      public void addObservedPoint(double x, double y) {
96          addObservedPoint(1.0, x, y);
97      }
98  
99      /** Add an observed weighted (x,y) point to the sample.
100      * @param weight weight of the observed point in the fit
101      * @param x abscissa of the point
102      * @param y observed value of the point at x, after fitting we should
103      * have f(x) as close as possible to this value
104      * @see #addObservedPoint(double, double)
105      * @see #addObservedPoint(WeightedObservedPoint)
106      * @see #getObservations()
107      */
108     public void addObservedPoint(double weight, double x, double y) {
109         observations.add(new WeightedObservedPoint(weight, x, y));
110     }
111 
112     /** Add an observed weighted (x,y) point to the sample.
113      * @param observed observed point to add
114      * @see #addObservedPoint(double, double)
115      * @see #addObservedPoint(double, double, double)
116      * @see #getObservations()
117      */
118     public void addObservedPoint(WeightedObservedPoint observed) {
119         observations.add(observed);
120     }
121 
122     /** Get the observed points.
123      * @return observed points
124      * @see #addObservedPoint(double, double)
125      * @see #addObservedPoint(double, double, double)
126      * @see #addObservedPoint(WeightedObservedPoint)
127      */
128     public WeightedObservedPoint[] getObservations() {
129         return observations.toArray(new WeightedObservedPoint[observations.size()]);
130     }
131 
132     /**
133      * Remove all observations.
134      */
135     public void clearObservations() {
136         observations.clear();
137     }
138 
139     /**
140      * Fit a curve.
141      * This method compute the coefficients of the curve that best
142      * fit the sample of observed points previously given through calls
143      * to the {@link #addObservedPoint(WeightedObservedPoint)
144      * addObservedPoint} method.
145      *
146      * @param f parametric function to fit.
147      * @param initialGuess first guess of the function parameters.
148      * @return the fitted parameters.
149      * @throws org.apache.commons.math3.exception.DimensionMismatchException
150      * if the start point dimension is wrong.
151      */
152     public double[] fit(T f, final double[] initialGuess) {
153         return fit(Integer.MAX_VALUE, f, initialGuess);
154     }
155 
156     /**
157      * Fit a curve.
158      * This method compute the coefficients of the curve that best
159      * fit the sample of observed points previously given through calls
160      * to the {@link #addObservedPoint(WeightedObservedPoint)
161      * addObservedPoint} method.
162      *
163      * @param f parametric function to fit.
164      * @param initialGuess first guess of the function parameters.
165      * @param maxEval Maximum number of function evaluations.
166      * @return the fitted parameters.
167      * @throws org.apache.commons.math3.exception.TooManyEvaluationsException
168      * if the number of allowed evaluations is exceeded.
169      * @throws org.apache.commons.math3.exception.DimensionMismatchException
170      * if the start point dimension is wrong.
171      * @since 3.0
172      */
173     public double[] fit(int maxEval, T f,
174                         final double[] initialGuess) {
175         // prepare least squares problem
176         double[] target  = new double[observations.size()];
177         double[] weights = new double[observations.size()];
178         int i = 0;
179         for (WeightedObservedPoint point : observations) {
180             target[i]  = point.getY();
181             weights[i] = point.getWeight();
182             ++i;
183         }
184 
185         // perform the fit
186         final PointVectorValuePair optimum;
187         if (optimizer == null) {
188             // to be removed in 4.0
189             optimum = oldOptimizer.optimize(maxEval, new OldTheoreticalValuesFunction(f),
190                                             target, weights, initialGuess);
191         } else {
192             optimum = optimizer.optimize(maxEval, new TheoreticalValuesFunction(f),
193                                          target, weights, initialGuess);
194         }
195 
196         // extract the coefficients
197         return optimum.getPointRef();
198     }
199 
200     /** Vectorial function computing function theoretical values. */
201     @Deprecated
202     private class OldTheoreticalValuesFunction
203         implements DifferentiableMultivariateVectorFunction {
204         /** Function to fit. */
205         private final ParametricUnivariateFunction f;
206 
207         /** Simple constructor.
208          * @param f function to fit.
209          */
210         public OldTheoreticalValuesFunction(final ParametricUnivariateFunction f) {
211             this.f = f;
212         }
213 
214         /** {@inheritDoc} */
215         public MultivariateMatrixFunction jacobian() {
216             return new MultivariateMatrixFunction() {
217                 public double[][] value(double[] point) {
218                     final double[][] jacobian = new double[observations.size()][];
219 
220                     int i = 0;
221                     for (WeightedObservedPoint observed : observations) {
222                         jacobian[i++] = f.gradient(observed.getX(), point);
223                     }
224 
225                     return jacobian;
226                 }
227             };
228         }
229 
230         /** {@inheritDoc} */
231         public double[] value(double[] point) {
232             // compute the residuals
233             final double[] values = new double[observations.size()];
234             int i = 0;
235             for (WeightedObservedPoint observed : observations) {
236                 values[i++] = f.value(observed.getX(), point);
237             }
238 
239             return values;
240         }
241     }
242 
243     /** Vectorial function computing function theoretical values. */
244     private class TheoreticalValuesFunction implements MultivariateDifferentiableVectorFunction {
245 
246         /** Function to fit. */
247         private final ParametricUnivariateFunction f;
248 
249         /** Simple constructor.
250          * @param f function to fit.
251          */
252         public TheoreticalValuesFunction(final ParametricUnivariateFunction f) {
253             this.f = f;
254         }
255 
256         /** {@inheritDoc} */
257         public double[] value(double[] point) {
258             // compute the residuals
259             final double[] values = new double[observations.size()];
260             int i = 0;
261             for (WeightedObservedPoint observed : observations) {
262                 values[i++] = f.value(observed.getX(), point);
263             }
264 
265             return values;
266         }
267 
268         /** {@inheritDoc} */
269         public DerivativeStructure[] value(DerivativeStructure[] point) {
270 
271             // extract parameters
272             final double[] parameters = new double[point.length];
273             for (int k = 0; k < point.length; ++k) {
274                 parameters[k] = point[k].getValue();
275             }
276 
277             // compute the residuals
278             final DerivativeStructure[] values = new DerivativeStructure[observations.size()];
279             int i = 0;
280             for (WeightedObservedPoint observed : observations) {
281 
282                 // build the DerivativeStructure by adding first the value as a constant
283                 // and then adding derivatives
284                 DerivativeStructure vi = new DerivativeStructure(point.length, 1, f.value(observed.getX(), parameters));
285                 for (int k = 0; k < point.length; ++k) {
286                     vi = vi.add(new DerivativeStructure(point.length, 1, k, 0.0));
287                 }
288 
289                 values[i++] = vi;
290 
291             }
292 
293             return values;
294         }
295 
296     }
297 
298 }