001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.math4.fitting;
018
019import java.util.Arrays;
020import java.util.Collection;
021
022import org.apache.commons.math4.analysis.MultivariateMatrixFunction;
023import org.apache.commons.math4.analysis.MultivariateVectorFunction;
024import org.apache.commons.math4.analysis.ParametricUnivariateFunction;
025import org.apache.commons.math4.fitting.leastsquares.LeastSquaresOptimizer;
026import org.apache.commons.math4.fitting.leastsquares.LeastSquaresProblem;
027import org.apache.commons.math4.fitting.leastsquares.LevenbergMarquardtOptimizer;
028
029/**
030 * Base class that contains common code for fitting parametric univariate
031 * real functions <code>y = f(p<sub>i</sub>;x)</code>, where {@code x} is
032 * the independent variable and the <code>p<sub>i</sub></code> are the
033 * <em>parameters</em>.
034 * <br>
035 * A fitter will find the optimal values of the parameters by
036 * <em>fitting</em> the curve so it remains very close to a set of
037 * {@code N} observed points <code>(x<sub>k</sub>, y<sub>k</sub>)</code>,
038 * {@code 0 <= k < N}.
039 * <br>
040 * An algorithm usually performs the fit by finding the parameter
041 * values that minimizes the objective function
042 * <pre><code>
043 *  &sum;y<sub>k</sub> - f(x<sub>k</sub>)<sup>2</sup>,
044 * </code></pre>
045 * which is actually a least-squares problem.
046 * This class contains boilerplate code for calling the
047 * {@link #fit(Collection)} method for obtaining the parameters.
048 * The problem setup, such as the choice of optimization algorithm
049 * for fitting a specific function is delegated to subclasses.
050 *
051 * @since 3.3
052 */
053public abstract class AbstractCurveFitter {
054    /**
055     * Fits a curve.
056     * This method computes the coefficients of the curve that best
057     * fit the sample of observed points.
058     *
059     * @param points Observations.
060     * @return the fitted parameters.
061     */
062    public double[] fit(Collection<WeightedObservedPoint> points) {
063        // Perform the fit.
064        return getOptimizer().optimize(getProblem(points)).getPoint().toArray();
065    }
066
067    /**
068     * Creates an optimizer set up to fit the appropriate curve.
069     * <p>
070     * The default implementation uses a {@link LevenbergMarquardtOptimizer
071     * Levenberg-Marquardt} optimizer.
072     * </p>
073     * @return the optimizer to use for fitting the curve to the
074     * given {@code points}.
075     */
076    protected LeastSquaresOptimizer getOptimizer() {
077        return new LevenbergMarquardtOptimizer();
078    }
079
080    /**
081     * Creates a least squares problem corresponding to the appropriate curve.
082     *
083     * @param points Sample points.
084     * @return the least squares problem to use for fitting the curve to the
085     * given {@code points}.
086     */
087    protected abstract LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> points);
088
089    /**
090     * Vector function for computing function theoretical values.
091     */
092    protected static class TheoreticalValuesFunction {
093        /** Function to fit. */
094        private final ParametricUnivariateFunction f;
095        /** Observations. */
096        private final double[] points;
097
098        /**
099         * @param f function to fit.
100         * @param observations Observations.
101         */
102        public TheoreticalValuesFunction(final ParametricUnivariateFunction f,
103                                         final Collection<WeightedObservedPoint> observations) {
104            this.f = f;
105            this.points = observations.stream().mapToDouble(WeightedObservedPoint::getX).toArray();
106        }
107
108        /**
109         * @return the model function values.
110         */
111        public MultivariateVectorFunction getModelFunction() {
112            return new MultivariateVectorFunction() {
113                /** {@inheritDoc} */
114                @Override
115                public double[] value(double[] p) {
116                    return Arrays.stream(points).map(point -> f.value(point, p)).toArray();
117                }
118            };
119        }
120
121        /**
122         * @return the model function Jacobian.
123         */
124        public MultivariateMatrixFunction getModelFunctionJacobian() {
125            return new MultivariateMatrixFunction() {
126                /** {@inheritDoc} */
127                @Override
128                public double[][] value(double[] p) {
129                    final int len = points.length;
130                    final double[][] jacobian = new double[len][];
131                    for (int i = 0; i < len; i++) {
132                        jacobian[i] = f.gradient(points[i], p);
133                    }
134                    return jacobian;
135                }
136            };
137        }
138    }
139}