001 /*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements. See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License. You may obtain a copy of the License at
008 *
009 * http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018 package org.apache.commons.math.optimization.fitting;
019
020 import java.util.ArrayList;
021 import java.util.List;
022
023 import org.apache.commons.math.analysis.DifferentiableMultivariateVectorialFunction;
024 import org.apache.commons.math.analysis.ParametricUnivariateRealFunction;
025 import org.apache.commons.math.analysis.MultivariateMatrixFunction;
026 import org.apache.commons.math.optimization.DifferentiableMultivariateVectorialOptimizer;
027 import org.apache.commons.math.optimization.VectorialPointValuePair;
028
029 /** Fitter for parametric univariate real functions y = f(x).
030 * <p>When a univariate real function y = f(x) does depend on some
031 * unknown parameters p<sub>0</sub>, p<sub>1</sub> ... p<sub>n-1</sub>,
032 * this class can be used to find these parameters. It does this
033 * by <em>fitting</em> the curve so it remains very close to a set of
034 * observed points (x<sub>0</sub>, y<sub>0</sub>), (x<sub>1</sub>,
035 * y<sub>1</sub>) ... (x<sub>k-1</sub>, y<sub>k-1</sub>). This fitting
036 * is done by finding the parameters values that minimizes the objective
037 * function ∑(y<sub>i</sub>-f(x<sub>i</sub>))<sup>2</sup>. This is
038 * really a least squares problem.</p>
039 * @version $Id: CurveFitter.java 1179928 2011-10-07 03:20:39Z psteitz $
040 * @since 2.0
041 */
042 public class CurveFitter {
043 /** Optimizer to use for the fitting. */
044 private final DifferentiableMultivariateVectorialOptimizer optimizer;
045 /** Observed points. */
046 private final List<WeightedObservedPoint> observations;
047
048 /** Simple constructor.
049 * @param optimizer optimizer to use for the fitting
050 */
051 public CurveFitter(final DifferentiableMultivariateVectorialOptimizer optimizer) {
052 this.optimizer = optimizer;
053 observations = new ArrayList<WeightedObservedPoint>();
054 }
055
056 /** Add an observed (x,y) point to the sample with unit weight.
057 * <p>Calling this method is equivalent to call
058 * {@code addObservedPoint(1.0, x, y)}.</p>
059 * @param x abscissa of the point
060 * @param y observed value of the point at x, after fitting we should
061 * have f(x) as close as possible to this value
062 * @see #addObservedPoint(double, double, double)
063 * @see #addObservedPoint(WeightedObservedPoint)
064 * @see #getObservations()
065 */
066 public void addObservedPoint(double x, double y) {
067 addObservedPoint(1.0, x, y);
068 }
069
070 /** Add an observed weighted (x,y) point to the sample.
071 * @param weight weight of the observed point in the fit
072 * @param x abscissa of the point
073 * @param y observed value of the point at x, after fitting we should
074 * have f(x) as close as possible to this value
075 * @see #addObservedPoint(double, double)
076 * @see #addObservedPoint(WeightedObservedPoint)
077 * @see #getObservations()
078 */
079 public void addObservedPoint(double weight, double x, double y) {
080 observations.add(new WeightedObservedPoint(weight, x, y));
081 }
082
083 /** Add an observed weighted (x,y) point to the sample.
084 * @param observed observed point to add
085 * @see #addObservedPoint(double, double)
086 * @see #addObservedPoint(double, double, double)
087 * @see #getObservations()
088 */
089 public void addObservedPoint(WeightedObservedPoint observed) {
090 observations.add(observed);
091 }
092
093 /** Get the observed points.
094 * @return observed points
095 * @see #addObservedPoint(double, double)
096 * @see #addObservedPoint(double, double, double)
097 * @see #addObservedPoint(WeightedObservedPoint)
098 */
099 public WeightedObservedPoint[] getObservations() {
100 return observations.toArray(new WeightedObservedPoint[observations.size()]);
101 }
102
103 /**
104 * Remove all observations.
105 */
106 public void clearObservations() {
107 observations.clear();
108 }
109
110 /**
111 * Fit a curve.
112 * This method compute the coefficients of the curve that best
113 * fit the sample of observed points previously given through calls
114 * to the {@link #addObservedPoint(WeightedObservedPoint)
115 * addObservedPoint} method.
116 *
117 * @param f parametric function to fit.
118 * @param initialGuess first guess of the function parameters.
119 * @return the fitted parameters.
120 * @throws org.apache.commons.math.exception.DimensionMismatchException
121 * if the start point dimension is wrong.
122 */
123 public double[] fit(final ParametricUnivariateRealFunction f, final double[] initialGuess) {
124 return fit(Integer.MAX_VALUE, f, initialGuess);
125 }
126
127 /**
128 * Fit a curve.
129 * This method compute the coefficients of the curve that best
130 * fit the sample of observed points previously given through calls
131 * to the {@link #addObservedPoint(WeightedObservedPoint)
132 * addObservedPoint} method.
133 *
134 * @param f parametric function to fit.
135 * @param initialGuess first guess of the function parameters.
136 * @param maxEval Maximum number of function evaluations.
137 * @return the fitted parameters.
138 * @throws org.apache.commons.math.exception.TooManyEvaluationsException
139 * if the number of allowed evaluations is exceeded.
140 * @throws org.apache.commons.math.exception.DimensionMismatchException
141 * if the start point dimension is wrong.
142 * @since 3.0
143 */
144 public double[] fit(int maxEval, final ParametricUnivariateRealFunction f,
145 final double[] initialGuess) {
146 // prepare least squares problem
147 double[] target = new double[observations.size()];
148 double[] weights = new double[observations.size()];
149 int i = 0;
150 for (WeightedObservedPoint point : observations) {
151 target[i] = point.getY();
152 weights[i] = point.getWeight();
153 ++i;
154 }
155
156 // perform the fit
157 VectorialPointValuePair optimum =
158 optimizer.optimize(maxEval, new TheoreticalValuesFunction(f),
159 target, weights, initialGuess);
160
161 // extract the coefficients
162 return optimum.getPointRef();
163 }
164
165 /** Vectorial function computing function theoretical values. */
166 private class TheoreticalValuesFunction
167 implements DifferentiableMultivariateVectorialFunction {
168 /** Function to fit. */
169 private final ParametricUnivariateRealFunction f;
170
171 /** Simple constructor.
172 * @param f function to fit.
173 */
174 public TheoreticalValuesFunction(final ParametricUnivariateRealFunction f) {
175 this.f = f;
176 }
177
178 /** {@inheritDoc} */
179 public MultivariateMatrixFunction jacobian() {
180 return new MultivariateMatrixFunction() {
181 public double[][] value(double[] point) {
182 final double[][] jacobian = new double[observations.size()][];
183
184 int i = 0;
185 for (WeightedObservedPoint observed : observations) {
186 jacobian[i++] = f.gradient(observed.getX(), point);
187 }
188
189 return jacobian;
190 }
191 };
192 }
193
194 /** {@inheritDoc} */
195 public double[] value(double[] point) {
196 // compute the residuals
197 final double[] values = new double[observations.size()];
198 int i = 0;
199 for (WeightedObservedPoint observed : observations) {
200 values[i++] = f.value(observed.getX(), point);
201 }
202
203 return values;
204 }
205 }
206 }