1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17 package org.apache.commons.math4.legacy.fitting.leastsquares;
18
19 import org.apache.commons.math4.legacy.linear.ArrayRealVector;
20 import org.apache.commons.math4.legacy.linear.DiagonalMatrix;
21 import org.apache.commons.math4.legacy.linear.RealVector;
22 import org.apache.commons.math4.legacy.stat.descriptive.StatisticalSummary;
23 import org.apache.commons.math4.legacy.stat.descriptive.SummaryStatistics;
24 import org.apache.commons.math4.core.jdkmath.JdkMath;
25 import org.junit.Assert;
26 import org.junit.Test;
27
28 import java.awt.geom.Point2D;
29 import java.util.ArrayList;
30 import java.util.List;
31
32 /**
33 * This class demonstrates the main functionality of the
34 * {@link LeastSquaresProblem.Evaluation}, common to the
35 * optimizer implementations in package
36 * {@link org.apache.commons.math4.legacy.fitting.leastsquares}.
37 * <br>
38 * Not enabled by default, as the class name does not end with "Test".
39 * <br>
40 * Invoke by running
41 * <pre><code>
42 * mvn test -Dtest=EvaluationTestValidation
43 * </code></pre>
44 * or by running
45 * <pre><code>
46 * mvn test -Dtest=EvaluationTestValidation -DargLine="-DmcRuns=1234 -server"
47 * </code></pre>
48 */
49 public class EvaluationTestValidation {
50 /** Number of runs. */
51 private static final int MONTE_CARLO_RUNS = Integer.parseInt(System.getProperty("mcRuns",
52 "100"));
53
54 /**
55 * Using a Monte-Carlo procedure, this test checks the error estimations
56 * as provided by the square-root of the diagonal elements of the
57 * covariance matrix.
58 * <br>
59 * The test generates sets of observations, each sampled from
60 * a Gaussian distribution.
61 * <br>
62 * The optimization problem solved is defined in class
63 * {@link StraightLineProblem}.
64 * <br>
65 * The output (on stdout) will be a table summarizing the distribution
66 * of parameters generated by the Monte-Carlo process and by the direct
67 * estimation provided by the diagonal elements of the covariance matrix.
68 */
69 @Test
70 public void testParametersErrorMonteCarloObservations() {
71 // Error on the observations.
72 final double yError = 15;
73
74 // True values of the parameters.
75 final double slope = 123.456;
76 final double offset = -98.765;
77
78 // Samples generator.
79 final RandomStraightLinePointGenerator lineGenerator
80 = new RandomStraightLinePointGenerator(slope, offset,
81 yError,
82 -1e3, 1e4,
83 138577L);
84
85 // Number of observations.
86 final int numObs = 100; // XXX Should be a command-line option.
87 // number of parameters.
88 final int numParams = 2;
89
90 // Parameters found for each of Monte-Carlo run.
91 final SummaryStatistics[] paramsFoundByDirectSolution = new SummaryStatistics[numParams];
92 // Sigma estimations (square-root of the diagonal elements of the
93 // covariance matrix), for each Monte-Carlo run.
94 final SummaryStatistics[] sigmaEstimate = new SummaryStatistics[numParams];
95
96 // Initialize statistics accumulators.
97 for (int i = 0; i < numParams; i++) {
98 paramsFoundByDirectSolution[i] = new SummaryStatistics();
99 sigmaEstimate[i] = new SummaryStatistics();
100 }
101
102 final RealVector init = new ArrayRealVector(new double[]{ slope, offset }, false);
103
104 // Monte-Carlo (generates many sets of observations).
105 final int mcRepeat = MONTE_CARLO_RUNS;
106 int mcCount = 0;
107 while (mcCount < mcRepeat) {
108 // Observations.
109 final Point2D.Double[] obs = lineGenerator.generate(numObs);
110
111 final StraightLineProblem problem = new StraightLineProblem(yError);
112 for (int i = 0; i < numObs; i++) {
113 final Point2D.Double p = obs[i];
114 problem.addPoint(p.x, p.y);
115 }
116
117 // Direct solution (using simple regression).
118 final double[] regress = problem.solve();
119
120 // Estimation of the standard deviation (diagonal elements of the
121 // covariance matrix).
122 final LeastSquaresProblem lsp = builder(problem).build();
123
124 final RealVector sigma = lsp.evaluate(init).getSigma(1e-14);
125
126 // Accumulate statistics.
127 for (int i = 0; i < numParams; i++) {
128 paramsFoundByDirectSolution[i].addValue(regress[i]);
129 sigmaEstimate[i].addValue(sigma.getEntry(i));
130 }
131
132 // Next Monte-Carlo.
133 ++mcCount;
134 }
135
136 // Print statistics.
137 final String line = "--------------------------------------------------------------";
138 System.out.println(" True value Mean Std deviation");
139 for (int i = 0; i < numParams; i++) {
140 System.out.println(line);
141 System.out.println("Parameter #" + i);
142
143 StatisticalSummary s = paramsFoundByDirectSolution[i].getSummary();
144 System.out.printf(" %+.6e %+.6e %+.6e\n",
145 init.getEntry(i),
146 s.getMean(),
147 s.getStandardDeviation());
148
149 s = sigmaEstimate[i].getSummary();
150 System.out.printf("sigma: %+.6e (%+.6e)\n",
151 s.getMean(),
152 s.getStandardDeviation());
153 }
154 System.out.println(line);
155
156 // Check the error estimation.
157 for (int i = 0; i < numParams; i++) {
158 Assert.assertEquals(paramsFoundByDirectSolution[i].getSummary().getStandardDeviation(),
159 sigmaEstimate[i].getSummary().getMean(),
160 8e-2);
161 }
162 }
163
164 /**
165 * In this test, the set of observations is fixed.
166 * Using a Monte-Carlo procedure, it generates sets of parameters,
167 * and determine the parameter change that will result in the
168 * normalized chi-square becoming larger by one than the value from
169 * the best fit solution.
170 * <br>
171 * The optimization problem solved is defined in class
172 * {@link StraightLineProblem}.
173 * <br>
174 * The output (on stdout) will be a list of lines containing:
175 * <ul>
176 * <li>slope of the straight line,</li>
177 * <li>intercept of the straight line,</li>
178 * <li>chi-square of the solution defined by the above two values.</li>
179 * </ul>
180 * The output is separated into two blocks (with a blank line between
181 * them); the first block will contain all parameter sets for which
182 * {@code chi2 < chi2_b + 1}
183 * and the second block, all sets for which
184 * {@code chi2 >= chi2_b + 1}
185 * where {@code chi2_b} is the lowest chi-square (corresponding to the
186 * best solution).
187 */
188 @Test
189 public void testParametersErrorMonteCarloParameters() {
190 // Error on the observations.
191 final double yError = 15;
192
193 // True values of the parameters.
194 final double slope = 123.456;
195 final double offset = -98.765;
196
197 // Samples generator.
198 final RandomStraightLinePointGenerator lineGenerator
199 = new RandomStraightLinePointGenerator(slope, offset,
200 yError,
201 -1e3, 1e4,
202 13839013L);
203
204 // Number of observations.
205 final int numObs = 10;
206 // number of parameters.
207
208 // Create a single set of observations.
209 final Point2D.Double[] obs = lineGenerator.generate(numObs);
210
211 final StraightLineProblem problem = new StraightLineProblem(yError);
212 for (int i = 0; i < numObs; i++) {
213 final Point2D.Double p = obs[i];
214 problem.addPoint(p.x, p.y);
215 }
216
217 // Direct solution (using simple regression).
218 final RealVector regress = new ArrayRealVector(problem.solve(), false);
219
220 // Dummy optimizer (to compute the chi-square).
221 final LeastSquaresProblem lsp = builder(problem).build();
222
223 // Get chi-square of the best parameters set for the given set of
224 // observations.
225 final double bestChi2N = getChi2N(lsp, regress);
226 final RealVector sigma = lsp.evaluate(regress).getSigma(1e-14);
227
228 // Monte-Carlo (generates a grid of parameters).
229 final int mcRepeat = MONTE_CARLO_RUNS;
230 final int gridSize = (int) JdkMath.sqrt(mcRepeat);
231
232 // Parameters found for each of Monte-Carlo run.
233 // Index 0 = slope
234 // Index 1 = offset
235 // Index 2 = normalized chi2
236 final List<double[]> paramsAndChi2 = new ArrayList<>(gridSize * gridSize);
237
238 final double slopeRange = 10 * sigma.getEntry(0);
239 final double offsetRange = 10 * sigma.getEntry(1);
240 final double minSlope = slope - 0.5 * slopeRange;
241 final double minOffset = offset - 0.5 * offsetRange;
242 final double deltaSlope = slopeRange/ gridSize;
243 final double deltaOffset = offsetRange / gridSize;
244 for (int i = 0; i < gridSize; i++) {
245 final double s = minSlope + i * deltaSlope;
246 for (int j = 0; j < gridSize; j++) {
247 final double o = minOffset + j * deltaOffset;
248 final double chi2N = getChi2N(lsp,
249 new ArrayRealVector(new double[] {s, o}, false));
250
251 paramsAndChi2.add(new double[] {s, o, chi2N});
252 }
253 }
254
255 // Output (for use with "gnuplot").
256
257 // Some info.
258
259 // For plotting separately sets of parameters that have a large chi2.
260 final double chi2NPlusOne = bestChi2N + 1;
261 int numLarger = 0;
262
263 final String lineFmt = "%+.10e %+.10e %.8e\n";
264
265 // Point with smallest chi-square.
266 System.out.printf(lineFmt, regress.getEntry(0), regress.getEntry(1), bestChi2N);
267 System.out.println(); // Empty line.
268
269 // Points within the confidence interval.
270 for (double[] d : paramsAndChi2) {
271 if (d[2] <= chi2NPlusOne) {
272 System.out.printf(lineFmt, d[0], d[1], d[2]);
273 }
274 }
275 System.out.println(); // Empty line.
276
277 // Points outside the confidence interval.
278 for (double[] d : paramsAndChi2) {
279 if (d[2] > chi2NPlusOne) {
280 ++numLarger;
281 System.out.printf(lineFmt, d[0], d[1], d[2]);
282 }
283 }
284 System.out.println(); // Empty line.
285
286 System.out.println("# sigma=" + sigma.toString());
287 System.out.println("# " + numLarger + " sets filtered out");
288 }
289
290 LeastSquaresBuilder builder(StraightLineProblem problem){
291 return new LeastSquaresBuilder()
292 .model(problem.getModelFunction(), problem.getModelFunctionJacobian())
293 .target(problem.target())
294 .weight(new DiagonalMatrix(problem.weight()))
295 //unused start point to avoid NPE
296 .start(new double[2]);
297 }
298 /**
299 * @return the normalized chi-square.
300 */
301 private double getChi2N(LeastSquaresProblem lsp,
302 RealVector params) {
303 final double cost = lsp.evaluate(params).getCost();
304 return cost * cost / (lsp.getObservationSize() - params.getDimension());
305 }
306 }
307