1   
2   
3   
4   
5   
6   
7   
8   
9   
10  
11  
12  
13  
14  
15  
16  
17  
18  package org.apache.commons.math4.legacy.analysis.function;
19  
20  import org.apache.commons.math4.legacy.analysis.UnivariateFunction;
21  import org.apache.commons.math4.legacy.analysis.differentiation.DerivativeStructure;
22  import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
23  import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
24  import org.apache.commons.math4.legacy.exception.NullArgumentException;
25  import org.apache.commons.math4.core.jdkmath.JdkMath;
26  import org.junit.Assert;
27  import org.junit.Test;
28  
29  
30  
31  
32  public class LogisticTest {
33      private final double EPS = Math.ulp(1d);
34  
35      @Test(expected=NotStrictlyPositiveException.class)
36      public void testPreconditions1() {
37          new Logistic(1, 0, 1, 1, 0, -1);
38      }
39  
40      @Test(expected=NotStrictlyPositiveException.class)
41      public void testPreconditions2() {
42          new Logistic(1, 0, 1, 1, 0, 0);
43      }
44  
45      @Test
46      public void testCompareSigmoid() {
47          final UnivariateFunction sig = new Sigmoid();
48          final UnivariateFunction sigL = new Logistic(1, 0, 1, 1, 0, 1);
49  
50          final double min = -2;
51          final double max = 2;
52          final int n = 100;
53          final double delta = (max - min) / n;
54          for (int i = 0; i < n; i++) {
55              final double x = min + i * delta;
56              Assert.assertEquals("x=" + x, sig.value(x), sigL.value(x), EPS);
57          }
58      }
59  
60      @Test
61      public void testSomeValues() {
62          final double k = 4;
63          final double m = 5;
64          final double b = 2;
65          final double q = 3;
66          final double a = -1;
67          final double n = 2;
68  
69          final UnivariateFunction f = new Logistic(k, m, b, q, a, n);
70  
71          double x;
72          x = m;
73          Assert.assertEquals("x=" + x, a + (k - a) / JdkMath.sqrt(1 + q), f.value(x), EPS);
74  
75          x = Double.NEGATIVE_INFINITY;
76          Assert.assertEquals("x=" + x, a, f.value(x), EPS);
77  
78          x = Double.POSITIVE_INFINITY;
79          Assert.assertEquals("x=" + x, k, f.value(x), EPS);
80      }
81  
82      @Test
83      public void testCompareDerivativeSigmoid() {
84          final double k = 3;
85          final double a = 2;
86  
87          final Logistic f = new Logistic(k, 0, 1, 1, a, 1);
88          final Sigmoid g = new Sigmoid(a, k);
89  
90          final double min = -10;
91          final double max = 10;
92          final double n = 20;
93          final double delta = (max - min) / n;
94          for (int i = 0; i < n; i++) {
95              final DerivativeStructure x = new DerivativeStructure(1, 5, 0, min + i * delta);
96              for (int order = 0; order <= x.getOrder(); ++order) {
97                  Assert.assertEquals("x=" + x.getValue(),
98                                      g.value(x).getPartialDerivative(order),
99                                      f.value(x).getPartialDerivative(order),
100                                     3.0e-15);
101             }
102         }
103     }
104 
105     @Test(expected=NullArgumentException.class)
106     public void testParametricUsage1() {
107         final Logistic.Parametric g = new Logistic.Parametric();
108         g.value(0, null);
109     }
110 
111     @Test(expected=DimensionMismatchException.class)
112     public void testParametricUsage2() {
113         final Logistic.Parametric g = new Logistic.Parametric();
114         g.value(0, new double[] {0});
115     }
116 
117     @Test(expected=NullArgumentException.class)
118     public void testParametricUsage3() {
119         final Logistic.Parametric g = new Logistic.Parametric();
120         g.gradient(0, null);
121     }
122 
123     @Test(expected=DimensionMismatchException.class)
124     public void testParametricUsage4() {
125         final Logistic.Parametric g = new Logistic.Parametric();
126         g.gradient(0, new double[] {0});
127     }
128 
129     @Test(expected=NotStrictlyPositiveException.class)
130     public void testParametricUsage5() {
131         final Logistic.Parametric g = new Logistic.Parametric();
132         g.value(0, new double[] {1, 0, 1, 1, 0 ,0});
133     }
134 
135     @Test(expected=NotStrictlyPositiveException.class)
136     public void testParametricUsage6() {
137         final Logistic.Parametric g = new Logistic.Parametric();
138         g.gradient(0, new double[] {1, 0, 1, 1, 0 ,0});
139     }
140 
141     @Test
142     public void testGradientComponent0Component4() {
143         final double k = 3;
144         final double a = 2;
145 
146         final Logistic.Parametric f = new Logistic.Parametric();
147         
148         final Sigmoid.Parametric g = new Sigmoid.Parametric();
149 
150         final double x = 0.12345;
151         final double[] gf = f.gradient(x, new double[] {k, 0, 1, 1, a, 1});
152         final double[] gg = g.gradient(x, new double[] {a, k});
153 
154         Assert.assertEquals(gg[0], gf[4], EPS);
155         Assert.assertEquals(gg[1], gf[0], EPS);
156     }
157 
158     @Test
159     public void testGradientComponent5() {
160         final double m = 1.2;
161         final double k = 3.4;
162         final double a = 2.3;
163         final double q = 0.567;
164         final double b = -JdkMath.log(q);
165         final double n = 3.4;
166 
167         final Logistic.Parametric f = new Logistic.Parametric();
168 
169         final double x = m - 1;
170         final double qExp1 = 2;
171 
172         final double[] gf = f.gradient(x, new double[] {k, m, b, q, a, n});
173 
174         Assert.assertEquals((k - a) * JdkMath.log(qExp1) / (n * n * JdkMath.pow(qExp1, 1 / n)),
175                             gf[5], EPS);
176     }
177 
178     @Test
179     public void testGradientComponent1Component2Component3() {
180         final double m = 1.2;
181         final double k = 3.4;
182         final double a = 2.3;
183         final double b = 0.567;
184         final double q = 1 / JdkMath.exp(b * m);
185         final double n = 3.4;
186 
187         final Logistic.Parametric f = new Logistic.Parametric();
188 
189         final double x = 0;
190         final double qExp1 = 2;
191 
192         final double[] gf = f.gradient(x, new double[] {k, m, b, q, a, n});
193 
194         final double factor = (a - k) / (n * JdkMath.pow(qExp1, 1 / n + 1));
195         Assert.assertEquals(factor * b, gf[1], EPS);
196         Assert.assertEquals(factor * m, gf[2], EPS);
197         Assert.assertEquals(factor / q, gf[3], EPS);
198     }
199 }