View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.math4.legacy.fitting;
18  
19  import java.util.Random;
20  
21  import org.apache.commons.math4.legacy.TestUtils;
22  import org.apache.commons.math4.legacy.analysis.polynomials.PolynomialFunction;
23  import org.apache.commons.statistics.distribution.ContinuousDistribution;
24  import org.apache.commons.statistics.distribution.UniformContinuousDistribution;
25  import org.apache.commons.math4.legacy.exception.ConvergenceException;
26  import org.apache.commons.math4.core.jdkmath.JdkMath;
27  import org.apache.commons.rng.simple.RandomSource;
28  import org.junit.Assert;
29  import org.junit.Test;
30  
31  /**
32   * Test for class {@link PolynomialCurveFitter}.
33   */
34  public class PolynomialCurveFitterTest {
35      @Test
36      public void testFit() {
37          final ContinuousDistribution.Sampler rng
38              = UniformContinuousDistribution.of(-100, 100).createSampler(RandomSource.WELL_512_A.create(64925784252L));
39          final double[] coeff = { 12.9, -3.4, 2.1 }; // 12.9 - 3.4 x + 2.1 x^2
40          final PolynomialFunction f = new PolynomialFunction(coeff);
41  
42          // Collect data from a known polynomial.
43          final WeightedObservedPoints obs = new WeightedObservedPoints();
44          for (int i = 0; i < 100; i++) {
45              final double x = rng.sample();
46              obs.add(x, f.value(x));
47          }
48  
49          // Start fit from initial guesses that are far from the optimal values.
50          final SimpleCurveFitter fitter
51              = PolynomialCurveFitter.create(0).withStartPoint(new double[] { -1e-20, 3e15, -5e25 });
52          final double[] best = fitter.fit(obs.toList());
53  
54          TestUtils.assertEquals("best != coeff", coeff, best, 1e-12);
55      }
56  
57      @Test
58      public void testNoError() {
59          final Random randomizer = new Random(64925784252L);
60          for (int degree = 1; degree < 10; ++degree) {
61              final PolynomialFunction p = buildRandomPolynomial(degree, randomizer);
62              final SimpleCurveFitter fitter = PolynomialCurveFitter.create(degree);
63  
64              final WeightedObservedPoints obs = new WeightedObservedPoints();
65              for (int i = 0; i <= degree; ++i) {
66                  obs.add(1.0, i, p.value(i));
67              }
68  
69              final PolynomialFunction fitted = new PolynomialFunction(fitter.fit(obs.toList()));
70  
71              for (double x = -1.0; x < 1.0; x += 0.01) {
72                  final double error = JdkMath.abs(p.value(x) - fitted.value(x)) /
73                      (1.0 + JdkMath.abs(p.value(x)));
74                  Assert.assertEquals(0.0, error, 1.0e-6);
75              }
76          }
77      }
78  
79      @Test
80      public void testSmallError() {
81          final Random randomizer = new Random(53882150042L);
82          double maxError = 0;
83          for (int degree = 0; degree < 10; ++degree) {
84              final PolynomialFunction p = buildRandomPolynomial(degree, randomizer);
85              final SimpleCurveFitter fitter = PolynomialCurveFitter.create(degree);
86  
87              final WeightedObservedPoints obs = new WeightedObservedPoints();
88              for (double x = -1.0; x < 1.0; x += 0.01) {
89                  obs.add(1.0, x, p.value(x) + 0.1 * randomizer.nextGaussian());
90              }
91  
92              final PolynomialFunction fitted = new PolynomialFunction(fitter.fit(obs.toList()));
93  
94              for (double x = -1.0; x < 1.0; x += 0.01) {
95                  final double error = JdkMath.abs(p.value(x) - fitted.value(x)) /
96                      (1.0 + JdkMath.abs(p.value(x)));
97                  maxError = JdkMath.max(maxError, error);
98                  Assert.assertTrue(JdkMath.abs(error) < 0.1);
99              }
100         }
101         Assert.assertTrue(maxError > 0.01);
102     }
103 
104     @Test
105     public void testRedundantSolvable() {
106         // Levenberg-Marquardt should handle redundant information gracefully
107         checkUnsolvableProblem(true);
108     }
109 
110     @Test
111     public void testLargeSample() {
112         final Random randomizer = new Random(0x5551480dca5b369bL);
113         double maxError = 0;
114         for (int degree = 0; degree < 10; ++degree) {
115             final PolynomialFunction p = buildRandomPolynomial(degree, randomizer);
116             final SimpleCurveFitter fitter = PolynomialCurveFitter.create(degree);
117 
118             final WeightedObservedPoints obs = new WeightedObservedPoints();
119             for (int i = 0; i < 40000; ++i) {
120                 final double x = -1.0 + i / 20000.0;
121                 obs.add(1.0, x, p.value(x) + 0.1 * randomizer.nextGaussian());
122             }
123 
124             final PolynomialFunction fitted = new PolynomialFunction(fitter.fit(obs.toList()));
125             for (double x = -1.0; x < 1.0; x += 0.01) {
126                 final double error = JdkMath.abs(p.value(x) - fitted.value(x)) /
127                     (1.0 + JdkMath.abs(p.value(x)));
128                 maxError = JdkMath.max(maxError, error);
129                 Assert.assertTrue(JdkMath.abs(error) < 0.01);
130             }
131         }
132         Assert.assertTrue(maxError > 0.001);
133     }
134 
135     private void checkUnsolvableProblem(boolean solvable) {
136         final Random randomizer = new Random(1248788532L);
137 
138         for (int degree = 0; degree < 10; ++degree) {
139             final PolynomialFunction p = buildRandomPolynomial(degree, randomizer);
140             final SimpleCurveFitter fitter = PolynomialCurveFitter.create(degree);
141             final WeightedObservedPoints obs = new WeightedObservedPoints();
142             // reusing the same point over and over again does not bring
143             // information, the problem cannot be solved in this case for
144             // degrees greater than 1 (but one point is sufficient for
145             // degree 0)
146             for (double x = -1.0; x < 1.0; x += 0.01) {
147                 obs.add(1.0, 0.0, p.value(0.0));
148             }
149 
150             try {
151                 fitter.fit(obs.toList());
152                 Assert.assertTrue(solvable || degree == 0);
153             } catch(ConvergenceException e) {
154                 Assert.assertTrue(!solvable && degree > 0);
155             }
156         }
157     }
158 
159     private PolynomialFunction buildRandomPolynomial(int degree, Random randomizer) {
160         final double[] coefficients = new double[degree + 1];
161         for (int i = 0; i <= degree; ++i) {
162             coefficients[i] = randomizer.nextGaussian();
163         }
164         return new PolynomialFunction(coefficients);
165     }
166 }