1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.math4.legacy.fitting;
18
19 import java.util.Random;
20
21 import org.apache.commons.math4.legacy.TestUtils;
22 import org.apache.commons.math4.legacy.analysis.polynomials.PolynomialFunction;
23 import org.apache.commons.statistics.distribution.ContinuousDistribution;
24 import org.apache.commons.statistics.distribution.UniformContinuousDistribution;
25 import org.apache.commons.math4.legacy.exception.ConvergenceException;
26 import org.apache.commons.math4.core.jdkmath.JdkMath;
27 import org.apache.commons.rng.simple.RandomSource;
28 import org.junit.Assert;
29 import org.junit.Test;
30
31
32
33
34 public class PolynomialCurveFitterTest {
35 @Test
36 public void testFit() {
37 final ContinuousDistribution.Sampler rng
38 = UniformContinuousDistribution.of(-100, 100).createSampler(RandomSource.WELL_512_A.create(64925784252L));
39 final double[] coeff = { 12.9, -3.4, 2.1 };
40 final PolynomialFunction f = new PolynomialFunction(coeff);
41
42
43 final WeightedObservedPoints obs = new WeightedObservedPoints();
44 for (int i = 0; i < 100; i++) {
45 final double x = rng.sample();
46 obs.add(x, f.value(x));
47 }
48
49
50 final SimpleCurveFitter fitter
51 = PolynomialCurveFitter.create(0).withStartPoint(new double[] { -1e-20, 3e15, -5e25 });
52 final double[] best = fitter.fit(obs.toList());
53
54 TestUtils.assertEquals("best != coeff", coeff, best, 1e-12);
55 }
56
57 @Test
58 public void testNoError() {
59 final Random randomizer = new Random(64925784252L);
60 for (int degree = 1; degree < 10; ++degree) {
61 final PolynomialFunction p = buildRandomPolynomial(degree, randomizer);
62 final SimpleCurveFitter fitter = PolynomialCurveFitter.create(degree);
63
64 final WeightedObservedPoints obs = new WeightedObservedPoints();
65 for (int i = 0; i <= degree; ++i) {
66 obs.add(1.0, i, p.value(i));
67 }
68
69 final PolynomialFunction fitted = new PolynomialFunction(fitter.fit(obs.toList()));
70
71 for (double x = -1.0; x < 1.0; x += 0.01) {
72 final double error = JdkMath.abs(p.value(x) - fitted.value(x)) /
73 (1.0 + JdkMath.abs(p.value(x)));
74 Assert.assertEquals(0.0, error, 1.0e-6);
75 }
76 }
77 }
78
79 @Test
80 public void testSmallError() {
81 final Random randomizer = new Random(53882150042L);
82 double maxError = 0;
83 for (int degree = 0; degree < 10; ++degree) {
84 final PolynomialFunction p = buildRandomPolynomial(degree, randomizer);
85 final SimpleCurveFitter fitter = PolynomialCurveFitter.create(degree);
86
87 final WeightedObservedPoints obs = new WeightedObservedPoints();
88 for (double x = -1.0; x < 1.0; x += 0.01) {
89 obs.add(1.0, x, p.value(x) + 0.1 * randomizer.nextGaussian());
90 }
91
92 final PolynomialFunction fitted = new PolynomialFunction(fitter.fit(obs.toList()));
93
94 for (double x = -1.0; x < 1.0; x += 0.01) {
95 final double error = JdkMath.abs(p.value(x) - fitted.value(x)) /
96 (1.0 + JdkMath.abs(p.value(x)));
97 maxError = JdkMath.max(maxError, error);
98 Assert.assertTrue(JdkMath.abs(error) < 0.1);
99 }
100 }
101 Assert.assertTrue(maxError > 0.01);
102 }
103
104 @Test
105 public void testRedundantSolvable() {
106
107 checkUnsolvableProblem(true);
108 }
109
110 @Test
111 public void testLargeSample() {
112 final Random randomizer = new Random(0x5551480dca5b369bL);
113 double maxError = 0;
114 for (int degree = 0; degree < 10; ++degree) {
115 final PolynomialFunction p = buildRandomPolynomial(degree, randomizer);
116 final SimpleCurveFitter fitter = PolynomialCurveFitter.create(degree);
117
118 final WeightedObservedPoints obs = new WeightedObservedPoints();
119 for (int i = 0; i < 40000; ++i) {
120 final double x = -1.0 + i / 20000.0;
121 obs.add(1.0, x, p.value(x) + 0.1 * randomizer.nextGaussian());
122 }
123
124 final PolynomialFunction fitted = new PolynomialFunction(fitter.fit(obs.toList()));
125 for (double x = -1.0; x < 1.0; x += 0.01) {
126 final double error = JdkMath.abs(p.value(x) - fitted.value(x)) /
127 (1.0 + JdkMath.abs(p.value(x)));
128 maxError = JdkMath.max(maxError, error);
129 Assert.assertTrue(JdkMath.abs(error) < 0.01);
130 }
131 }
132 Assert.assertTrue(maxError > 0.001);
133 }
134
135 private void checkUnsolvableProblem(boolean solvable) {
136 final Random randomizer = new Random(1248788532L);
137
138 for (int degree = 0; degree < 10; ++degree) {
139 final PolynomialFunction p = buildRandomPolynomial(degree, randomizer);
140 final SimpleCurveFitter fitter = PolynomialCurveFitter.create(degree);
141 final WeightedObservedPoints obs = new WeightedObservedPoints();
142
143
144
145
146 for (double x = -1.0; x < 1.0; x += 0.01) {
147 obs.add(1.0, 0.0, p.value(0.0));
148 }
149
150 try {
151 fitter.fit(obs.toList());
152 Assert.assertTrue(solvable || degree == 0);
153 } catch(ConvergenceException e) {
154 Assert.assertTrue(!solvable && degree > 0);
155 }
156 }
157 }
158
159 private PolynomialFunction buildRandomPolynomial(int degree, Random randomizer) {
160 final double[] coefficients = new double[degree + 1];
161 for (int i = 0; i <= degree; ++i) {
162 coefficients[i] = randomizer.nextGaussian();
163 }
164 return new PolynomialFunction(coefficients);
165 }
166 }