001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.math3.fitting;
018
019import java.util.Collection;
020
021import org.apache.commons.math3.analysis.MultivariateVectorFunction;
022import org.apache.commons.math3.analysis.MultivariateMatrixFunction;
023import org.apache.commons.math3.analysis.ParametricUnivariateFunction;
024import org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer;
025import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem;
026import org.apache.commons.math3.fitting.leastsquares.LevenbergMarquardtOptimizer;
027
028/**
029 * Base class that contains common code for fitting parametric univariate
030 * real functions <code>y = f(p<sub>i</sub>;x)</code>, where {@code x} is
031 * the independent variable and the <code>p<sub>i</sub></code> are the
032 * <em>parameters</em>.
033 * <br/>
034 * A fitter will find the optimal values of the parameters by
035 * <em>fitting</em> the curve so it remains very close to a set of
036 * {@code N} observed points <code>(x<sub>k</sub>, y<sub>k</sub>)</code>,
037 * {@code 0 <= k < N}.
038 * <br/>
039 * An algorithm usually performs the fit by finding the parameter
040 * values that minimizes the objective function
041 * <pre><code>
042 *  &sum;y<sub>k</sub> - f(x<sub>k</sub>)<sup>2</sup>,
043 * </code></pre>
044 * which is actually a least-squares problem.
045 * This class contains boilerplate code for calling the
046 * {@link #fit(Collection)} method for obtaining the parameters.
047 * The problem setup, such as the choice of optimization algorithm
048 * for fitting a specific function is delegated to subclasses.
049 *
050 * @since 3.3
051 */
052public abstract class AbstractCurveFitter {
053    /**
054     * Fits a curve.
055     * This method computes the coefficients of the curve that best
056     * fit the sample of observed points.
057     *
058     * @param points Observations.
059     * @return the fitted parameters.
060     */
061    public double[] fit(Collection<WeightedObservedPoint> points) {
062        // Perform the fit.
063        return getOptimizer().optimize(getProblem(points)).getPoint().toArray();
064    }
065
066    /**
067     * Creates an optimizer set up to fit the appropriate curve.
068     * <p>
069     * The default implementation uses a {@link LevenbergMarquardtOptimizer
070     * Levenberg-Marquardt} optimizer.
071     * </p>
072     * @return the optimizer to use for fitting the curve to the
073     * given {@code points}.
074     */
075    protected LeastSquaresOptimizer getOptimizer() {
076        return new LevenbergMarquardtOptimizer();
077    }
078
079    /**
080     * Creates a least squares problem corresponding to the appropriate curve.
081     *
082     * @param points Sample points.
083     * @return the least squares problem to use for fitting the curve to the
084     * given {@code points}.
085     */
086    protected abstract LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> points);
087
088    /**
089     * Vector function for computing function theoretical values.
090     */
091    protected static class TheoreticalValuesFunction {
092        /** Function to fit. */
093        private final ParametricUnivariateFunction f;
094        /** Observations. */
095        private final double[] points;
096
097        /**
098         * @param f function to fit.
099         * @param observations Observations.
100         */
101        public TheoreticalValuesFunction(final ParametricUnivariateFunction f,
102                                         final Collection<WeightedObservedPoint> observations) {
103            this.f = f;
104
105            final int len = observations.size();
106            this.points = new double[len];
107            int i = 0;
108            for (WeightedObservedPoint obs : observations) {
109                this.points[i++] = obs.getX();
110            }
111        }
112
113        /**
114         * @return the model function values.
115         */
116        public MultivariateVectorFunction getModelFunction() {
117            return new MultivariateVectorFunction() {
118                /** {@inheritDoc} */
119                public double[] value(double[] p) {
120                    final int len = points.length;
121                    final double[] values = new double[len];
122                    for (int i = 0; i < len; i++) {
123                        values[i] = f.value(points[i], p);
124                    }
125
126                    return values;
127                }
128            };
129        }
130
131        /**
132         * @return the model function Jacobian.
133         */
134        public MultivariateMatrixFunction getModelFunctionJacobian() {
135            return new MultivariateMatrixFunction() {
136                /** {@inheritDoc} */
137                public double[][] value(double[] p) {
138                    final int len = points.length;
139                    final double[][] jacobian = new double[len][];
140                    for (int i = 0; i < len; i++) {
141                        jacobian[i] = f.gradient(points[i], p);
142                    }
143                    return jacobian;
144                }
145            };
146        }
147    }
148}