View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math4.legacy.analysis.function;
19  
20  import org.apache.commons.math4.legacy.analysis.UnivariateFunction;
21  import org.apache.commons.math4.legacy.analysis.differentiation.DerivativeStructure;
22  import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
23  import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
24  import org.apache.commons.math4.legacy.exception.NullArgumentException;
25  import org.apache.commons.math4.core.jdkmath.JdkMath;
26  import org.junit.Assert;
27  import org.junit.Test;
28  
29  /**
30   * Test for class {@link Logistic}.
31   */
32  public class LogisticTest {
33      private final double EPS = Math.ulp(1d);
34  
35      @Test(expected=NotStrictlyPositiveException.class)
36      public void testPreconditions1() {
37          new Logistic(1, 0, 1, 1, 0, -1);
38      }
39  
40      @Test(expected=NotStrictlyPositiveException.class)
41      public void testPreconditions2() {
42          new Logistic(1, 0, 1, 1, 0, 0);
43      }
44  
45      @Test
46      public void testCompareSigmoid() {
47          final UnivariateFunction sig = new Sigmoid();
48          final UnivariateFunction sigL = new Logistic(1, 0, 1, 1, 0, 1);
49  
50          final double min = -2;
51          final double max = 2;
52          final int n = 100;
53          final double delta = (max - min) / n;
54          for (int i = 0; i < n; i++) {
55              final double x = min + i * delta;
56              Assert.assertEquals("x=" + x, sig.value(x), sigL.value(x), EPS);
57          }
58      }
59  
60      @Test
61      public void testSomeValues() {
62          final double k = 4;
63          final double m = 5;
64          final double b = 2;
65          final double q = 3;
66          final double a = -1;
67          final double n = 2;
68  
69          final UnivariateFunction f = new Logistic(k, m, b, q, a, n);
70  
71          double x;
72          x = m;
73          Assert.assertEquals("x=" + x, a + (k - a) / JdkMath.sqrt(1 + q), f.value(x), EPS);
74  
75          x = Double.NEGATIVE_INFINITY;
76          Assert.assertEquals("x=" + x, a, f.value(x), EPS);
77  
78          x = Double.POSITIVE_INFINITY;
79          Assert.assertEquals("x=" + x, k, f.value(x), EPS);
80      }
81  
82      @Test
83      public void testCompareDerivativeSigmoid() {
84          final double k = 3;
85          final double a = 2;
86  
87          final Logistic f = new Logistic(k, 0, 1, 1, a, 1);
88          final Sigmoid g = new Sigmoid(a, k);
89  
90          final double min = -10;
91          final double max = 10;
92          final double n = 20;
93          final double delta = (max - min) / n;
94          for (int i = 0; i < n; i++) {
95              final DerivativeStructure x = new DerivativeStructure(1, 5, 0, min + i * delta);
96              for (int order = 0; order <= x.getOrder(); ++order) {
97                  Assert.assertEquals("x=" + x.getValue(),
98                                      g.value(x).getPartialDerivative(order),
99                                      f.value(x).getPartialDerivative(order),
100                                     3.0e-15);
101             }
102         }
103     }
104 
105     @Test(expected=NullArgumentException.class)
106     public void testParametricUsage1() {
107         final Logistic.Parametric g = new Logistic.Parametric();
108         g.value(0, null);
109     }
110 
111     @Test(expected=DimensionMismatchException.class)
112     public void testParametricUsage2() {
113         final Logistic.Parametric g = new Logistic.Parametric();
114         g.value(0, new double[] {0});
115     }
116 
117     @Test(expected=NullArgumentException.class)
118     public void testParametricUsage3() {
119         final Logistic.Parametric g = new Logistic.Parametric();
120         g.gradient(0, null);
121     }
122 
123     @Test(expected=DimensionMismatchException.class)
124     public void testParametricUsage4() {
125         final Logistic.Parametric g = new Logistic.Parametric();
126         g.gradient(0, new double[] {0});
127     }
128 
129     @Test(expected=NotStrictlyPositiveException.class)
130     public void testParametricUsage5() {
131         final Logistic.Parametric g = new Logistic.Parametric();
132         g.value(0, new double[] {1, 0, 1, 1, 0 ,0});
133     }
134 
135     @Test(expected=NotStrictlyPositiveException.class)
136     public void testParametricUsage6() {
137         final Logistic.Parametric g = new Logistic.Parametric();
138         g.gradient(0, new double[] {1, 0, 1, 1, 0 ,0});
139     }
140 
141     @Test
142     public void testGradientComponent0Component4() {
143         final double k = 3;
144         final double a = 2;
145 
146         final Logistic.Parametric f = new Logistic.Parametric();
147         // Compare using the "Sigmoid" function.
148         final Sigmoid.Parametric g = new Sigmoid.Parametric();
149 
150         final double x = 0.12345;
151         final double[] gf = f.gradient(x, new double[] {k, 0, 1, 1, a, 1});
152         final double[] gg = g.gradient(x, new double[] {a, k});
153 
154         Assert.assertEquals(gg[0], gf[4], EPS);
155         Assert.assertEquals(gg[1], gf[0], EPS);
156     }
157 
158     @Test
159     public void testGradientComponent5() {
160         final double m = 1.2;
161         final double k = 3.4;
162         final double a = 2.3;
163         final double q = 0.567;
164         final double b = -JdkMath.log(q);
165         final double n = 3.4;
166 
167         final Logistic.Parametric f = new Logistic.Parametric();
168 
169         final double x = m - 1;
170         final double qExp1 = 2;
171 
172         final double[] gf = f.gradient(x, new double[] {k, m, b, q, a, n});
173 
174         Assert.assertEquals((k - a) * JdkMath.log(qExp1) / (n * n * JdkMath.pow(qExp1, 1 / n)),
175                             gf[5], EPS);
176     }
177 
178     @Test
179     public void testGradientComponent1Component2Component3() {
180         final double m = 1.2;
181         final double k = 3.4;
182         final double a = 2.3;
183         final double b = 0.567;
184         final double q = 1 / JdkMath.exp(b * m);
185         final double n = 3.4;
186 
187         final Logistic.Parametric f = new Logistic.Parametric();
188 
189         final double x = 0;
190         final double qExp1 = 2;
191 
192         final double[] gf = f.gradient(x, new double[] {k, m, b, q, a, n});
193 
194         final double factor = (a - k) / (n * JdkMath.pow(qExp1, 1 / n + 1));
195         Assert.assertEquals(factor * b, gf[1], EPS);
196         Assert.assertEquals(factor * m, gf[2], EPS);
197         Assert.assertEquals(factor / q, gf[3], EPS);
198     }
199 }