View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math4.legacy.fitting.leastsquares;
19  
20  import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
21  import org.apache.commons.math4.legacy.exception.TooManyEvaluationsException;
22  import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresOptimizer.Optimum;
23  import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresProblem.Evaluation;
24  import org.apache.commons.math4.legacy.linear.DiagonalMatrix;
25  import org.apache.commons.math4.legacy.linear.RealMatrix;
26  import org.apache.commons.math4.legacy.linear.RealVector;
27  import org.apache.commons.math4.legacy.linear.SingularMatrixException;
28  import org.apache.commons.math4.legacy.optim.ConvergenceChecker;
29  import org.apache.commons.math4.core.jdkmath.JdkMath;
30  import org.apache.commons.numbers.core.Precision;
31  import org.junit.Assert;
32  import org.junit.Test;
33  
34  /**
35   * <p>Some of the unit tests are re-implementations of the MINPACK <a
36   * href="http://www.netlib.org/minpack/ex/file17">file17</a> and <a
37   * href="http://www.netlib.org/minpack/ex/file22">file22</a> test files.
38   * The redistribution policy for MINPACK is available <a
39   * href="http://www.netlib.org/minpack/disclaimer">here</a>.
40   *
41   */
42  public class LevenbergMarquardtOptimizerTest
43      extends AbstractLeastSquaresOptimizerAbstractTest{
44  
45      public LeastSquaresBuilder builder(BevingtonProblem problem){
46          return base()
47                  .model(problem.getModelFunction(), problem.getModelFunctionJacobian());
48      }
49  
50      public LeastSquaresBuilder builder(CircleProblem problem){
51          return base()
52                  .model(problem.getModelFunction(), problem.getModelFunctionJacobian())
53                  .target(problem.target())
54                  .weight(new DiagonalMatrix(problem.weight()));
55      }
56  
57      @Override
58      public int getMaxIterations() {
59          return 25;
60      }
61  
62      @Override
63      public LeastSquaresOptimizer getOptimizer() {
64          return new LevenbergMarquardtOptimizer();
65      }
66  
67      @Override
68      @Test
69      public void testNonInvertible() {
70          try{
71              /*
72               * Overrides the method from parent class, since the default singularity
73               * threshold (1e-14) does not trigger the expected exception.
74               */
75              LinearProblem problem = new LinearProblem(new double[][] {
76                      {  1, 2, -3 },
77                      {  2, 1,  3 },
78                      { -3, 0, -9 }
79              }, new double[] { 1, 1, 1 });
80  
81              final Optimum optimum = optimizer.optimize(
82                      problem.getBuilder().maxIterations(20).build());
83  
84              //TODO check that it is a bad fit? Why the extra conditions?
85              Assert.assertTrue(JdkMath.sqrt(problem.getTarget().length) * optimum.getRMS() > 0.6);
86  
87              optimum.getCovariances(1.5e-14);
88  
89              fail(optimizer);
90          }catch (SingularMatrixException e){
91              //expected
92          }
93      }
94  
95      @Test
96      public void testControlParameters() {
97          CircleVectorial circle = new CircleVectorial();
98          circle.addPoint( 30.0,  68.0);
99          circle.addPoint( 50.0,  -6.0);
100         circle.addPoint(110.0, -20.0);
101         circle.addPoint( 35.0,  15.0);
102         circle.addPoint( 45.0,  97.0);
103         checkEstimate(
104                 circle, 0.1, 10, 1.0e-14, 1.0e-16, 1.0e-10, false);
105         checkEstimate(
106                 circle, 0.1, 10, 1.0e-15, 1.0e-17, 1.0e-10, false);
107         checkEstimate(
108                 circle, 0.1,  5, 1.0e-15, 1.0e-16, 1.0e-10, true);
109         circle.addPoint(300, -300);
110         //wardev I changed true => false
111         //TODO why should this fail? It uses 15 evaluations.
112         checkEstimate(
113                 circle, 0.1, 20, 1.0e-18, 1.0e-16, 1.0e-10, false);
114     }
115 
116     private void checkEstimate(CircleVectorial circle,
117                                double initialStepBoundFactor, int maxCostEval,
118                                double costRelativeTolerance, double parRelativeTolerance,
119                                double orthoTolerance, boolean shouldFail) {
120         try {
121             final LevenbergMarquardtOptimizer optimizer = new LevenbergMarquardtOptimizer()
122                 .withInitialStepBoundFactor(initialStepBoundFactor)
123                 .withCostRelativeTolerance(costRelativeTolerance)
124                 .withParameterRelativeTolerance(parRelativeTolerance)
125                 .withOrthoTolerance(orthoTolerance)
126                 .withRankingThreshold(Precision.SAFE_MIN);
127 
128             final LeastSquaresProblem problem = builder(circle)
129                     .maxEvaluations(maxCostEval)
130                     .maxIterations(100)
131                     .start(new double[] { 98.680, 47.345 })
132                     .build();
133 
134             optimizer.optimize(problem);
135 
136             Assert.assertTrue(!shouldFail);
137             //TODO check it got the right answer
138         } catch (DimensionMismatchException ee) {
139             Assert.assertTrue(shouldFail);
140         } catch (TooManyEvaluationsException ee) {
141             Assert.assertTrue(shouldFail);
142         }
143     }
144 
145     /**
146      * Non-linear test case: fitting of decay curve (from Chapter 8 of
147      * Bevington's textbook, "Data reduction and analysis for the physical sciences").
148      * XXX The expected ("reference") values may not be accurate and the tolerance too
149      * relaxed for this test to be currently really useful (the issue is under
150      * investigation).
151      */
152     @Test
153     public void testBevington() {
154         final double[][] dataPoints = {
155             // column 1 = times
156             { 15, 30, 45, 60, 75, 90, 105, 120, 135, 150,
157               165, 180, 195, 210, 225, 240, 255, 270, 285, 300,
158               315, 330, 345, 360, 375, 390, 405, 420, 435, 450,
159               465, 480, 495, 510, 525, 540, 555, 570, 585, 600,
160               615, 630, 645, 660, 675, 690, 705, 720, 735, 750,
161               765, 780, 795, 810, 825, 840, 855, 870, 885, },
162             // column 2 = measured counts
163             { 775, 479, 380, 302, 185, 157, 137, 119, 110, 89,
164               74, 61, 66, 68, 48, 54, 51, 46, 55, 29,
165               28, 37, 49, 26, 35, 29, 31, 24, 25, 35,
166               24, 30, 26, 28, 21, 18, 20, 27, 17, 17,
167               14, 17, 24, 11, 22, 17, 12, 10, 13, 16,
168               9, 9, 14, 21, 17, 13, 12, 18, 10, },
169         };
170         final double[] start = {10, 900, 80, 27, 225};
171 
172         final BevingtonProblem problem = new BevingtonProblem();
173 
174         final int len = dataPoints[0].length;
175         final double[] weights = new double[len];
176         for (int i = 0; i < len; i++) {
177             problem.addPoint(dataPoints[0][i],
178                              dataPoints[1][i]);
179 
180             weights[i] = 1 / dataPoints[1][i];
181         }
182 
183         final Optimum optimum = optimizer.optimize(
184                 builder(problem)
185                         .target(dataPoints[1])
186                         .weight(new DiagonalMatrix(weights))
187                         .start(start)
188                         .maxIterations(20)
189                         .build()
190         );
191 
192         final RealVector solution = optimum.getPoint();
193         final double[] expectedSolution = { 10.4, 958.3, 131.4, 33.9, 205.0 };
194 
195         final RealMatrix covarMatrix = optimum.getCovariances(1e-14);
196         final double[][] expectedCovarMatrix = {
197             { 3.38, -3.69, 27.98, -2.34, -49.24 },
198             { -3.69, 2492.26, 81.89, -69.21, -8.9 },
199             { 27.98, 81.89, 468.99, -44.22, -615.44 },
200             { -2.34, -69.21, -44.22, 6.39, 53.80 },
201             { -49.24, -8.9, -615.44, 53.8, 929.45 }
202         };
203 
204         final int numParams = expectedSolution.length;
205 
206         // Check that the computed solution is within the reference error range.
207         for (int i = 0; i < numParams; i++) {
208             final double error = JdkMath.sqrt(expectedCovarMatrix[i][i]);
209             Assert.assertEquals("Parameter " + i, expectedSolution[i], solution.getEntry(i), error);
210         }
211 
212         // Check that each entry of the computed covariance matrix is within 10%
213         // of the reference matrix entry.
214         for (int i = 0; i < numParams; i++) {
215             for (int j = 0; j < numParams; j++) {
216                 Assert.assertEquals("Covariance matrix [" + i + "][" + j + "]",
217                                     expectedCovarMatrix[i][j],
218                                     covarMatrix.getEntry(i, j),
219                                     JdkMath.abs(0.1 * expectedCovarMatrix[i][j]));
220             }
221         }
222 
223         // Check various measures of goodness-of-fit.
224         final double chi2 = optimum.getChiSquare();
225         final double cost = optimum.getCost();
226         final double rms = optimum.getRMS();
227         final double reducedChi2 = optimum.getReducedChiSquare(start.length);
228 
229         // XXX Values computed by the CM code: It would be better to compare
230         // with the results from another library.
231         final double expectedChi2 = 66.07852350839286;
232         final double expectedReducedChi2 = 1.2014277001525975;
233         final double expectedCost = 8.128869755900439;
234         final double expectedRms = 1.0582887010256337;
235 
236         final double tol = 1e-14;
237         Assert.assertEquals(expectedChi2, chi2, tol);
238         Assert.assertEquals(expectedReducedChi2, reducedChi2, tol);
239         Assert.assertEquals(expectedCost, cost, tol);
240         Assert.assertEquals(expectedRms, rms, tol);
241     }
242 
243     @Test
244     public void testCircleFitting2() {
245         final double xCenter = 123.456;
246         final double yCenter = 654.321;
247         final double xSigma = 10;
248         final double ySigma = 15;
249         final double radius = 111.111;
250         // The test is extremely sensitive to the seed.
251         final RandomCirclePointGenerator factory
252             = new RandomCirclePointGenerator(xCenter, yCenter, radius,
253                                              xSigma, ySigma);
254         final CircleProblem circle = new CircleProblem(xSigma, ySigma);
255 
256         final int numPoints = 10;
257         factory.samples(numPoints).forEach(circle::addPoint);
258 
259         // First guess for the center's coordinates and radius.
260         final double[] init = { 118, 659, 115 };
261 
262         final Optimum optimum = optimizer.optimize(
263                 builder(circle).maxIterations(50).start(init).build());
264 
265         final double[] paramFound = optimum.getPoint().toArray();
266 
267         // Retrieve errors estimation.
268         final double[] asymptoticStandardErrorFound = optimum.getSigma(1e-14).toArray();
269 
270         // Check that the parameters are found within the assumed error bars.
271         Assert.assertEquals("Delta=" + 2 * asymptoticStandardErrorFound[0], xCenter, paramFound[0], 2 * asymptoticStandardErrorFound[0]);
272         Assert.assertEquals("Delta=" + 2 * asymptoticStandardErrorFound[1], yCenter, paramFound[1], 2 * asymptoticStandardErrorFound[1]);
273         Assert.assertEquals("Delta=" + 2 * asymptoticStandardErrorFound[2], radius, paramFound[2], asymptoticStandardErrorFound[2]);
274     }
275 
276     @Test
277     public void testParameterValidator() {
278         // Setup.
279         final double xCenter = 123.456;
280         final double yCenter = 654.321;
281         final double xSigma = 10;
282         final double ySigma = 15;
283         final double radius = 111.111;
284         final RandomCirclePointGenerator factory
285             = new RandomCirclePointGenerator(xCenter, yCenter, radius,
286                                              xSigma, ySigma);
287         final CircleProblem circle = new CircleProblem(xSigma, ySigma);
288 
289         final int numPoints = 10;
290         factory.samples(numPoints).forEach(circle::addPoint);
291 
292         // First guess for the center's coordinates and radius.
293         final double[] init = { 118, 659, 115 };
294 
295         final Optimum optimum = optimizer.optimize(
296                 builder(circle).maxIterations(50).start(init).build());
297 
298         final int numEval = optimum.getEvaluations();
299         Assert.assertTrue(numEval > 1);
300 
301         // Build a new problem with a validator that amounts to cheating.
302 
303         // Note we cannot return a fixed point.
304         // The optimiser relies on computing a predicted reduction in the cost
305         // function (preRed) and an actual reduction (actRed). The ratio between them must be
306         // non-zero to indicate the step reduced the cost function. If a threshold is not
307         // achieved then the step is rejected and the optimiser can cycle through many iterations
308         // not moving anywhere until alternative thresholds reduce to a level that terminate
309         // the cycle.
310         // Here we take the current point and move it towards an acceptable answer
311         // given the problem (the previous optimum). This should speed up the optimiser.
312         // This can still fail to reduce the iterations when the adjusted step moves
313         // to a sub-optimal position in the cost function.
314         final ParameterValidator cheatValidator
315             = new ParameterValidator() {
316                     @Override
317                     public RealVector validate(RealVector params) {
318                         // Cheat: Move towards the optimum found previously.
319                         final RealVector direction = optimum.getPoint().subtract(params);
320                         return params.add(direction.mapMultiply(0.75));
321                     }
322                 };
323 
324         final Optimum cheatOptimum
325             = optimizer.optimize(builder(circle).maxIterations(50).start(init).parameterValidator(cheatValidator).build());
326         final int cheatNumEval = cheatOptimum.getEvaluations();
327         Assert.assertTrue("n=" + numEval + " nc=" + cheatNumEval, cheatNumEval < numEval);
328         // System.out.println("n=" + numEval + " nc=" + cheatNumEval);
329     }
330 
331     @Test
332     public void testEvaluationCount() {
333         //setup
334         LeastSquaresProblem lsp = new LinearProblem(new double[][] {{1}}, new double[] {1})
335                 .getBuilder()
336                 .checker(new ConvergenceChecker<Evaluation>() {
337                     @Override
338                     public boolean converged(int iteration, Evaluation previous, Evaluation current) {
339                         return true;
340                     }
341                 })
342                 .build();
343 
344         //action
345         Optimum optimum = optimizer.optimize(lsp);
346 
347         //verify
348         //check iterations and evaluations are not switched.
349         Assert.assertEquals(1, optimum.getIterations());
350         Assert.assertEquals(2, optimum.getEvaluations());
351     }
352 }