1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17 package org.apache.commons.math4.legacy.fitting;
18
19 import java.util.Arrays;
20 import java.util.Collection;
21
22 import org.apache.commons.math4.legacy.analysis.MultivariateMatrixFunction;
23 import org.apache.commons.math4.legacy.analysis.MultivariateVectorFunction;
24 import org.apache.commons.math4.legacy.analysis.ParametricUnivariateFunction;
25 import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresOptimizer;
26 import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresProblem;
27 import org.apache.commons.math4.legacy.fitting.leastsquares.LevenbergMarquardtOptimizer;
28
29 /**
30 * Base class that contains common code for fitting parametric univariate
31 * real functions <code>y = f(p<sub>i</sub>;x)</code>, where {@code x} is
32 * the independent variable and the <code>p<sub>i</sub></code> are the
33 * <em>parameters</em>.
34 * <br>
35 * A fitter will find the optimal values of the parameters by
36 * <em>fitting</em> the curve so it remains very close to a set of
37 * {@code N} observed points <code>(x<sub>k</sub>, y<sub>k</sub>)</code>,
38 * {@code 0 <= k < N}.
39 * <br>
40 * An algorithm usually performs the fit by finding the parameter
41 * values that minimizes the objective function
42 * <pre><code>
43 * ∑y<sub>k</sub> - f(x<sub>k</sub>)<sup>2</sup>,
44 * </code></pre>
45 * which is actually a least-squares problem.
46 * This class contains boilerplate code for calling the
47 * {@link #fit(Collection)} method for obtaining the parameters.
48 * The problem setup, such as the choice of optimization algorithm
49 * for fitting a specific function is delegated to subclasses.
50 *
51 * @since 3.3
52 */
53 public abstract class AbstractCurveFitter {
54 /**
55 * Fits a curve.
56 * This method computes the coefficients of the curve that best
57 * fit the sample of observed points.
58 *
59 * @param points Observations.
60 * @return the fitted parameters.
61 */
62 public double[] fit(Collection<WeightedObservedPoint> points) {
63 // Perform the fit.
64 return getOptimizer().optimize(getProblem(points)).getPoint().toArray();
65 }
66
67 /**
68 * Creates an optimizer set up to fit the appropriate curve.
69 * <p>
70 * The default implementation uses a {@link LevenbergMarquardtOptimizer
71 * Levenberg-Marquardt} optimizer.
72 * </p>
73 * @return the optimizer to use for fitting the curve to the
74 * given {@code points}.
75 */
76 protected LeastSquaresOptimizer getOptimizer() {
77 return new LevenbergMarquardtOptimizer();
78 }
79
80 /**
81 * Creates a least squares problem corresponding to the appropriate curve.
82 *
83 * @param points Sample points.
84 * @return the least squares problem to use for fitting the curve to the
85 * given {@code points}.
86 */
87 protected abstract LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> points);
88
89 /**
90 * Vector function for computing function theoretical values.
91 */
92 protected static class TheoreticalValuesFunction {
93 /** Function to fit. */
94 private final ParametricUnivariateFunction f;
95 /** Observations. */
96 private final double[] points;
97
98 /**
99 * @param f function to fit.
100 * @param observations Observations.
101 */
102 public TheoreticalValuesFunction(final ParametricUnivariateFunction f,
103 final Collection<WeightedObservedPoint> observations) {
104 this.f = f;
105 this.points = observations.stream().mapToDouble(WeightedObservedPoint::getX).toArray();
106 }
107
108 /**
109 * @return the model function values.
110 */
111 public MultivariateVectorFunction getModelFunction() {
112 return new MultivariateVectorFunction() {
113 /** {@inheritDoc} */
114 @Override
115 public double[] value(double[] p) {
116 return Arrays.stream(points).map(point -> f.value(point, p)).toArray();
117 }
118 };
119 }
120
121 /**
122 * @return the model function Jacobian.
123 */
124 public MultivariateMatrixFunction getModelFunctionJacobian() {
125 return new MultivariateMatrixFunction() {
126 /** {@inheritDoc} */
127 @Override
128 public double[][] value(double[] p) {
129 final int len = points.length;
130 final double[][] jacobian = new double[len][];
131 for (int i = 0; i < len; i++) {
132 jacobian[i] = f.gradient(points[i], p);
133 }
134 return jacobian;
135 }
136 };
137 }
138 }
139 }