1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.math4.legacy.analysis.function;
19
20 import org.apache.commons.math4.legacy.analysis.UnivariateFunction;
21 import org.apache.commons.math4.legacy.analysis.differentiation.DerivativeStructure;
22 import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
23 import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
24 import org.apache.commons.math4.legacy.exception.NullArgumentException;
25 import org.apache.commons.math4.core.jdkmath.JdkMath;
26 import org.junit.Assert;
27 import org.junit.Test;
28
29
30
31
32 public class LogisticTest {
33 private final double EPS = Math.ulp(1d);
34
35 @Test(expected=NotStrictlyPositiveException.class)
36 public void testPreconditions1() {
37 new Logistic(1, 0, 1, 1, 0, -1);
38 }
39
40 @Test(expected=NotStrictlyPositiveException.class)
41 public void testPreconditions2() {
42 new Logistic(1, 0, 1, 1, 0, 0);
43 }
44
45 @Test
46 public void testCompareSigmoid() {
47 final UnivariateFunction sig = new Sigmoid();
48 final UnivariateFunction sigL = new Logistic(1, 0, 1, 1, 0, 1);
49
50 final double min = -2;
51 final double max = 2;
52 final int n = 100;
53 final double delta = (max - min) / n;
54 for (int i = 0; i < n; i++) {
55 final double x = min + i * delta;
56 Assert.assertEquals("x=" + x, sig.value(x), sigL.value(x), EPS);
57 }
58 }
59
60 @Test
61 public void testSomeValues() {
62 final double k = 4;
63 final double m = 5;
64 final double b = 2;
65 final double q = 3;
66 final double a = -1;
67 final double n = 2;
68
69 final UnivariateFunction f = new Logistic(k, m, b, q, a, n);
70
71 double x;
72 x = m;
73 Assert.assertEquals("x=" + x, a + (k - a) / JdkMath.sqrt(1 + q), f.value(x), EPS);
74
75 x = Double.NEGATIVE_INFINITY;
76 Assert.assertEquals("x=" + x, a, f.value(x), EPS);
77
78 x = Double.POSITIVE_INFINITY;
79 Assert.assertEquals("x=" + x, k, f.value(x), EPS);
80 }
81
82 @Test
83 public void testCompareDerivativeSigmoid() {
84 final double k = 3;
85 final double a = 2;
86
87 final Logistic f = new Logistic(k, 0, 1, 1, a, 1);
88 final Sigmoid g = new Sigmoid(a, k);
89
90 final double min = -10;
91 final double max = 10;
92 final double n = 20;
93 final double delta = (max - min) / n;
94 for (int i = 0; i < n; i++) {
95 final DerivativeStructure x = new DerivativeStructure(1, 5, 0, min + i * delta);
96 for (int order = 0; order <= x.getOrder(); ++order) {
97 Assert.assertEquals("x=" + x.getValue(),
98 g.value(x).getPartialDerivative(order),
99 f.value(x).getPartialDerivative(order),
100 3.0e-15);
101 }
102 }
103 }
104
105 @Test(expected=NullArgumentException.class)
106 public void testParametricUsage1() {
107 final Logistic.Parametric g = new Logistic.Parametric();
108 g.value(0, null);
109 }
110
111 @Test(expected=DimensionMismatchException.class)
112 public void testParametricUsage2() {
113 final Logistic.Parametric g = new Logistic.Parametric();
114 g.value(0, new double[] {0});
115 }
116
117 @Test(expected=NullArgumentException.class)
118 public void testParametricUsage3() {
119 final Logistic.Parametric g = new Logistic.Parametric();
120 g.gradient(0, null);
121 }
122
123 @Test(expected=DimensionMismatchException.class)
124 public void testParametricUsage4() {
125 final Logistic.Parametric g = new Logistic.Parametric();
126 g.gradient(0, new double[] {0});
127 }
128
129 @Test(expected=NotStrictlyPositiveException.class)
130 public void testParametricUsage5() {
131 final Logistic.Parametric g = new Logistic.Parametric();
132 g.value(0, new double[] {1, 0, 1, 1, 0 ,0});
133 }
134
135 @Test(expected=NotStrictlyPositiveException.class)
136 public void testParametricUsage6() {
137 final Logistic.Parametric g = new Logistic.Parametric();
138 g.gradient(0, new double[] {1, 0, 1, 1, 0 ,0});
139 }
140
141 @Test
142 public void testGradientComponent0Component4() {
143 final double k = 3;
144 final double a = 2;
145
146 final Logistic.Parametric f = new Logistic.Parametric();
147
148 final Sigmoid.Parametric g = new Sigmoid.Parametric();
149
150 final double x = 0.12345;
151 final double[] gf = f.gradient(x, new double[] {k, 0, 1, 1, a, 1});
152 final double[] gg = g.gradient(x, new double[] {a, k});
153
154 Assert.assertEquals(gg[0], gf[4], EPS);
155 Assert.assertEquals(gg[1], gf[0], EPS);
156 }
157
158 @Test
159 public void testGradientComponent5() {
160 final double m = 1.2;
161 final double k = 3.4;
162 final double a = 2.3;
163 final double q = 0.567;
164 final double b = -JdkMath.log(q);
165 final double n = 3.4;
166
167 final Logistic.Parametric f = new Logistic.Parametric();
168
169 final double x = m - 1;
170 final double qExp1 = 2;
171
172 final double[] gf = f.gradient(x, new double[] {k, m, b, q, a, n});
173
174 Assert.assertEquals((k - a) * JdkMath.log(qExp1) / (n * n * JdkMath.pow(qExp1, 1 / n)),
175 gf[5], EPS);
176 }
177
178 @Test
179 public void testGradientComponent1Component2Component3() {
180 final double m = 1.2;
181 final double k = 3.4;
182 final double a = 2.3;
183 final double b = 0.567;
184 final double q = 1 / JdkMath.exp(b * m);
185 final double n = 3.4;
186
187 final Logistic.Parametric f = new Logistic.Parametric();
188
189 final double x = 0;
190 final double qExp1 = 2;
191
192 final double[] gf = f.gradient(x, new double[] {k, m, b, q, a, n});
193
194 final double factor = (a - k) / (n * JdkMath.pow(qExp1, 1 / n + 1));
195 Assert.assertEquals(factor * b, gf[1], EPS);
196 Assert.assertEquals(factor * m, gf[2], EPS);
197 Assert.assertEquals(factor / q, gf[3], EPS);
198 }
199 }