View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.math4.legacy.fitting.leastsquares;
18  
19  import org.apache.commons.math4.legacy.linear.ArrayRealVector;
20  import org.apache.commons.math4.legacy.linear.DiagonalMatrix;
21  import org.apache.commons.math4.legacy.linear.RealVector;
22  import org.apache.commons.math4.legacy.stat.descriptive.StatisticalSummary;
23  import org.apache.commons.math4.legacy.stat.descriptive.SummaryStatistics;
24  import org.apache.commons.math4.core.jdkmath.JdkMath;
25  import org.junit.Assert;
26  import org.junit.Test;
27  
28  import java.awt.geom.Point2D;
29  import java.util.ArrayList;
30  import java.util.List;
31  
32  /**
33   * This class demonstrates the main functionality of the
34   * {@link LeastSquaresProblem.Evaluation}, common to the
35   * optimizer implementations in package
36   * {@link org.apache.commons.math4.legacy.fitting.leastsquares}.
37   * <br>
38   * Not enabled by default, as the class name does not end with "Test".
39   * <br>
40   * Invoke by running
41   * <pre><code>
42   *  mvn test -Dtest=EvaluationTestValidation
43   * </code></pre>
44   * or by running
45   * <pre><code>
46   *  mvn test -Dtest=EvaluationTestValidation -DargLine="-DmcRuns=1234 -server"
47   * </code></pre>
48   */
49  public class EvaluationTestValidation {
50      /** Number of runs. */
51      private static final int MONTE_CARLO_RUNS = Integer.parseInt(System.getProperty("mcRuns",
52                                                                                      "100"));
53  
54      /**
55       * Using a Monte-Carlo procedure, this test checks the error estimations
56       * as provided by the square-root of the diagonal elements of the
57       * covariance matrix.
58       * <br>
59       * The test generates sets of observations, each sampled from
60       * a Gaussian distribution.
61       * <br>
62       * The optimization problem solved is defined in class
63       * {@link StraightLineProblem}.
64       * <br>
65       * The output (on stdout) will be a table summarizing the distribution
66       * of parameters generated by the Monte-Carlo process and by the direct
67       * estimation provided by the diagonal elements of the covariance matrix.
68       */
69      @Test
70      public void testParametersErrorMonteCarloObservations() {
71          // Error on the observations.
72          final double yError = 15;
73  
74          // True values of the parameters.
75          final double slope = 123.456;
76          final double offset = -98.765;
77  
78          // Samples generator.
79          final RandomStraightLinePointGenerator lineGenerator
80              = new RandomStraightLinePointGenerator(slope, offset,
81                                                     yError,
82                                                     -1e3, 1e4,
83                                                     138577L);
84  
85          // Number of observations.
86          final int numObs = 100; // XXX Should be a command-line option.
87          // number of parameters.
88          final int numParams = 2;
89  
90          // Parameters found for each of Monte-Carlo run.
91          final SummaryStatistics[] paramsFoundByDirectSolution = new SummaryStatistics[numParams];
92          // Sigma estimations (square-root of the diagonal elements of the
93          // covariance matrix), for each Monte-Carlo run.
94          final SummaryStatistics[] sigmaEstimate = new SummaryStatistics[numParams];
95  
96          // Initialize statistics accumulators.
97          for (int i = 0; i < numParams; i++) {
98              paramsFoundByDirectSolution[i] = new SummaryStatistics();
99              sigmaEstimate[i] = new SummaryStatistics();
100         }
101 
102         final RealVector init = new ArrayRealVector(new double[]{ slope, offset }, false);
103 
104         // Monte-Carlo (generates many sets of observations).
105         final int mcRepeat = MONTE_CARLO_RUNS;
106         int mcCount = 0;
107         while (mcCount < mcRepeat) {
108             // Observations.
109             final Point2D.Double[] obs = lineGenerator.generate(numObs);
110 
111             final StraightLineProblem problem = new StraightLineProblem(yError);
112             for (int i = 0; i < numObs; i++) {
113                 final Point2D.Double p = obs[i];
114                 problem.addPoint(p.x, p.y);
115             }
116 
117             // Direct solution (using simple regression).
118             final double[] regress = problem.solve();
119 
120             // Estimation of the standard deviation (diagonal elements of the
121             // covariance matrix).
122             final LeastSquaresProblem lsp = builder(problem).build();
123 
124             final RealVector sigma = lsp.evaluate(init).getSigma(1e-14);
125 
126             // Accumulate statistics.
127             for (int i = 0; i < numParams; i++) {
128                 paramsFoundByDirectSolution[i].addValue(regress[i]);
129                 sigmaEstimate[i].addValue(sigma.getEntry(i));
130             }
131 
132             // Next Monte-Carlo.
133             ++mcCount;
134         }
135 
136         // Print statistics.
137         final String line = "--------------------------------------------------------------";
138         System.out.println("                 True value       Mean        Std deviation");
139         for (int i = 0; i < numParams; i++) {
140             System.out.println(line);
141             System.out.println("Parameter #" + i);
142 
143             StatisticalSummary s = paramsFoundByDirectSolution[i].getSummary();
144             System.out.printf("              %+.6e   %+.6e   %+.6e\n",
145                               init.getEntry(i),
146                               s.getMean(),
147                               s.getStandardDeviation());
148 
149             s = sigmaEstimate[i].getSummary();
150             System.out.printf("sigma: %+.6e (%+.6e)\n",
151                               s.getMean(),
152                               s.getStandardDeviation());
153         }
154         System.out.println(line);
155 
156         // Check the error estimation.
157         for (int i = 0; i < numParams; i++) {
158             Assert.assertEquals(paramsFoundByDirectSolution[i].getSummary().getStandardDeviation(),
159                                 sigmaEstimate[i].getSummary().getMean(),
160                                 8e-2);
161         }
162     }
163 
164     /**
165      * In this test, the set of observations is fixed.
166      * Using a Monte-Carlo procedure, it generates sets of parameters,
167      * and determine the parameter change that will result in the
168      * normalized chi-square becoming larger by one than the value from
169      * the best fit solution.
170      * <br>
171      * The optimization problem solved is defined in class
172      * {@link StraightLineProblem}.
173      * <br>
174      * The output (on stdout) will be a list of lines containing:
175      * <ul>
176      *  <li>slope of the straight line,</li>
177      *  <li>intercept of the straight line,</li>
178      *  <li>chi-square of the solution defined by the above two values.</li>
179      * </ul>
180      * The output is separated into two blocks (with a blank line between
181      * them); the first block will contain all parameter sets for which
182      * {@code chi2 < chi2_b + 1}
183      * and the second block, all sets for which
184      * {@code chi2 >= chi2_b + 1}
185      * where {@code chi2_b} is the lowest chi-square (corresponding to the
186      * best solution).
187      */
188     @Test
189     public void testParametersErrorMonteCarloParameters() {
190         // Error on the observations.
191         final double yError = 15;
192 
193         // True values of the parameters.
194         final double slope = 123.456;
195         final double offset = -98.765;
196 
197         // Samples generator.
198         final RandomStraightLinePointGenerator lineGenerator
199             = new RandomStraightLinePointGenerator(slope, offset,
200                                                    yError,
201                                                    -1e3, 1e4,
202                                                    13839013L);
203 
204         // Number of observations.
205         final int numObs = 10;
206         // number of parameters.
207 
208         // Create a single set of observations.
209         final Point2D.Double[] obs = lineGenerator.generate(numObs);
210 
211         final StraightLineProblem problem = new StraightLineProblem(yError);
212         for (int i = 0; i < numObs; i++) {
213             final Point2D.Double p = obs[i];
214             problem.addPoint(p.x, p.y);
215         }
216 
217         // Direct solution (using simple regression).
218         final RealVector regress = new ArrayRealVector(problem.solve(), false);
219 
220         // Dummy optimizer (to compute the chi-square).
221         final LeastSquaresProblem lsp = builder(problem).build();
222 
223         // Get chi-square of the best parameters set for the given set of
224         // observations.
225         final double bestChi2N = getChi2N(lsp, regress);
226         final RealVector sigma = lsp.evaluate(regress).getSigma(1e-14);
227 
228         // Monte-Carlo (generates a grid of parameters).
229         final int mcRepeat = MONTE_CARLO_RUNS;
230         final int gridSize = (int) JdkMath.sqrt(mcRepeat);
231 
232         // Parameters found for each of Monte-Carlo run.
233         // Index 0 = slope
234         // Index 1 = offset
235         // Index 2 = normalized chi2
236         final List<double[]> paramsAndChi2 = new ArrayList<>(gridSize * gridSize);
237 
238         final double slopeRange = 10 * sigma.getEntry(0);
239         final double offsetRange = 10 * sigma.getEntry(1);
240         final double minSlope = slope - 0.5 * slopeRange;
241         final double minOffset = offset - 0.5 * offsetRange;
242         final double deltaSlope =  slopeRange/ gridSize;
243         final double deltaOffset = offsetRange / gridSize;
244         for (int i = 0; i < gridSize; i++) {
245             final double s = minSlope + i * deltaSlope;
246             for (int j = 0; j < gridSize; j++) {
247                 final double o = minOffset + j * deltaOffset;
248                 final double chi2N = getChi2N(lsp,
249                         new ArrayRealVector(new double[] {s, o}, false));
250 
251                 paramsAndChi2.add(new double[] {s, o, chi2N});
252             }
253         }
254 
255         // Output (for use with "gnuplot").
256 
257         // Some info.
258 
259         // For plotting separately sets of parameters that have a large chi2.
260         final double chi2NPlusOne = bestChi2N + 1;
261         int numLarger = 0;
262 
263         final String lineFmt = "%+.10e %+.10e   %.8e\n";
264 
265         // Point with smallest chi-square.
266         System.out.printf(lineFmt, regress.getEntry(0), regress.getEntry(1), bestChi2N);
267         System.out.println(); // Empty line.
268 
269         // Points within the confidence interval.
270         for (double[] d : paramsAndChi2) {
271             if (d[2] <= chi2NPlusOne) {
272                 System.out.printf(lineFmt, d[0], d[1], d[2]);
273             }
274         }
275         System.out.println(); // Empty line.
276 
277         // Points outside the confidence interval.
278         for (double[] d : paramsAndChi2) {
279             if (d[2] > chi2NPlusOne) {
280                 ++numLarger;
281                 System.out.printf(lineFmt, d[0], d[1], d[2]);
282             }
283         }
284         System.out.println(); // Empty line.
285 
286         System.out.println("# sigma=" + sigma.toString());
287         System.out.println("# " + numLarger + " sets filtered out");
288     }
289 
290     LeastSquaresBuilder builder(StraightLineProblem problem){
291         return new LeastSquaresBuilder()
292                 .model(problem.getModelFunction(), problem.getModelFunctionJacobian())
293                 .target(problem.target())
294                 .weight(new DiagonalMatrix(problem.weight()))
295                 //unused start point to avoid NPE
296                 .start(new double[2]);
297     }
298     /**
299      * @return the normalized chi-square.
300      */
301     private double getChi2N(LeastSquaresProblem lsp,
302                             RealVector params) {
303         final double cost = lsp.evaluate(params).getCost();
304         return cost * cost / (lsp.getObservationSize() - params.getDimension());
305     }
306 }
307