1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.math4.legacy.analysis.function;
19
20 import org.apache.commons.math4.legacy.analysis.FunctionUtils;
21 import org.apache.commons.math4.legacy.analysis.UnivariateFunction;
22 import org.apache.commons.math4.legacy.analysis.differentiation.DerivativeStructure;
23 import org.apache.commons.math4.legacy.analysis.differentiation.UnivariateDifferentiableFunction;
24 import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
25 import org.apache.commons.math4.legacy.exception.NullArgumentException;
26 import org.apache.commons.math4.legacy.exception.OutOfRangeException;
27 import org.apache.commons.rng.simple.RandomSource;
28 import org.apache.commons.rng.UniformRandomProvider;
29 import org.apache.commons.math4.core.jdkmath.JdkMath;
30 import org.junit.Assert;
31 import org.junit.Test;
32
33
34
35
36 public class LogitTest {
37 private final double EPS = Math.ulp(1d);
38
39 @Test(expected=OutOfRangeException.class)
40 public void testPreconditions1() {
41 final double lo = -1;
42 final double hi = 2;
43 final UnivariateFunction f = new Logit(lo, hi);
44
45 f.value(lo - 1);
46 }
47
48 @Test(expected=OutOfRangeException.class)
49 public void testPreconditions2() {
50 final double lo = -1;
51 final double hi = 2;
52 final UnivariateFunction f = new Logit(lo, hi);
53
54 f.value(hi + 1);
55 }
56
57 @Test
58 public void testSomeValues() {
59 final double lo = 1;
60 final double hi = 2;
61 final UnivariateFunction f = new Logit(lo, hi);
62
63 Assert.assertEquals(Double.NEGATIVE_INFINITY, f.value(1), EPS);
64 Assert.assertEquals(Double.POSITIVE_INFINITY, f.value(2), EPS);
65 Assert.assertEquals(0, f.value(1.5), EPS);
66 }
67
68 @Test
69 public void testDerivative() {
70 final double lo = 1;
71 final double hi = 2;
72 final Logit f = new Logit(lo, hi);
73 final DerivativeStructure f15 = f.value(new DerivativeStructure(1, 1, 0, 1.5));
74
75 Assert.assertEquals(4, f15.getPartialDerivative(1), EPS);
76 }
77
78 @Test
79 public void testDerivativeLargeArguments() {
80 final Logit f = new Logit(1, 2);
81
82 for (double arg : new double[] {
83 Double.NEGATIVE_INFINITY, -Double.MAX_VALUE, -1e155, 1e155, Double.MAX_VALUE, Double.POSITIVE_INFINITY
84 }) {
85 try {
86 f.value(new DerivativeStructure(1, 1, 0, arg));
87 Assert.fail("an exception should have been thrown");
88 } catch (OutOfRangeException ore) {
89
90 } catch (Exception e) {
91 Assert.fail("wrong exception caught: " + e.getMessage());
92 }
93 }
94 }
95
96 @Test
97 public void testDerivativesHighOrder() {
98 DerivativeStructure l = new Logit(1, 3).value(new DerivativeStructure(1, 5, 0, 1.2));
99 Assert.assertEquals(-2.1972245773362193828, l.getPartialDerivative(0), 1.0e-16);
100 Assert.assertEquals(5.5555555555555555555, l.getPartialDerivative(1), 9.0e-16);
101 Assert.assertEquals(-24.691358024691358025, l.getPartialDerivative(2), 2.0e-14);
102 Assert.assertEquals(250.34293552812071331, l.getPartialDerivative(3), 2.0e-13);
103 Assert.assertEquals(-3749.4284407864654778, l.getPartialDerivative(4), 4.0e-12);
104 Assert.assertEquals(75001.270131585632282, l.getPartialDerivative(5), 8.0e-11);
105 }
106
107 @Test(expected=NullArgumentException.class)
108 public void testParametricUsage1() {
109 final Logit.Parametric g = new Logit.Parametric();
110 g.value(0, null);
111 }
112
113 @Test(expected=DimensionMismatchException.class)
114 public void testParametricUsage2() {
115 final Logit.Parametric g = new Logit.Parametric();
116 g.value(0, new double[] {0});
117 }
118
119 @Test(expected=NullArgumentException.class)
120 public void testParametricUsage3() {
121 final Logit.Parametric g = new Logit.Parametric();
122 g.gradient(0, null);
123 }
124
125 @Test(expected=DimensionMismatchException.class)
126 public void testParametricUsage4() {
127 final Logit.Parametric g = new Logit.Parametric();
128 g.gradient(0, new double[] {0});
129 }
130
131 @Test(expected=OutOfRangeException.class)
132 public void testParametricUsage5() {
133 final Logit.Parametric g = new Logit.Parametric();
134 g.value(-1, new double[] {0, 1});
135 }
136
137 @Test(expected=OutOfRangeException.class)
138 public void testParametricUsage6() {
139 final Logit.Parametric g = new Logit.Parametric();
140 g.value(2, new double[] {0, 1});
141 }
142
143 @Test
144 public void testParametricValue() {
145 final double lo = 2;
146 final double hi = 3;
147 final Logit f = new Logit(lo, hi);
148
149 final Logit.Parametric g = new Logit.Parametric();
150 Assert.assertEquals(f.value(2), g.value(2, new double[] {lo, hi}), 0);
151 Assert.assertEquals(f.value(2.34567), g.value(2.34567, new double[] {lo, hi}), 0);
152 Assert.assertEquals(f.value(3), g.value(3, new double[] {lo, hi}), 0);
153 }
154
155 @Test
156 public void testValueWithInverseFunction() {
157 final double lo = 2;
158 final double hi = 3;
159 final Logit f = new Logit(lo, hi);
160 final Sigmoid g = new Sigmoid(lo, hi);
161 final UniformRandomProvider random = RandomSource.WELL_1024_A.create(0x49914cdd9f0b8db5L);
162 final UnivariateDifferentiableFunction id = FunctionUtils.compose((UnivariateDifferentiableFunction) g,
163 (UnivariateDifferentiableFunction) f);
164
165 for (int i = 0; i < 10; i++) {
166 final double x = lo + random.nextDouble() * (hi - lo);
167 Assert.assertEquals(x, id.value(new DerivativeStructure(1, 1, 0, x)).getValue(), EPS);
168 }
169
170 Assert.assertEquals(lo, id.value(new DerivativeStructure(1, 1, 0, lo)).getValue(), EPS);
171 Assert.assertEquals(hi, id.value(new DerivativeStructure(1, 1, 0, hi)).getValue(), EPS);
172 }
173
174 @Test
175 public void testDerivativesWithInverseFunction() {
176 double[] epsilon = new double[] { 1e-20, 1e-15, 1.5e-14, 2e-11, 1e-8, 1e-6 };
177 final double lo = 2;
178 final double hi = 3;
179 final Logit f = new Logit(lo, hi);
180 final Sigmoid g = new Sigmoid(lo, hi);
181 final UniformRandomProvider random = RandomSource.WELL_1024_A.create();
182 final UnivariateDifferentiableFunction id =
183 FunctionUtils.compose((UnivariateDifferentiableFunction) g, (UnivariateDifferentiableFunction) f);
184 for (int maxOrder = 0; maxOrder < 6; ++maxOrder) {
185 double max = 0;
186 for (int i = 0; i < 10; i++) {
187 final double x = lo + random.nextDouble() * (hi - lo);
188 final DerivativeStructure dsX = new DerivativeStructure(1, maxOrder, 0, x);
189 max = JdkMath.max(max, JdkMath.abs(dsX.getPartialDerivative(maxOrder) -
190 id.value(dsX).getPartialDerivative(maxOrder)));
191 Assert.assertEquals("maxOrder = " + maxOrder,
192 dsX.getPartialDerivative(maxOrder),
193 id.value(dsX).getPartialDerivative(maxOrder),
194 epsilon[maxOrder]);
195 }
196
197
198
199 final DerivativeStructure dsLo = new DerivativeStructure(1, maxOrder, 0, lo);
200 if (maxOrder == 0) {
201 Assert.assertTrue(Double.isInfinite(f.value(dsLo).getPartialDerivative(maxOrder)));
202 Assert.assertEquals(lo, id.value(dsLo).getPartialDerivative(maxOrder), epsilon[maxOrder]);
203 } else if (maxOrder == 1) {
204 Assert.assertTrue(Double.isInfinite(f.value(dsLo).getPartialDerivative(maxOrder)));
205 Assert.assertTrue(Double.isNaN(id.value(dsLo).getPartialDerivative(maxOrder)));
206 } else {
207 Assert.assertTrue(Double.isNaN(f.value(dsLo).getPartialDerivative(maxOrder)));
208 Assert.assertTrue(Double.isNaN(id.value(dsLo).getPartialDerivative(maxOrder)));
209 }
210
211 final DerivativeStructure dsHi = new DerivativeStructure(1, maxOrder, 0, hi);
212 if (maxOrder == 0) {
213 Assert.assertTrue(Double.isInfinite(f.value(dsHi).getPartialDerivative(maxOrder)));
214 Assert.assertEquals(hi, id.value(dsHi).getPartialDerivative(maxOrder), epsilon[maxOrder]);
215 } else if (maxOrder == 1) {
216 Assert.assertTrue(Double.isInfinite(f.value(dsHi).getPartialDerivative(maxOrder)));
217 Assert.assertTrue(Double.isNaN(id.value(dsHi).getPartialDerivative(maxOrder)));
218 } else {
219 Assert.assertTrue(Double.isNaN(f.value(dsHi).getPartialDerivative(maxOrder)));
220 Assert.assertTrue(Double.isNaN(id.value(dsHi).getPartialDerivative(maxOrder)));
221 }
222 }
223 }
224 }