1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.math4.legacy.fitting;
18
19 import java.util.Collections;
20 import java.util.Collection;
21 import java.util.Comparator;
22 import java.util.List;
23 import java.util.ArrayList;
24
25 import org.apache.commons.math4.legacy.exception.ZeroException;
26 import org.apache.commons.math4.legacy.exception.OutOfRangeException;
27 import org.apache.commons.math4.legacy.analysis.ParametricUnivariateFunction;
28 import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresBuilder;
29 import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresProblem;
30 import org.apache.commons.math4.legacy.linear.DiagonalMatrix;
31
32
33
34
35
36
37 public class SimpleCurveFitter extends AbstractCurveFitter {
38
39 private final ParametricUnivariateFunction function;
40
41 private final double[] initialGuess;
42
43 private final ParameterGuesser guesser;
44
45 private final int maxIter;
46
47
48
49
50
51
52
53
54
55
56
57 protected SimpleCurveFitter(ParametricUnivariateFunction function,
58 double[] initialGuess,
59 ParameterGuesser guesser,
60 int maxIter) {
61 this.function = function;
62 this.initialGuess = initialGuess;
63 this.guesser = guesser;
64 this.maxIter = maxIter;
65 }
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81 public static SimpleCurveFitter create(ParametricUnivariateFunction f,
82 double[] start) {
83 return new SimpleCurveFitter(f, start, null, Integer.MAX_VALUE);
84 }
85
86
87
88
89
90
91
92
93
94
95
96
97
98 public static SimpleCurveFitter create(ParametricUnivariateFunction f,
99 ParameterGuesser guesser) {
100 return new SimpleCurveFitter(f, null, guesser, Integer.MAX_VALUE);
101 }
102
103
104
105
106
107
108 public SimpleCurveFitter withStartPoint(double[] newStart) {
109 return new SimpleCurveFitter(function,
110 newStart.clone(),
111 null,
112 maxIter);
113 }
114
115
116
117
118
119
120 public SimpleCurveFitter withMaxIterations(int newMaxIter) {
121 return new SimpleCurveFitter(function,
122 initialGuess,
123 guesser,
124 newMaxIter);
125 }
126
127
128 @Override
129 protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> observations) {
130
131 final int len = observations.size();
132 final double[] target = new double[len];
133 final double[] weights = new double[len];
134
135 int count = 0;
136 for (WeightedObservedPoint obs : observations) {
137 target[count] = obs.getY();
138 weights[count] = obs.getWeight();
139 ++count;
140 }
141
142 final AbstractCurveFitter.TheoreticalValuesFunction model
143 = new AbstractCurveFitter.TheoreticalValuesFunction(function,
144 observations);
145
146 final double[] startPoint = initialGuess != null ?
147 initialGuess :
148
149 guesser.guess(observations);
150
151
152 return new LeastSquaresBuilder().
153 maxEvaluations(Integer.MAX_VALUE).
154 maxIterations(maxIter).
155 start(startPoint).
156 target(target).
157 weight(new DiagonalMatrix(weights)).
158 model(model.getModelFunction(), model.getModelFunctionJacobian()).
159 build();
160 }
161
162
163
164
165 public abstract static class ParameterGuesser {
166
167 private static final Comparator<WeightedObservedPoint> CMP = new Comparator<WeightedObservedPoint>() {
168
169 @Override
170 public int compare(WeightedObservedPoint p1,
171 WeightedObservedPoint p2) {
172 if (p1 == null && p2 == null) {
173 return 0;
174 }
175 if (p1 == null) {
176 return -1;
177 }
178 if (p2 == null) {
179 return 1;
180 }
181 int comp = Double.compare(p1.getX(), p2.getX());
182 if (comp != 0) {
183 return comp;
184 }
185 comp = Double.compare(p1.getY(), p2.getY());
186 if (comp != 0) {
187 return comp;
188 }
189 return Double.compare(p1.getWeight(), p2.getWeight());
190 }
191 };
192
193
194
195
196
197
198
199 public abstract double[] guess(Collection<WeightedObservedPoint> obs);
200
201
202
203
204
205
206
207 protected List<WeightedObservedPoint> sortObservations(Collection<WeightedObservedPoint> unsorted) {
208 final List<WeightedObservedPoint> observations = new ArrayList<>(unsorted);
209 Collections.sort(observations, CMP);
210 return observations;
211 }
212
213
214
215
216
217
218
219 protected int findMaxY(WeightedObservedPoint[] points) {
220 int maxYIdx = 0;
221 for (int i = 1; i < points.length; i++) {
222 if (points[i].getY() > points[maxYIdx].getY()) {
223 maxYIdx = i;
224 }
225 }
226 return maxYIdx;
227 }
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243 protected double interpolateXAtY(WeightedObservedPoint[] points,
244 int startIdx,
245 int idxStep,
246 double y) {
247 if (idxStep == 0) {
248 throw new ZeroException();
249 }
250 final WeightedObservedPoint[] twoPoints
251 = getInterpolationPointsForY(points, startIdx, idxStep, y);
252 final WeightedObservedPoint p1 = twoPoints[0];
253 final WeightedObservedPoint p2 = twoPoints[1];
254 if (p1.getY() == y) {
255 return p1.getX();
256 }
257 if (p2.getY() == y) {
258 return p2.getX();
259 }
260 return p1.getX() + (((y - p1.getY()) * (p2.getX() - p1.getX())) /
261 (p2.getY() - p1.getY()));
262 }
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279 private WeightedObservedPoint[] getInterpolationPointsForY(WeightedObservedPoint[] points,
280 int startIdx,
281 int idxStep,
282 double y) {
283 if (idxStep == 0) {
284 throw new ZeroException();
285 }
286 for (int i = startIdx;
287 idxStep < 0 ? i + idxStep >= 0 : i + idxStep < points.length;
288 i += idxStep) {
289 final WeightedObservedPoint p1 = points[i];
290 final WeightedObservedPoint p2 = points[i + idxStep];
291 if (isBetween(y, p1.getY(), p2.getY())) {
292 if (idxStep < 0) {
293 return new WeightedObservedPoint[] { p2, p1 };
294 } else {
295 return new WeightedObservedPoint[] { p1, p2 };
296 }
297 }
298 }
299
300
301
302
303 throw new OutOfRangeException(y,
304 Double.NEGATIVE_INFINITY,
305 Double.POSITIVE_INFINITY);
306 }
307
308
309
310
311
312
313
314
315
316
317
318 private boolean isBetween(double value,
319 double boundary1,
320 double boundary2) {
321 return (value >= boundary1 && value <= boundary2) ||
322 (value >= boundary2 && value <= boundary1);
323 }
324 }
325 }