1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.math3.analysis.function;
19
20 import org.apache.commons.math3.analysis.UnivariateFunction;
21 import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
22 import org.apache.commons.math3.exception.NotStrictlyPositiveException;
23 import org.apache.commons.math3.exception.NullArgumentException;
24 import org.apache.commons.math3.exception.DimensionMismatchException;
25 import org.apache.commons.math3.util.FastMath;
26
27 import org.junit.Assert;
28 import org.junit.Test;
29
30
31
32
33 public class LogisticTest {
34 private final double EPS = Math.ulp(1d);
35
36 @Test(expected=NotStrictlyPositiveException.class)
37 public void testPreconditions1() {
38 new Logistic(1, 0, 1, 1, 0, -1);
39 }
40
41 @Test(expected=NotStrictlyPositiveException.class)
42 public void testPreconditions2() {
43 new Logistic(1, 0, 1, 1, 0, 0);
44 }
45
46 @Test
47 public void testCompareSigmoid() {
48 final UnivariateFunction sig = new Sigmoid();
49 final UnivariateFunction sigL = new Logistic(1, 0, 1, 1, 0, 1);
50
51 final double min = -2;
52 final double max = 2;
53 final int n = 100;
54 final double delta = (max - min) / n;
55 for (int i = 0; i < n; i++) {
56 final double x = min + i * delta;
57 Assert.assertEquals("x=" + x, sig.value(x), sigL.value(x), EPS);
58 }
59 }
60
61 @Test
62 public void testSomeValues() {
63 final double k = 4;
64 final double m = 5;
65 final double b = 2;
66 final double q = 3;
67 final double a = -1;
68 final double n = 2;
69
70 final UnivariateFunction f = new Logistic(k, m, b, q, a, n);
71
72 double x;
73 x = m;
74 Assert.assertEquals("x=" + x, a + (k - a) / FastMath.sqrt(1 + q), f.value(x), EPS);
75
76 x = Double.NEGATIVE_INFINITY;
77 Assert.assertEquals("x=" + x, a, f.value(x), EPS);
78
79 x = Double.POSITIVE_INFINITY;
80 Assert.assertEquals("x=" + x, k, f.value(x), EPS);
81 }
82
83 @Test
84 public void testCompareDerivativeSigmoid() {
85 final double k = 3;
86 final double a = 2;
87
88 final Logistic f = new Logistic(k, 0, 1, 1, a, 1);
89 final Sigmoid g = new Sigmoid(a, k);
90
91 final double min = -10;
92 final double max = 10;
93 final double n = 20;
94 final double delta = (max - min) / n;
95 for (int i = 0; i < n; i++) {
96 final DerivativeStructure x = new DerivativeStructure(1, 5, 0, min + i * delta);
97 for (int order = 0; order <= x.getOrder(); ++order) {
98 Assert.assertEquals("x=" + x.getValue(),
99 g.value(x).getPartialDerivative(order),
100 f.value(x).getPartialDerivative(order),
101 3.0e-15);
102 }
103 }
104 }
105
106 @Test(expected=NullArgumentException.class)
107 public void testParametricUsage1() {
108 final Logistic.Parametric g = new Logistic.Parametric();
109 g.value(0, null);
110 }
111
112 @Test(expected=DimensionMismatchException.class)
113 public void testParametricUsage2() {
114 final Logistic.Parametric g = new Logistic.Parametric();
115 g.value(0, new double[] {0});
116 }
117
118 @Test(expected=NullArgumentException.class)
119 public void testParametricUsage3() {
120 final Logistic.Parametric g = new Logistic.Parametric();
121 g.gradient(0, null);
122 }
123
124 @Test(expected=DimensionMismatchException.class)
125 public void testParametricUsage4() {
126 final Logistic.Parametric g = new Logistic.Parametric();
127 g.gradient(0, new double[] {0});
128 }
129
130 @Test(expected=NotStrictlyPositiveException.class)
131 public void testParametricUsage5() {
132 final Logistic.Parametric g = new Logistic.Parametric();
133 g.value(0, new double[] {1, 0, 1, 1, 0 ,0});
134 }
135
136 @Test(expected=NotStrictlyPositiveException.class)
137 public void testParametricUsage6() {
138 final Logistic.Parametric g = new Logistic.Parametric();
139 g.gradient(0, new double[] {1, 0, 1, 1, 0 ,0});
140 }
141
142 @Test
143 public void testGradientComponent0Component4() {
144 final double k = 3;
145 final double a = 2;
146
147 final Logistic.Parametric f = new Logistic.Parametric();
148
149 final Sigmoid.Parametric g = new Sigmoid.Parametric();
150
151 final double x = 0.12345;
152 final double[] gf = f.gradient(x, new double[] {k, 0, 1, 1, a, 1});
153 final double[] gg = g.gradient(x, new double[] {a, k});
154
155 Assert.assertEquals(gg[0], gf[4], EPS);
156 Assert.assertEquals(gg[1], gf[0], EPS);
157 }
158
159 @Test
160 public void testGradientComponent5() {
161 final double m = 1.2;
162 final double k = 3.4;
163 final double a = 2.3;
164 final double q = 0.567;
165 final double b = -FastMath.log(q);
166 final double n = 3.4;
167
168 final Logistic.Parametric f = new Logistic.Parametric();
169
170 final double x = m - 1;
171 final double qExp1 = 2;
172
173 final double[] gf = f.gradient(x, new double[] {k, m, b, q, a, n});
174
175 Assert.assertEquals((k - a) * FastMath.log(qExp1) / (n * n * FastMath.pow(qExp1, 1 / n)),
176 gf[5], EPS);
177 }
178
179 @Test
180 public void testGradientComponent1Component2Component3() {
181 final double m = 1.2;
182 final double k = 3.4;
183 final double a = 2.3;
184 final double b = 0.567;
185 final double q = 1 / FastMath.exp(b * m);
186 final double n = 3.4;
187
188 final Logistic.Parametric f = new Logistic.Parametric();
189
190 final double x = 0;
191 final double qExp1 = 2;
192
193 final double[] gf = f.gradient(x, new double[] {k, m, b, q, a, n});
194
195 final double factor = (a - k) / (n * FastMath.pow(qExp1, 1 / n + 1));
196 Assert.assertEquals(factor * b, gf[1], EPS);
197 Assert.assertEquals(factor * m, gf[2], EPS);
198 Assert.assertEquals(factor / q, gf[3], EPS);
199 }
200 }