001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math3.optimization.fitting;
019
020import java.util.ArrayList;
021import java.util.List;
022
023import org.apache.commons.math3.analysis.DifferentiableMultivariateVectorFunction;
024import org.apache.commons.math3.analysis.MultivariateMatrixFunction;
025import org.apache.commons.math3.analysis.ParametricUnivariateFunction;
026import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
027import org.apache.commons.math3.analysis.differentiation.MultivariateDifferentiableVectorFunction;
028import org.apache.commons.math3.optimization.DifferentiableMultivariateVectorOptimizer;
029import org.apache.commons.math3.optimization.MultivariateDifferentiableVectorOptimizer;
030import org.apache.commons.math3.optimization.PointVectorValuePair;
031
032/** Fitter for parametric univariate real functions y = f(x).
033 * <br/>
034 * When a univariate real function y = f(x) does depend on some
035 * unknown parameters p<sub>0</sub>, p<sub>1</sub> ... p<sub>n-1</sub>,
036 * this class can be used to find these parameters. It does this
037 * by <em>fitting</em> the curve so it remains very close to a set of
038 * observed points (x<sub>0</sub>, y<sub>0</sub>), (x<sub>1</sub>,
039 * y<sub>1</sub>) ... (x<sub>k-1</sub>, y<sub>k-1</sub>). This fitting
040 * is done by finding the parameters values that minimizes the objective
041 * function &sum;(y<sub>i</sub>-f(x<sub>i</sub>))<sup>2</sup>. This is
042 * really a least squares problem.
043 *
044 * @param <T> Function to use for the fit.
045 *
046 * @deprecated As of 3.1 (to be removed in 4.0).
047 * @since 2.0
048 */
049@Deprecated
050public class CurveFitter<T extends ParametricUnivariateFunction> {
051
052    /** Optimizer to use for the fitting.
053     * @deprecated as of 3.1 replaced by {@link #optimizer}
054     */
055    @Deprecated
056    private final DifferentiableMultivariateVectorOptimizer oldOptimizer;
057
058    /** Optimizer to use for the fitting. */
059    private final MultivariateDifferentiableVectorOptimizer optimizer;
060
061    /** Observed points. */
062    private final List<WeightedObservedPoint> observations;
063
064    /** Simple constructor.
065     * @param optimizer optimizer to use for the fitting
066     * @deprecated as of 3.1 replaced by {@link #CurveFitter(MultivariateDifferentiableVectorOptimizer)}
067     */
068    @Deprecated
069    public CurveFitter(final DifferentiableMultivariateVectorOptimizer optimizer) {
070        this.oldOptimizer = optimizer;
071        this.optimizer    = null;
072        observations      = new ArrayList<WeightedObservedPoint>();
073    }
074
075    /** Simple constructor.
076     * @param optimizer optimizer to use for the fitting
077     * @since 3.1
078     */
079    public CurveFitter(final MultivariateDifferentiableVectorOptimizer optimizer) {
080        this.oldOptimizer = null;
081        this.optimizer    = optimizer;
082        observations      = new ArrayList<WeightedObservedPoint>();
083    }
084
085    /** Add an observed (x,y) point to the sample with unit weight.
086     * <p>Calling this method is equivalent to call
087     * {@code addObservedPoint(1.0, x, y)}.</p>
088     * @param x abscissa of the point
089     * @param y observed value of the point at x, after fitting we should
090     * have f(x) as close as possible to this value
091     * @see #addObservedPoint(double, double, double)
092     * @see #addObservedPoint(WeightedObservedPoint)
093     * @see #getObservations()
094     */
095    public void addObservedPoint(double x, double y) {
096        addObservedPoint(1.0, x, y);
097    }
098
099    /** Add an observed weighted (x,y) point to the sample.
100     * @param weight weight of the observed point in the fit
101     * @param x abscissa of the point
102     * @param y observed value of the point at x, after fitting we should
103     * have f(x) as close as possible to this value
104     * @see #addObservedPoint(double, double)
105     * @see #addObservedPoint(WeightedObservedPoint)
106     * @see #getObservations()
107     */
108    public void addObservedPoint(double weight, double x, double y) {
109        observations.add(new WeightedObservedPoint(weight, x, y));
110    }
111
112    /** Add an observed weighted (x,y) point to the sample.
113     * @param observed observed point to add
114     * @see #addObservedPoint(double, double)
115     * @see #addObservedPoint(double, double, double)
116     * @see #getObservations()
117     */
118    public void addObservedPoint(WeightedObservedPoint observed) {
119        observations.add(observed);
120    }
121
122    /** Get the observed points.
123     * @return observed points
124     * @see #addObservedPoint(double, double)
125     * @see #addObservedPoint(double, double, double)
126     * @see #addObservedPoint(WeightedObservedPoint)
127     */
128    public WeightedObservedPoint[] getObservations() {
129        return observations.toArray(new WeightedObservedPoint[observations.size()]);
130    }
131
132    /**
133     * Remove all observations.
134     */
135    public void clearObservations() {
136        observations.clear();
137    }
138
139    /**
140     * Fit a curve.
141     * This method compute the coefficients of the curve that best
142     * fit the sample of observed points previously given through calls
143     * to the {@link #addObservedPoint(WeightedObservedPoint)
144     * addObservedPoint} method.
145     *
146     * @param f parametric function to fit.
147     * @param initialGuess first guess of the function parameters.
148     * @return the fitted parameters.
149     * @throws org.apache.commons.math3.exception.DimensionMismatchException
150     * if the start point dimension is wrong.
151     */
152    public double[] fit(T f, final double[] initialGuess) {
153        return fit(Integer.MAX_VALUE, f, initialGuess);
154    }
155
156    /**
157     * Fit a curve.
158     * This method compute the coefficients of the curve that best
159     * fit the sample of observed points previously given through calls
160     * to the {@link #addObservedPoint(WeightedObservedPoint)
161     * addObservedPoint} method.
162     *
163     * @param f parametric function to fit.
164     * @param initialGuess first guess of the function parameters.
165     * @param maxEval Maximum number of function evaluations.
166     * @return the fitted parameters.
167     * @throws org.apache.commons.math3.exception.TooManyEvaluationsException
168     * if the number of allowed evaluations is exceeded.
169     * @throws org.apache.commons.math3.exception.DimensionMismatchException
170     * if the start point dimension is wrong.
171     * @since 3.0
172     */
173    public double[] fit(int maxEval, T f,
174                        final double[] initialGuess) {
175        // prepare least squares problem
176        double[] target  = new double[observations.size()];
177        double[] weights = new double[observations.size()];
178        int i = 0;
179        for (WeightedObservedPoint point : observations) {
180            target[i]  = point.getY();
181            weights[i] = point.getWeight();
182            ++i;
183        }
184
185        // perform the fit
186        final PointVectorValuePair optimum;
187        if (optimizer == null) {
188            // to be removed in 4.0
189            optimum = oldOptimizer.optimize(maxEval, new OldTheoreticalValuesFunction(f),
190                                            target, weights, initialGuess);
191        } else {
192            optimum = optimizer.optimize(maxEval, new TheoreticalValuesFunction(f),
193                                         target, weights, initialGuess);
194        }
195
196        // extract the coefficients
197        return optimum.getPointRef();
198    }
199
200    /** Vectorial function computing function theoretical values. */
201    @Deprecated
202    private class OldTheoreticalValuesFunction
203        implements DifferentiableMultivariateVectorFunction {
204        /** Function to fit. */
205        private final ParametricUnivariateFunction f;
206
207        /** Simple constructor.
208         * @param f function to fit.
209         */
210        OldTheoreticalValuesFunction(final ParametricUnivariateFunction f) {
211            this.f = f;
212        }
213
214        /** {@inheritDoc} */
215        public MultivariateMatrixFunction jacobian() {
216            return new MultivariateMatrixFunction() {
217                /** {@inheritDoc} */
218                public double[][] value(double[] point) {
219                    final double[][] jacobian = new double[observations.size()][];
220
221                    int i = 0;
222                    for (WeightedObservedPoint observed : observations) {
223                        jacobian[i++] = f.gradient(observed.getX(), point);
224                    }
225
226                    return jacobian;
227                }
228            };
229        }
230
231        /** {@inheritDoc} */
232        public double[] value(double[] point) {
233            // compute the residuals
234            final double[] values = new double[observations.size()];
235            int i = 0;
236            for (WeightedObservedPoint observed : observations) {
237                values[i++] = f.value(observed.getX(), point);
238            }
239
240            return values;
241        }
242    }
243
244    /** Vectorial function computing function theoretical values. */
245    private class TheoreticalValuesFunction implements MultivariateDifferentiableVectorFunction {
246
247        /** Function to fit. */
248        private final ParametricUnivariateFunction f;
249
250        /** Simple constructor.
251         * @param f function to fit.
252         */
253        TheoreticalValuesFunction(final ParametricUnivariateFunction f) {
254            this.f = f;
255        }
256
257        /** {@inheritDoc} */
258        public double[] value(double[] point) {
259            // compute the residuals
260            final double[] values = new double[observations.size()];
261            int i = 0;
262            for (WeightedObservedPoint observed : observations) {
263                values[i++] = f.value(observed.getX(), point);
264            }
265
266            return values;
267        }
268
269        /** {@inheritDoc} */
270        public DerivativeStructure[] value(DerivativeStructure[] point) {
271
272            // extract parameters
273            final double[] parameters = new double[point.length];
274            for (int k = 0; k < point.length; ++k) {
275                parameters[k] = point[k].getValue();
276            }
277
278            // compute the residuals
279            final DerivativeStructure[] values = new DerivativeStructure[observations.size()];
280            int i = 0;
281            for (WeightedObservedPoint observed : observations) {
282
283                // build the DerivativeStructure by adding first the value as a constant
284                // and then adding derivatives
285                DerivativeStructure vi = new DerivativeStructure(point.length, 1, f.value(observed.getX(), parameters));
286                for (int k = 0; k < point.length; ++k) {
287                    vi = vi.add(new DerivativeStructure(point.length, 1, k, 0.0));
288                }
289
290                values[i++] = vi;
291
292            }
293
294            return values;
295        }
296
297    }
298
299}