View Javadoc

1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math3.analysis.function;
19  
20  import org.apache.commons.math3.analysis.UnivariateFunction;
21  import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
22  import org.apache.commons.math3.exception.NotStrictlyPositiveException;
23  import org.apache.commons.math3.exception.NullArgumentException;
24  import org.apache.commons.math3.exception.DimensionMismatchException;
25  import org.apache.commons.math3.util.FastMath;
26  
27  import org.junit.Assert;
28  import org.junit.Test;
29  
30  /**
31   * Test for class {@link Logistic}.
32   */
33  public class LogisticTest {
34      private final double EPS = Math.ulp(1d);
35  
36      @Test(expected=NotStrictlyPositiveException.class)
37      public void testPreconditions1() {
38          new Logistic(1, 0, 1, 1, 0, -1);
39      }
40  
41      @Test(expected=NotStrictlyPositiveException.class)
42      public void testPreconditions2() {
43          new Logistic(1, 0, 1, 1, 0, 0);
44      }
45  
46      @Test
47      public void testCompareSigmoid() {
48          final UnivariateFunction sig = new Sigmoid();
49          final UnivariateFunction sigL = new Logistic(1, 0, 1, 1, 0, 1);
50  
51          final double min = -2;
52          final double max = 2;
53          final int n = 100;
54          final double delta = (max - min) / n;
55          for (int i = 0; i < n; i++) {
56              final double x = min + i * delta;
57              Assert.assertEquals("x=" + x, sig.value(x), sigL.value(x), EPS);
58          }
59      }
60  
61      @Test
62      public void testSomeValues() {
63          final double k = 4;
64          final double m = 5;
65          final double b = 2;
66          final double q = 3;
67          final double a = -1;
68          final double n = 2;
69  
70          final UnivariateFunction f = new Logistic(k, m, b, q, a, n);
71  
72          double x;
73          x = m;
74          Assert.assertEquals("x=" + x, a + (k - a) / FastMath.sqrt(1 + q), f.value(x), EPS);
75  
76          x = Double.NEGATIVE_INFINITY;
77          Assert.assertEquals("x=" + x, a, f.value(x), EPS);
78  
79          x = Double.POSITIVE_INFINITY;
80          Assert.assertEquals("x=" + x, k, f.value(x), EPS);
81      }
82  
83      @Test
84      public void testCompareDerivativeSigmoid() {
85          final double k = 3;
86          final double a = 2;
87  
88          final Logistic f = new Logistic(k, 0, 1, 1, a, 1);
89          final Sigmoid g = new Sigmoid(a, k);
90          
91          final double min = -10;
92          final double max = 10;
93          final double n = 20;
94          final double delta = (max - min) / n;
95          for (int i = 0; i < n; i++) {
96              final DerivativeStructure x = new DerivativeStructure(1, 5, 0, min + i * delta);
97              for (int order = 0; order <= x.getOrder(); ++order) {
98                  Assert.assertEquals("x=" + x.getValue(),
99                                      g.value(x).getPartialDerivative(order),
100                                     f.value(x).getPartialDerivative(order),
101                                     3.0e-15);
102             }
103         }
104     }
105 
106     @Test(expected=NullArgumentException.class)
107     public void testParametricUsage1() {
108         final Logistic.Parametric g = new Logistic.Parametric();
109         g.value(0, null);
110     }
111 
112     @Test(expected=DimensionMismatchException.class)
113     public void testParametricUsage2() {
114         final Logistic.Parametric g = new Logistic.Parametric();
115         g.value(0, new double[] {0});
116     }
117 
118     @Test(expected=NullArgumentException.class)
119     public void testParametricUsage3() {
120         final Logistic.Parametric g = new Logistic.Parametric();
121         g.gradient(0, null);
122     }
123 
124     @Test(expected=DimensionMismatchException.class)
125     public void testParametricUsage4() {
126         final Logistic.Parametric g = new Logistic.Parametric();
127         g.gradient(0, new double[] {0});
128     }
129 
130     @Test(expected=NotStrictlyPositiveException.class)
131     public void testParametricUsage5() {
132         final Logistic.Parametric g = new Logistic.Parametric();
133         g.value(0, new double[] {1, 0, 1, 1, 0 ,0});
134     }
135 
136     @Test(expected=NotStrictlyPositiveException.class)
137     public void testParametricUsage6() {
138         final Logistic.Parametric g = new Logistic.Parametric();
139         g.gradient(0, new double[] {1, 0, 1, 1, 0 ,0});
140     }
141 
142     @Test
143     public void testGradientComponent0Component4() {
144         final double k = 3;
145         final double a = 2;
146 
147         final Logistic.Parametric f = new Logistic.Parametric();
148         // Compare using the "Sigmoid" function.
149         final Sigmoid.Parametric g = new Sigmoid.Parametric();
150         
151         final double x = 0.12345;
152         final double[] gf = f.gradient(x, new double[] {k, 0, 1, 1, a, 1});
153         final double[] gg = g.gradient(x, new double[] {a, k});
154 
155         Assert.assertEquals(gg[0], gf[4], EPS);
156         Assert.assertEquals(gg[1], gf[0], EPS);
157     }
158 
159     @Test
160     public void testGradientComponent5() {
161         final double m = 1.2;
162         final double k = 3.4;
163         final double a = 2.3;
164         final double q = 0.567;
165         final double b = -FastMath.log(q);
166         final double n = 3.4;
167 
168         final Logistic.Parametric f = new Logistic.Parametric();
169         
170         final double x = m - 1;
171         final double qExp1 = 2;
172 
173         final double[] gf = f.gradient(x, new double[] {k, m, b, q, a, n});
174 
175         Assert.assertEquals((k - a) * FastMath.log(qExp1) / (n * n * FastMath.pow(qExp1, 1 / n)),
176                             gf[5], EPS);
177     }
178 
179     @Test
180     public void testGradientComponent1Component2Component3() {
181         final double m = 1.2;
182         final double k = 3.4;
183         final double a = 2.3;
184         final double b = 0.567;
185         final double q = 1 / FastMath.exp(b * m);
186         final double n = 3.4;
187 
188         final Logistic.Parametric f = new Logistic.Parametric();
189         
190         final double x = 0;
191         final double qExp1 = 2;
192 
193         final double[] gf = f.gradient(x, new double[] {k, m, b, q, a, n});
194 
195         final double factor = (a - k) / (n * FastMath.pow(qExp1, 1 / n + 1));
196         Assert.assertEquals(factor * b, gf[1], EPS);
197         Assert.assertEquals(factor * m, gf[2], EPS);
198         Assert.assertEquals(factor / q, gf[3], EPS);
199     }
200 }