View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math4.legacy.analysis.function;
19  
20  import org.apache.commons.math4.legacy.analysis.FunctionUtils;
21  import org.apache.commons.math4.legacy.analysis.UnivariateFunction;
22  import org.apache.commons.math4.legacy.analysis.differentiation.DerivativeStructure;
23  import org.apache.commons.math4.legacy.analysis.differentiation.UnivariateDifferentiableFunction;
24  import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
25  import org.apache.commons.math4.legacy.exception.NullArgumentException;
26  import org.apache.commons.math4.legacy.exception.OutOfRangeException;
27  import org.apache.commons.rng.simple.RandomSource;
28  import org.apache.commons.rng.UniformRandomProvider;
29  import org.apache.commons.math4.core.jdkmath.JdkMath;
30  import org.junit.Assert;
31  import org.junit.Test;
32  
33  /**
34   * Test for class {@link Logit}.
35   */
36  public class LogitTest {
37      private final double EPS = Math.ulp(1d);
38  
39      @Test(expected=OutOfRangeException.class)
40      public void testPreconditions1() {
41          final double lo = -1;
42          final double hi = 2;
43          final UnivariateFunction f = new Logit(lo, hi);
44  
45          f.value(lo - 1);
46      }
47  
48      @Test(expected=OutOfRangeException.class)
49      public void testPreconditions2() {
50          final double lo = -1;
51          final double hi = 2;
52          final UnivariateFunction f = new Logit(lo, hi);
53  
54          f.value(hi + 1);
55      }
56  
57      @Test
58      public void testSomeValues() {
59          final double lo = 1;
60          final double hi = 2;
61          final UnivariateFunction f = new Logit(lo, hi);
62  
63          Assert.assertEquals(Double.NEGATIVE_INFINITY, f.value(1), EPS);
64          Assert.assertEquals(Double.POSITIVE_INFINITY, f.value(2), EPS);
65          Assert.assertEquals(0, f.value(1.5), EPS);
66      }
67  
68      @Test
69      public void testDerivative() {
70          final double lo = 1;
71          final double hi = 2;
72          final Logit f = new Logit(lo, hi);
73          final DerivativeStructure f15 = f.value(new DerivativeStructure(1, 1, 0, 1.5));
74  
75          Assert.assertEquals(4, f15.getPartialDerivative(1), EPS);
76      }
77  
78      @Test
79      public void testDerivativeLargeArguments() {
80          final Logit f = new Logit(1, 2);
81  
82          for (double arg : new double[] {
83              Double.NEGATIVE_INFINITY, -Double.MAX_VALUE, -1e155, 1e155, Double.MAX_VALUE, Double.POSITIVE_INFINITY
84              }) {
85              try {
86                  f.value(new DerivativeStructure(1, 1, 0, arg));
87                  Assert.fail("an exception should have been thrown");
88              } catch (OutOfRangeException ore) {
89                  // expected
90              } catch (Exception e) {
91                  Assert.fail("wrong exception caught: " + e.getMessage());
92              }
93          }
94      }
95  
96      @Test
97      public void testDerivativesHighOrder() {
98          DerivativeStructure l = new Logit(1, 3).value(new DerivativeStructure(1, 5, 0, 1.2));
99          Assert.assertEquals(-2.1972245773362193828, l.getPartialDerivative(0), 1.0e-16);
100         Assert.assertEquals(5.5555555555555555555,  l.getPartialDerivative(1), 9.0e-16);
101         Assert.assertEquals(-24.691358024691358025, l.getPartialDerivative(2), 2.0e-14);
102         Assert.assertEquals(250.34293552812071331,  l.getPartialDerivative(3), 2.0e-13);
103         Assert.assertEquals(-3749.4284407864654778, l.getPartialDerivative(4), 4.0e-12);
104         Assert.assertEquals(75001.270131585632282,  l.getPartialDerivative(5), 8.0e-11);
105     }
106 
107     @Test(expected=NullArgumentException.class)
108     public void testParametricUsage1() {
109         final Logit.Parametric g = new Logit.Parametric();
110         g.value(0, null);
111     }
112 
113     @Test(expected=DimensionMismatchException.class)
114     public void testParametricUsage2() {
115         final Logit.Parametric g = new Logit.Parametric();
116         g.value(0, new double[] {0});
117     }
118 
119     @Test(expected=NullArgumentException.class)
120     public void testParametricUsage3() {
121         final Logit.Parametric g = new Logit.Parametric();
122         g.gradient(0, null);
123     }
124 
125     @Test(expected=DimensionMismatchException.class)
126     public void testParametricUsage4() {
127         final Logit.Parametric g = new Logit.Parametric();
128         g.gradient(0, new double[] {0});
129     }
130 
131     @Test(expected=OutOfRangeException.class)
132     public void testParametricUsage5() {
133         final Logit.Parametric g = new Logit.Parametric();
134         g.value(-1, new double[] {0, 1});
135     }
136 
137     @Test(expected=OutOfRangeException.class)
138     public void testParametricUsage6() {
139         final Logit.Parametric g = new Logit.Parametric();
140         g.value(2, new double[] {0, 1});
141     }
142 
143     @Test
144     public void testParametricValue() {
145         final double lo = 2;
146         final double hi = 3;
147         final Logit f = new Logit(lo, hi);
148 
149         final Logit.Parametric g = new Logit.Parametric();
150         Assert.assertEquals(f.value(2), g.value(2, new double[] {lo, hi}), 0);
151         Assert.assertEquals(f.value(2.34567), g.value(2.34567, new double[] {lo, hi}), 0);
152         Assert.assertEquals(f.value(3), g.value(3, new double[] {lo, hi}), 0);
153     }
154 
155     @Test
156     public void testValueWithInverseFunction() {
157         final double lo = 2;
158         final double hi = 3;
159         final Logit f = new Logit(lo, hi);
160         final Sigmoid g = new Sigmoid(lo, hi);
161         final UniformRandomProvider random = RandomSource.WELL_1024_A.create(0x49914cdd9f0b8db5L);
162         final UnivariateDifferentiableFunction id = FunctionUtils.compose((UnivariateDifferentiableFunction) g,
163                                                                 (UnivariateDifferentiableFunction) f);
164 
165         for (int i = 0; i < 10; i++) {
166             final double x = lo + random.nextDouble() * (hi - lo);
167             Assert.assertEquals(x, id.value(new DerivativeStructure(1, 1, 0, x)).getValue(), EPS);
168         }
169 
170         Assert.assertEquals(lo, id.value(new DerivativeStructure(1, 1, 0, lo)).getValue(), EPS);
171         Assert.assertEquals(hi, id.value(new DerivativeStructure(1, 1, 0, hi)).getValue(), EPS);
172     }
173 
174     @Test
175     public void testDerivativesWithInverseFunction() {
176         double[] epsilon = new double[] { 1e-20, 1e-15, 1.5e-14, 2e-11, 1e-8, 1e-6 };
177         final double lo = 2;
178         final double hi = 3;
179         final Logit f = new Logit(lo, hi);
180         final Sigmoid g = new Sigmoid(lo, hi);
181         final UniformRandomProvider random = RandomSource.WELL_1024_A.create();
182         final UnivariateDifferentiableFunction id =
183                 FunctionUtils.compose((UnivariateDifferentiableFunction) g, (UnivariateDifferentiableFunction) f);
184         for (int maxOrder = 0; maxOrder < 6; ++maxOrder) {
185             double max = 0;
186             for (int i = 0; i < 10; i++) {
187                 final double x = lo + random.nextDouble() * (hi - lo);
188                 final DerivativeStructure dsX = new DerivativeStructure(1, maxOrder, 0, x);
189                 max = JdkMath.max(max, JdkMath.abs(dsX.getPartialDerivative(maxOrder) -
190                                                      id.value(dsX).getPartialDerivative(maxOrder)));
191                 Assert.assertEquals("maxOrder = " + maxOrder,
192                                     dsX.getPartialDerivative(maxOrder),
193                                     id.value(dsX).getPartialDerivative(maxOrder),
194                                     epsilon[maxOrder]);
195             }
196 
197             // each function evaluates correctly near boundaries,
198             // but combination leads to NaN as some intermediate point is infinite
199             final DerivativeStructure dsLo = new DerivativeStructure(1, maxOrder, 0, lo);
200             if (maxOrder == 0) {
201                 Assert.assertTrue(Double.isInfinite(f.value(dsLo).getPartialDerivative(maxOrder)));
202                 Assert.assertEquals(lo, id.value(dsLo).getPartialDerivative(maxOrder), epsilon[maxOrder]);
203             } else if (maxOrder == 1) {
204                 Assert.assertTrue(Double.isInfinite(f.value(dsLo).getPartialDerivative(maxOrder)));
205                 Assert.assertTrue(Double.isNaN(id.value(dsLo).getPartialDerivative(maxOrder)));
206             } else {
207                 Assert.assertTrue(Double.isNaN(f.value(dsLo).getPartialDerivative(maxOrder)));
208                 Assert.assertTrue(Double.isNaN(id.value(dsLo).getPartialDerivative(maxOrder)));
209             }
210 
211             final DerivativeStructure dsHi = new DerivativeStructure(1, maxOrder, 0, hi);
212             if (maxOrder == 0) {
213                 Assert.assertTrue(Double.isInfinite(f.value(dsHi).getPartialDerivative(maxOrder)));
214                 Assert.assertEquals(hi, id.value(dsHi).getPartialDerivative(maxOrder), epsilon[maxOrder]);
215             } else if (maxOrder == 1) {
216                 Assert.assertTrue(Double.isInfinite(f.value(dsHi).getPartialDerivative(maxOrder)));
217                 Assert.assertTrue(Double.isNaN(id.value(dsHi).getPartialDerivative(maxOrder)));
218             } else {
219                 Assert.assertTrue(Double.isNaN(f.value(dsHi).getPartialDerivative(maxOrder)));
220                 Assert.assertTrue(Double.isNaN(id.value(dsHi).getPartialDerivative(maxOrder)));
221             }
222         }
223     }
224 }