1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.statistics.inference;
18
19 import java.util.Arrays;
20 import java.util.stream.IntStream;
21 import java.util.stream.Stream;
22 import org.apache.commons.statistics.distribution.BinomialDistribution;
23 import org.junit.jupiter.api.Assertions;
24 import org.junit.jupiter.api.Test;
25 import org.junit.jupiter.params.ParameterizedTest;
26 import org.junit.jupiter.params.provider.Arguments;
27 import org.junit.jupiter.params.provider.CsvSource;
28 import org.junit.jupiter.params.provider.MethodSource;
29
30
31
32
33 class BinomialTestTest {
34
35 @Test
36 void testInvalidOptionsThrows() {
37 final BinomialTest test = BinomialTest.withDefaults();
38 Assertions.assertThrows(NullPointerException.class, () ->
39 test.with((AlternativeHypothesis) null));
40 }
41
42 @ParameterizedTest
43 @CsvSource({
44 "10, 5, -1",
45 "10, 5, 2",
46 "10, -1, 0.5",
47 "10, 11, 0.5",
48 "-1, 5, 0.5",
49 "1, 2, 0.5",
50 })
51 void testBinomialTestThrows(int n, int k, double p) {
52 final BinomialTest test = BinomialTest.withDefaults();
53 Assertions.assertThrows(IllegalArgumentException.class, () -> test.test(n, k, p));
54 }
55
56
57
58
59
60
61
62
63
64
65
66
67
68 @ParameterizedTest
69 @CsvSource({
70 "0, 0.25, 1e-15, 0",
71 "1, 0.25, 1e-15, 0",
72 "2, 0.25, 1e-15, 0",
73 "0, 0.5, 1e-15, 0",
74 "1, 0.5, 1e-15, 0",
75 "2, 0.5, 1e-15, 0",
76 "0, 0.75, 1e-15, 0",
77 "1, 0.75, 1e-15, 0",
78 "2, 0.75, 1e-15, 0",
79 "10, 0.25, 2e-15, 0",
80 "10, 0.49, 2e-15, 0",
81 "10, 0.5, 2e-15, 0",
82 "10, 0.51, 2e-15, 0",
83 "10, 0.75, 2e-15, 0",
84 "11, 0.25, 3e-15, 0",
85 "11, 0.49, 2e-15, 0",
86 "11, 0.5, 2e-15, 0",
87 "11, 0.51, 2e-15, 0",
88 "11, 0.75, 3e-15, 0",
89 "5, 0.1, 2e-15, 0",
90 "5, 0.7, 1e-15, 0",
91 "20, 0.7, 3e-15, 0",
92 })
93 void testBinomialTest(int n, double p, double eps) {
94 final BinomialDistribution dist = BinomialDistribution.of(n, p);
95 final double[] pk = IntStream.rangeClosed(0, n).mapToDouble(dist::probability).toArray();
96
97
98
99 final double maxP = Math.nextDown(1.0);
100
101 final BinomialTest twoSided = BinomialTest.withDefaults();
102 final BinomialTest less = BinomialTest.withDefaults().with(AlternativeHypothesis.LESS_THAN);
103 final BinomialTest greater = BinomialTest.withDefaults().with(AlternativeHypothesis.GREATER_THAN);
104
105 IntStream.rangeClosed(0, n).forEach(k -> {
106 double expected;
107
108
109 expected = Math.min(maxP, IntStream.rangeClosed(0, k).mapToDouble(i -> pk[i]).sum());
110 TestUtils.assertProbability(expected,
111 less.test(n, k, p).getPValue(), eps,
112 () -> "less than: k=" + k);
113
114 expected = Math.min(maxP, IntStream.rangeClosed(k, n).mapToDouble(i -> pk[i]).sum());
115 TestUtils.assertProbability(expected,
116 greater.test(n, k, p).getPValue(), eps,
117 () -> "greater than: k=" + k);
118
119
120
121 expected = Math.min(maxP, Arrays.stream(pk).filter(x -> x <= pk[k]).sum());
122 TestUtils.assertProbability(expected,
123 twoSided.test(n, k, p).getPValue(), eps,
124 () -> "two-sided: k=" + k);
125 });
126 }
127
128
129
130
131
132
133
134
135
136
137
138
139 @ParameterizedTest
140 @CsvSource({
141 "1234, 0.3",
142 "1234, 0.55",
143 "1234, 0.87",
144 "12345, 0.3",
145 "12345, 0.55",
146 "12345, 0.87",
147
148 "10000, 0.49999",
149 "10000, 0.50001",
150 })
151 void testBinomialTestLargeN(int n, double p) {
152 final BinomialDistribution dist = BinomialDistribution.of(n, p);
153
154
155
156 final double[] pk = IntStream.rangeClosed(0, n).mapToDouble(dist::logProbability).toArray();
157
158
159
160 final double eps = 0;
161
162 final BinomialTest twoSided = BinomialTest.withDefaults();
163 final BinomialTest less = BinomialTest.withDefaults().with(AlternativeHypothesis.LESS_THAN);
164 final BinomialTest greater = BinomialTest.withDefaults().with(AlternativeHypothesis.GREATER_THAN);
165
166 IntStream.rangeClosed(0, n).forEach(k -> {
167 double expected;
168
169
170 expected = dist.cumulativeProbability(k);
171 TestUtils.assertProbability(expected,
172 less.test(n, k, p).getPValue(), eps,
173 () -> "less than: k=" + k);
174
175 expected = dist.survivalProbability(k - 1);
176 TestUtils.assertProbability(expected,
177 greater.test(n, k, p).getPValue(), eps,
178 () -> "greater than: k=" + k);
179
180
181
182 int i = -1;
183 while (i < n && pk[i + 1] <= pk[k]) {
184 i++;
185 }
186 int j = n + 1;
187 while (j > 0 && pk[j - 1] <= pk[k]) {
188 j--;
189 }
190 expected = j <= i ? 1 : dist.cumulativeProbability(i) + dist.survivalProbability(j - 1);
191 TestUtils.assertProbability(expected,
192 twoSided.test(n, k, p).getPValue(), eps,
193 () -> "two-sided: k=" + k);
194 });
195 }
196
197 @Test
198 void testBinomialTestPValues() {
199 final int successes = 51;
200 final int trials = 235;
201 final double probability = 1.0 / 6.0;
202
203 Assertions.assertEquals(0.04375, BinomialTest.withDefaults()
204 .test(trials, successes, probability).getPValue(), 1e-4);
205 Assertions.assertEquals(0.02654, BinomialTest.withDefaults().with(AlternativeHypothesis.GREATER_THAN)
206 .test(trials, successes, probability).getPValue(), 1e-4);
207 Assertions.assertEquals(0.982, BinomialTest.withDefaults().with(AlternativeHypothesis.LESS_THAN)
208 .test(trials, successes, probability).getPValue(), 1e-4);
209
210
211 final BinomialTest twoSided = BinomialTest.withDefaults();
212 Assertions.assertEquals(1, twoSided.test(3, 3, 1).getPValue(), 1e-4);
213 Assertions.assertEquals(1, twoSided.test(3, 3, 0.9).getPValue(), 1e-4);
214 Assertions.assertEquals(1, twoSided.test(3, 3, 0.8).getPValue(), 1e-4);
215 Assertions.assertEquals(0.559, twoSided.test(3, 3, 0.7).getPValue(), 1e-4);
216 Assertions.assertEquals(0.28, twoSided.test(3, 3, 0.6).getPValue(), 1e-4);
217 Assertions.assertEquals(0.25, twoSided.test(3, 3, 0.5).getPValue(), 1e-4);
218 Assertions.assertEquals(0.064, twoSided.test(3, 3, 0.4).getPValue(), 1e-4);
219 Assertions.assertEquals(0.027, twoSided.test(3, 3, 0.3).getPValue(), 1e-4);
220 Assertions.assertEquals(0.008, twoSided.test(3, 3, 0.2).getPValue(), 1e-4);
221 Assertions.assertEquals(0.001, twoSided.test(3, 3, 0.1).getPValue(), 1e-4);
222 Assertions.assertEquals(0, twoSided.test(3, 3, 0.0).getPValue(), 1e-4);
223
224 Assertions.assertEquals(0, twoSided.test(3, 0, 1).getPValue(), 1e-4);
225 Assertions.assertEquals(0.001, twoSided.test(3, 0, 0.9).getPValue(), 1e-4);
226 Assertions.assertEquals(0.008, twoSided.test(3, 0, 0.8).getPValue(), 1e-4);
227 Assertions.assertEquals(0.027, twoSided.test(3, 0, 0.7).getPValue(), 1e-4);
228 Assertions.assertEquals(0.064, twoSided.test(3, 0, 0.6).getPValue(), 1e-4);
229 Assertions.assertEquals(0.25, twoSided.test(3, 0, 0.5).getPValue(), 1e-4);
230 Assertions.assertEquals(0.28, twoSided.test(3, 0, 0.4).getPValue(), 1e-4);
231 Assertions.assertEquals(0.559, twoSided.test(3, 0, 0.3).getPValue(), 1e-4);
232 Assertions.assertEquals(1, twoSided.test(3, 0, 0.2).getPValue(), 1e-4);
233 Assertions.assertEquals(1, twoSided.test(3, 0, 0.1).getPValue(), 1e-4);
234 Assertions.assertEquals(1, twoSided.test(3, 0, 0.0).getPValue(), 1e-4);
235 }
236
237 @ParameterizedTest
238 @CsvSource({
239
240 "10, 5, 0.5",
241 "11, 5, 0.5",
242 "11, 6, 0.5",
243 "20, 5, 0.25",
244 "21, 5, 0.25",
245 "21, 6, 0.25",
246 "20, 15, 0.75",
247 "21, 15, 0.75",
248 "21, 16, 0.75",
249 })
250 void testMath1644(int n, int k, double p) {
251 final double pval = BinomialTest.withDefaults().test(n, k, p).getPValue();
252 Assertions.assertTrue(pval <= 1, () -> "pval=" + pval);
253 }
254
255 @ParameterizedTest
256 @MethodSource
257 void testBinomTest(int n, int k, double probability, double[] p) {
258 int i = 0;
259 for (final AlternativeHypothesis h : AlternativeHypothesis.values()) {
260 final SignificanceResult r = BinomialTest.withDefaults().with(h).test(n, k, probability);
261 Assertions.assertEquals((double) k / n, r.getStatistic(), "statistic");
262 TestUtils.assertProbability(p[i++], r.getPValue(), 1e-14, "p-value");
263 }
264 }
265
266 static Stream<Arguments> testBinomTest() {
267
268
269 final Stream.Builder<Arguments> builder = Stream.builder();
270 builder.add(Arguments.of(15, 3, 0.1,
271 new double[] {0.18406106910639106, 0.18406106910639106, 0.944444369992464}));
272 builder.add(Arguments.of(150, 37, 0.25,
273 new double[] {1.0, 0.5687513546881982, 0.5062937783866548}));
274 builder.add(Arguments.of(150, 67, 0.25,
275 new double[] {2.083753914662947e-07, 1.2964820621216238e-07, 0.9999999481384629}));
276 builder.add(Arguments.of(150, 17, 0.25,
277 new double[] {4.229481760264341e-05, 0.9999911956737946, 2.399451075709081e-05}));
278 return builder.build();
279 }
280 }