001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math3.optimization.fitting;
019
020import java.util.ArrayList;
021import java.util.List;
022
023import org.apache.commons.math3.analysis.DifferentiableMultivariateVectorFunction;
024import org.apache.commons.math3.analysis.MultivariateMatrixFunction;
025import org.apache.commons.math3.analysis.ParametricUnivariateFunction;
026import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
027import org.apache.commons.math3.analysis.differentiation.MultivariateDifferentiableVectorFunction;
028import org.apache.commons.math3.optimization.DifferentiableMultivariateVectorOptimizer;
029import org.apache.commons.math3.optimization.MultivariateDifferentiableVectorOptimizer;
030import org.apache.commons.math3.optimization.PointVectorValuePair;
031
032/** Fitter for parametric univariate real functions y = f(x).
033 * <br/>
034 * When a univariate real function y = f(x) does depend on some
035 * unknown parameters p<sub>0</sub>, p<sub>1</sub> ... p<sub>n-1</sub>,
036 * this class can be used to find these parameters. It does this
037 * by <em>fitting</em> the curve so it remains very close to a set of
038 * observed points (x<sub>0</sub>, y<sub>0</sub>), (x<sub>1</sub>,
039 * y<sub>1</sub>) ... (x<sub>k-1</sub>, y<sub>k-1</sub>). This fitting
040 * is done by finding the parameters values that minimizes the objective
041 * function &sum;(y<sub>i</sub>-f(x<sub>i</sub>))<sup>2</sup>. This is
042 * really a least squares problem.
043 *
044 * @param <T> Function to use for the fit.
045 *
046 * @version $Id: CurveFitter.java 1499808 2013-07-04 17:00:42Z sebb $
047 * @deprecated As of 3.1 (to be removed in 4.0).
048 * @since 2.0
049 */
050@Deprecated
051public class CurveFitter<T extends ParametricUnivariateFunction> {
052
053    /** Optimizer to use for the fitting.
054     * @deprecated as of 3.1 replaced by {@link #optimizer}
055     */
056    @Deprecated
057    private final DifferentiableMultivariateVectorOptimizer oldOptimizer;
058
059    /** Optimizer to use for the fitting. */
060    private final MultivariateDifferentiableVectorOptimizer optimizer;
061
062    /** Observed points. */
063    private final List<WeightedObservedPoint> observations;
064
065    /** Simple constructor.
066     * @param optimizer optimizer to use for the fitting
067     * @deprecated as of 3.1 replaced by {@link #CurveFitter(MultivariateDifferentiableVectorOptimizer)}
068     */
069    @Deprecated
070    public CurveFitter(final DifferentiableMultivariateVectorOptimizer optimizer) {
071        this.oldOptimizer = optimizer;
072        this.optimizer    = null;
073        observations      = new ArrayList<WeightedObservedPoint>();
074    }
075
076    /** Simple constructor.
077     * @param optimizer optimizer to use for the fitting
078     * @since 3.1
079     */
080    public CurveFitter(final MultivariateDifferentiableVectorOptimizer optimizer) {
081        this.oldOptimizer = null;
082        this.optimizer    = optimizer;
083        observations      = new ArrayList<WeightedObservedPoint>();
084    }
085
086    /** Add an observed (x,y) point to the sample with unit weight.
087     * <p>Calling this method is equivalent to call
088     * {@code addObservedPoint(1.0, x, y)}.</p>
089     * @param x abscissa of the point
090     * @param y observed value of the point at x, after fitting we should
091     * have f(x) as close as possible to this value
092     * @see #addObservedPoint(double, double, double)
093     * @see #addObservedPoint(WeightedObservedPoint)
094     * @see #getObservations()
095     */
096    public void addObservedPoint(double x, double y) {
097        addObservedPoint(1.0, x, y);
098    }
099
100    /** Add an observed weighted (x,y) point to the sample.
101     * @param weight weight of the observed point in the fit
102     * @param x abscissa of the point
103     * @param y observed value of the point at x, after fitting we should
104     * have f(x) as close as possible to this value
105     * @see #addObservedPoint(double, double)
106     * @see #addObservedPoint(WeightedObservedPoint)
107     * @see #getObservations()
108     */
109    public void addObservedPoint(double weight, double x, double y) {
110        observations.add(new WeightedObservedPoint(weight, x, y));
111    }
112
113    /** Add an observed weighted (x,y) point to the sample.
114     * @param observed observed point to add
115     * @see #addObservedPoint(double, double)
116     * @see #addObservedPoint(double, double, double)
117     * @see #getObservations()
118     */
119    public void addObservedPoint(WeightedObservedPoint observed) {
120        observations.add(observed);
121    }
122
123    /** Get the observed points.
124     * @return observed points
125     * @see #addObservedPoint(double, double)
126     * @see #addObservedPoint(double, double, double)
127     * @see #addObservedPoint(WeightedObservedPoint)
128     */
129    public WeightedObservedPoint[] getObservations() {
130        return observations.toArray(new WeightedObservedPoint[observations.size()]);
131    }
132
133    /**
134     * Remove all observations.
135     */
136    public void clearObservations() {
137        observations.clear();
138    }
139
140    /**
141     * Fit a curve.
142     * This method compute the coefficients of the curve that best
143     * fit the sample of observed points previously given through calls
144     * to the {@link #addObservedPoint(WeightedObservedPoint)
145     * addObservedPoint} method.
146     *
147     * @param f parametric function to fit.
148     * @param initialGuess first guess of the function parameters.
149     * @return the fitted parameters.
150     * @throws org.apache.commons.math3.exception.DimensionMismatchException
151     * if the start point dimension is wrong.
152     */
153    public double[] fit(T f, final double[] initialGuess) {
154        return fit(Integer.MAX_VALUE, f, initialGuess);
155    }
156
157    /**
158     * Fit a curve.
159     * This method compute the coefficients of the curve that best
160     * fit the sample of observed points previously given through calls
161     * to the {@link #addObservedPoint(WeightedObservedPoint)
162     * addObservedPoint} method.
163     *
164     * @param f parametric function to fit.
165     * @param initialGuess first guess of the function parameters.
166     * @param maxEval Maximum number of function evaluations.
167     * @return the fitted parameters.
168     * @throws org.apache.commons.math3.exception.TooManyEvaluationsException
169     * if the number of allowed evaluations is exceeded.
170     * @throws org.apache.commons.math3.exception.DimensionMismatchException
171     * if the start point dimension is wrong.
172     * @since 3.0
173     */
174    public double[] fit(int maxEval, T f,
175                        final double[] initialGuess) {
176        // prepare least squares problem
177        double[] target  = new double[observations.size()];
178        double[] weights = new double[observations.size()];
179        int i = 0;
180        for (WeightedObservedPoint point : observations) {
181            target[i]  = point.getY();
182            weights[i] = point.getWeight();
183            ++i;
184        }
185
186        // perform the fit
187        final PointVectorValuePair optimum;
188        if (optimizer == null) {
189            // to be removed in 4.0
190            optimum = oldOptimizer.optimize(maxEval, new OldTheoreticalValuesFunction(f),
191                                            target, weights, initialGuess);
192        } else {
193            optimum = optimizer.optimize(maxEval, new TheoreticalValuesFunction(f),
194                                         target, weights, initialGuess);
195        }
196
197        // extract the coefficients
198        return optimum.getPointRef();
199    }
200
201    /** Vectorial function computing function theoretical values. */
202    @Deprecated
203    private class OldTheoreticalValuesFunction
204        implements DifferentiableMultivariateVectorFunction {
205        /** Function to fit. */
206        private final ParametricUnivariateFunction f;
207
208        /** Simple constructor.
209         * @param f function to fit.
210         */
211        public OldTheoreticalValuesFunction(final ParametricUnivariateFunction f) {
212            this.f = f;
213        }
214
215        /** {@inheritDoc} */
216        public MultivariateMatrixFunction jacobian() {
217            return new MultivariateMatrixFunction() {
218                public double[][] value(double[] point) {
219                    final double[][] jacobian = new double[observations.size()][];
220
221                    int i = 0;
222                    for (WeightedObservedPoint observed : observations) {
223                        jacobian[i++] = f.gradient(observed.getX(), point);
224                    }
225
226                    return jacobian;
227                }
228            };
229        }
230
231        /** {@inheritDoc} */
232        public double[] value(double[] point) {
233            // compute the residuals
234            final double[] values = new double[observations.size()];
235            int i = 0;
236            for (WeightedObservedPoint observed : observations) {
237                values[i++] = f.value(observed.getX(), point);
238            }
239
240            return values;
241        }
242    }
243
244    /** Vectorial function computing function theoretical values. */
245    private class TheoreticalValuesFunction implements MultivariateDifferentiableVectorFunction {
246
247        /** Function to fit. */
248        private final ParametricUnivariateFunction f;
249
250        /** Simple constructor.
251         * @param f function to fit.
252         */
253        public TheoreticalValuesFunction(final ParametricUnivariateFunction f) {
254            this.f = f;
255        }
256
257        /** {@inheritDoc} */
258        public double[] value(double[] point) {
259            // compute the residuals
260            final double[] values = new double[observations.size()];
261            int i = 0;
262            for (WeightedObservedPoint observed : observations) {
263                values[i++] = f.value(observed.getX(), point);
264            }
265
266            return values;
267        }
268
269        /** {@inheritDoc} */
270        public DerivativeStructure[] value(DerivativeStructure[] point) {
271
272            // extract parameters
273            final double[] parameters = new double[point.length];
274            for (int k = 0; k < point.length; ++k) {
275                parameters[k] = point[k].getValue();
276            }
277
278            // compute the residuals
279            final DerivativeStructure[] values = new DerivativeStructure[observations.size()];
280            int i = 0;
281            for (WeightedObservedPoint observed : observations) {
282
283                // build the DerivativeStructure by adding first the value as a constant
284                // and then adding derivatives
285                DerivativeStructure vi = new DerivativeStructure(point.length, 1, f.value(observed.getX(), parameters));
286                for (int k = 0; k < point.length; ++k) {
287                    vi = vi.add(new DerivativeStructure(point.length, 1, k, 0.0));
288                }
289
290                values[i++] = vi;
291
292            }
293
294            return values;
295        }
296
297    }
298
299}