1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17 package org.apache.commons.math4.legacy.fitting;
18
19 import java.util.Collections;
20 import java.util.Collection;
21 import java.util.Comparator;
22 import java.util.List;
23 import java.util.ArrayList;
24
25 import org.apache.commons.math4.legacy.exception.ZeroException;
26 import org.apache.commons.math4.legacy.exception.OutOfRangeException;
27 import org.apache.commons.math4.legacy.analysis.ParametricUnivariateFunction;
28 import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresBuilder;
29 import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresProblem;
30 import org.apache.commons.math4.legacy.linear.DiagonalMatrix;
31
32 /**
33 * Fits points to a user-defined {@link ParametricUnivariateFunction function}.
34 *
35 * @since 3.4
36 */
37 public class SimpleCurveFitter extends AbstractCurveFitter {
38 /** Function to fit. */
39 private final ParametricUnivariateFunction function;
40 /** Initial guess for the parameters. */
41 private final double[] initialGuess;
42 /** Parameter guesser. */
43 private final ParameterGuesser guesser;
44 /** Maximum number of iterations of the optimization algorithm. */
45 private final int maxIter;
46
47 /**
48 * Constructor used by the factory methods.
49 *
50 * @param function Function to fit.
51 * @param initialGuess Initial guess. Cannot be {@code null}. Its length must
52 * be consistent with the number of parameters of the {@code function} to fit.
53 * @param guesser Method for providing an initial guess (if {@code initialGuess}
54 * is {@code null}).
55 * @param maxIter Maximum number of iterations of the optimization algorithm.
56 */
57 protected SimpleCurveFitter(ParametricUnivariateFunction function,
58 double[] initialGuess,
59 ParameterGuesser guesser,
60 int maxIter) {
61 this.function = function;
62 this.initialGuess = initialGuess;
63 this.guesser = guesser;
64 this.maxIter = maxIter;
65 }
66
67 /**
68 * Creates a curve fitter.
69 * The maximum number of iterations of the optimization algorithm is set
70 * to {@link Integer#MAX_VALUE}.
71 *
72 * @param f Function to fit.
73 * @param start Initial guess for the parameters. Cannot be {@code null}.
74 * Its length must be consistent with the number of parameters of the
75 * function to fit.
76 * @return a curve fitter.
77 *
78 * @see #withStartPoint(double[])
79 * @see #withMaxIterations(int)
80 */
81 public static SimpleCurveFitter create(ParametricUnivariateFunction f,
82 double[] start) {
83 return new SimpleCurveFitter(f, start, null, Integer.MAX_VALUE);
84 }
85
86 /**
87 * Creates a curve fitter.
88 * The maximum number of iterations of the optimization algorithm is set
89 * to {@link Integer#MAX_VALUE}.
90 *
91 * @param f Function to fit.
92 * @param guesser Method for providing an initial guess.
93 * @return a curve fitter.
94 *
95 * @see #withStartPoint(double[])
96 * @see #withMaxIterations(int)
97 */
98 public static SimpleCurveFitter create(ParametricUnivariateFunction f,
99 ParameterGuesser guesser) {
100 return new SimpleCurveFitter(f, null, guesser, Integer.MAX_VALUE);
101 }
102
103 /**
104 * Configure the start point (initial guess).
105 * @param newStart new start point (initial guess)
106 * @return a new instance.
107 */
108 public SimpleCurveFitter withStartPoint(double[] newStart) {
109 return new SimpleCurveFitter(function,
110 newStart.clone(),
111 null,
112 maxIter);
113 }
114
115 /**
116 * Configure the maximum number of iterations.
117 * @param newMaxIter maximum number of iterations
118 * @return a new instance.
119 */
120 public SimpleCurveFitter withMaxIterations(int newMaxIter) {
121 return new SimpleCurveFitter(function,
122 initialGuess,
123 guesser,
124 newMaxIter);
125 }
126
127 /** {@inheritDoc} */
128 @Override
129 protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> observations) {
130 // Prepare least-squares problem.
131 final int len = observations.size();
132 final double[] target = new double[len];
133 final double[] weights = new double[len];
134
135 int count = 0;
136 for (WeightedObservedPoint obs : observations) {
137 target[count] = obs.getY();
138 weights[count] = obs.getWeight();
139 ++count;
140 }
141
142 final AbstractCurveFitter.TheoreticalValuesFunction model
143 = new AbstractCurveFitter.TheoreticalValuesFunction(function,
144 observations);
145
146 final double[] startPoint = initialGuess != null ?
147 initialGuess :
148 // Compute estimation.
149 guesser.guess(observations);
150
151 // Create an optimizer for fitting the curve to the observed points.
152 return new LeastSquaresBuilder().
153 maxEvaluations(Integer.MAX_VALUE).
154 maxIterations(maxIter).
155 start(startPoint).
156 target(target).
157 weight(new DiagonalMatrix(weights)).
158 model(model.getModelFunction(), model.getModelFunctionJacobian()).
159 build();
160 }
161
162 /**
163 * Guesses the parameters.
164 */
165 public abstract static class ParameterGuesser {
166 /** Comparator. */
167 private static final Comparator<WeightedObservedPoint> CMP = new Comparator<WeightedObservedPoint>() {
168 /** {@inheritDoc} */
169 @Override
170 public int compare(WeightedObservedPoint p1,
171 WeightedObservedPoint p2) {
172 if (p1 == null && p2 == null) {
173 return 0;
174 }
175 if (p1 == null) {
176 return -1;
177 }
178 if (p2 == null) {
179 return 1;
180 }
181 int comp = Double.compare(p1.getX(), p2.getX());
182 if (comp != 0) {
183 return comp;
184 }
185 comp = Double.compare(p1.getY(), p2.getY());
186 if (comp != 0) {
187 return comp;
188 }
189 return Double.compare(p1.getWeight(), p2.getWeight());
190 }
191 };
192
193 /**
194 * Computes an estimation of the parameters.
195 *
196 * @param obs Observations.
197 * @return the guessed parameters.
198 */
199 public abstract double[] guess(Collection<WeightedObservedPoint> obs);
200
201 /**
202 * Sort the observations.
203 *
204 * @param unsorted Input observations.
205 * @return the input observations, sorted.
206 */
207 protected List<WeightedObservedPoint> sortObservations(Collection<WeightedObservedPoint> unsorted) {
208 final List<WeightedObservedPoint> observations = new ArrayList<>(unsorted);
209 Collections.sort(observations, CMP);
210 return observations;
211 }
212
213 /**
214 * Finds index of point in specified points with the largest Y.
215 *
216 * @param points Points to search.
217 * @return the index in specified points array.
218 */
219 protected int findMaxY(WeightedObservedPoint[] points) {
220 int maxYIdx = 0;
221 for (int i = 1; i < points.length; i++) {
222 if (points[i].getY() > points[maxYIdx].getY()) {
223 maxYIdx = i;
224 }
225 }
226 return maxYIdx;
227 }
228
229 /**
230 * Interpolates using the specified points to determine X at the
231 * specified Y.
232 *
233 * @param points Points to use for interpolation.
234 * @param startIdx Index within points from which to start the search for
235 * interpolation bounds points.
236 * @param idxStep Index step for searching interpolation bounds points.
237 * @param y Y value for which X should be determined.
238 * @return the value of X for the specified Y.
239 * @throws ZeroException if {@code idxStep} is 0.
240 * @throws OutOfRangeException if specified {@code y} is not within the
241 * range of the specified {@code points}.
242 */
243 protected double interpolateXAtY(WeightedObservedPoint[] points,
244 int startIdx,
245 int idxStep,
246 double y) {
247 if (idxStep == 0) {
248 throw new ZeroException();
249 }
250 final WeightedObservedPoint[] twoPoints
251 = getInterpolationPointsForY(points, startIdx, idxStep, y);
252 final WeightedObservedPoint p1 = twoPoints[0];
253 final WeightedObservedPoint p2 = twoPoints[1];
254 if (p1.getY() == y) {
255 return p1.getX();
256 }
257 if (p2.getY() == y) {
258 return p2.getX();
259 }
260 return p1.getX() + (((y - p1.getY()) * (p2.getX() - p1.getX())) /
261 (p2.getY() - p1.getY()));
262 }
263
264 /**
265 * Gets the two bounding interpolation points from the specified points
266 * suitable for determining X at the specified Y.
267 *
268 * @param points Points to use for interpolation.
269 * @param startIdx Index within points from which to start search for
270 * interpolation bounds points.
271 * @param idxStep Index step for search for interpolation bounds points.
272 * @param y Y value for which X should be determined.
273 * @return the array containing two points suitable for determining X at
274 * the specified Y.
275 * @throws ZeroException if {@code idxStep} is 0.
276 * @throws OutOfRangeException if specified {@code y} is not within the
277 * range of the specified {@code points}.
278 */
279 private WeightedObservedPoint[] getInterpolationPointsForY(WeightedObservedPoint[] points,
280 int startIdx,
281 int idxStep,
282 double y) {
283 if (idxStep == 0) {
284 throw new ZeroException();
285 }
286 for (int i = startIdx;
287 idxStep < 0 ? i + idxStep >= 0 : i + idxStep < points.length;
288 i += idxStep) {
289 final WeightedObservedPoint p1 = points[i];
290 final WeightedObservedPoint p2 = points[i + idxStep];
291 if (isBetween(y, p1.getY(), p2.getY())) {
292 if (idxStep < 0) {
293 return new WeightedObservedPoint[] { p2, p1 };
294 } else {
295 return new WeightedObservedPoint[] { p1, p2 };
296 }
297 }
298 }
299
300 // Boundaries are replaced by dummy values because the raised
301 // exception is caught and the message never displayed.
302 // TODO: Exceptions should not be used for flow control.
303 throw new OutOfRangeException(y,
304 Double.NEGATIVE_INFINITY,
305 Double.POSITIVE_INFINITY);
306 }
307
308 /**
309 * Determines whether a value is between two other values.
310 *
311 * @param value Value to test whether it is between {@code boundary1}
312 * and {@code boundary2}.
313 * @param boundary1 One end of the range.
314 * @param boundary2 Other end of the range.
315 * @return {@code true} if {@code value} is between {@code boundary1} and
316 * {@code boundary2} (inclusive), {@code false} otherwise.
317 */
318 private boolean isBetween(double value,
319 double boundary1,
320 double boundary2) {
321 return (value >= boundary1 && value <= boundary2) ||
322 (value >= boundary2 && value <= boundary1);
323 }
324 }
325 }