1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.math4.legacy.fitting.leastsquares;
18
19 import org.apache.commons.math4.legacy.linear.ArrayRealVector;
20 import org.apache.commons.math4.legacy.linear.DiagonalMatrix;
21 import org.apache.commons.math4.legacy.linear.RealVector;
22 import org.apache.commons.math4.legacy.stat.descriptive.StatisticalSummary;
23 import org.apache.commons.math4.legacy.stat.descriptive.SummaryStatistics;
24 import org.apache.commons.math4.core.jdkmath.JdkMath;
25 import org.junit.Assert;
26 import org.junit.Test;
27
28 import java.awt.geom.Point2D;
29 import java.util.ArrayList;
30 import java.util.List;
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49 public class EvaluationTestValidation {
50
51 private static final int MONTE_CARLO_RUNS = Integer.parseInt(System.getProperty("mcRuns",
52 "100"));
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69 @Test
70 public void testParametersErrorMonteCarloObservations() {
71
72 final double yError = 15;
73
74
75 final double slope = 123.456;
76 final double offset = -98.765;
77
78
79 final RandomStraightLinePointGenerator lineGenerator
80 = new RandomStraightLinePointGenerator(slope, offset,
81 yError,
82 -1e3, 1e4,
83 138577L);
84
85
86 final int numObs = 100;
87
88 final int numParams = 2;
89
90
91 final SummaryStatistics[] paramsFoundByDirectSolution = new SummaryStatistics[numParams];
92
93
94 final SummaryStatistics[] sigmaEstimate = new SummaryStatistics[numParams];
95
96
97 for (int i = 0; i < numParams; i++) {
98 paramsFoundByDirectSolution[i] = new SummaryStatistics();
99 sigmaEstimate[i] = new SummaryStatistics();
100 }
101
102 final RealVector init = new ArrayRealVector(new double[]{ slope, offset }, false);
103
104
105 final int mcRepeat = MONTE_CARLO_RUNS;
106 int mcCount = 0;
107 while (mcCount < mcRepeat) {
108
109 final Point2D.Double[] obs = lineGenerator.generate(numObs);
110
111 final StraightLineProblem problem = new StraightLineProblem(yError);
112 for (int i = 0; i < numObs; i++) {
113 final Point2D.Double p = obs[i];
114 problem.addPoint(p.x, p.y);
115 }
116
117
118 final double[] regress = problem.solve();
119
120
121
122 final LeastSquaresProblem lsp = builder(problem).build();
123
124 final RealVector sigma = lsp.evaluate(init).getSigma(1e-14);
125
126
127 for (int i = 0; i < numParams; i++) {
128 paramsFoundByDirectSolution[i].addValue(regress[i]);
129 sigmaEstimate[i].addValue(sigma.getEntry(i));
130 }
131
132
133 ++mcCount;
134 }
135
136
137 final String line = "--------------------------------------------------------------";
138 System.out.println(" True value Mean Std deviation");
139 for (int i = 0; i < numParams; i++) {
140 System.out.println(line);
141 System.out.println("Parameter #" + i);
142
143 StatisticalSummary s = paramsFoundByDirectSolution[i].getSummary();
144 System.out.printf(" %+.6e %+.6e %+.6e\n",
145 init.getEntry(i),
146 s.getMean(),
147 s.getStandardDeviation());
148
149 s = sigmaEstimate[i].getSummary();
150 System.out.printf("sigma: %+.6e (%+.6e)\n",
151 s.getMean(),
152 s.getStandardDeviation());
153 }
154 System.out.println(line);
155
156
157 for (int i = 0; i < numParams; i++) {
158 Assert.assertEquals(paramsFoundByDirectSolution[i].getSummary().getStandardDeviation(),
159 sigmaEstimate[i].getSummary().getMean(),
160 8e-2);
161 }
162 }
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188 @Test
189 public void testParametersErrorMonteCarloParameters() {
190
191 final double yError = 15;
192
193
194 final double slope = 123.456;
195 final double offset = -98.765;
196
197
198 final RandomStraightLinePointGenerator lineGenerator
199 = new RandomStraightLinePointGenerator(slope, offset,
200 yError,
201 -1e3, 1e4,
202 13839013L);
203
204
205 final int numObs = 10;
206
207
208
209 final Point2D.Double[] obs = lineGenerator.generate(numObs);
210
211 final StraightLineProblem problem = new StraightLineProblem(yError);
212 for (int i = 0; i < numObs; i++) {
213 final Point2D.Double p = obs[i];
214 problem.addPoint(p.x, p.y);
215 }
216
217
218 final RealVector regress = new ArrayRealVector(problem.solve(), false);
219
220
221 final LeastSquaresProblem lsp = builder(problem).build();
222
223
224
225 final double bestChi2N = getChi2N(lsp, regress);
226 final RealVector sigma = lsp.evaluate(regress).getSigma(1e-14);
227
228
229 final int mcRepeat = MONTE_CARLO_RUNS;
230 final int gridSize = (int) JdkMath.sqrt(mcRepeat);
231
232
233
234
235
236 final List<double[]> paramsAndChi2 = new ArrayList<>(gridSize * gridSize);
237
238 final double slopeRange = 10 * sigma.getEntry(0);
239 final double offsetRange = 10 * sigma.getEntry(1);
240 final double minSlope = slope - 0.5 * slopeRange;
241 final double minOffset = offset - 0.5 * offsetRange;
242 final double deltaSlope = slopeRange/ gridSize;
243 final double deltaOffset = offsetRange / gridSize;
244 for (int i = 0; i < gridSize; i++) {
245 final double s = minSlope + i * deltaSlope;
246 for (int j = 0; j < gridSize; j++) {
247 final double o = minOffset + j * deltaOffset;
248 final double chi2N = getChi2N(lsp,
249 new ArrayRealVector(new double[] {s, o}, false));
250
251 paramsAndChi2.add(new double[] {s, o, chi2N});
252 }
253 }
254
255
256
257
258
259
260 final double chi2NPlusOne = bestChi2N + 1;
261 int numLarger = 0;
262
263 final String lineFmt = "%+.10e %+.10e %.8e\n";
264
265
266 System.out.printf(lineFmt, regress.getEntry(0), regress.getEntry(1), bestChi2N);
267 System.out.println();
268
269
270 for (double[] d : paramsAndChi2) {
271 if (d[2] <= chi2NPlusOne) {
272 System.out.printf(lineFmt, d[0], d[1], d[2]);
273 }
274 }
275 System.out.println();
276
277
278 for (double[] d : paramsAndChi2) {
279 if (d[2] > chi2NPlusOne) {
280 ++numLarger;
281 System.out.printf(lineFmt, d[0], d[1], d[2]);
282 }
283 }
284 System.out.println();
285
286 System.out.println("# sigma=" + sigma.toString());
287 System.out.println("# " + numLarger + " sets filtered out");
288 }
289
290 LeastSquaresBuilder builder(StraightLineProblem problem){
291 return new LeastSquaresBuilder()
292 .model(problem.getModelFunction(), problem.getModelFunctionJacobian())
293 .target(problem.target())
294 .weight(new DiagonalMatrix(problem.weight()))
295
296 .start(new double[2]);
297 }
298
299
300
301 private double getChi2N(LeastSquaresProblem lsp,
302 RealVector params) {
303 final double cost = lsp.evaluate(params).getCost();
304 return cost * cost / (lsp.getObservationSize() - params.getDimension());
305 }
306 }
307