1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.statistics.distribution;
19
20 import java.util.function.DoublePredicate;
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41 public final class HypergeometricDistribution extends AbstractDiscreteDistribution {
42
43 private static final double HALF = 0.5;
44
45 private final int numberOfSuccesses;
46
47 private final int populationSize;
48
49 private final int sampleSize;
50
51 private final int lowerBound;
52
53 private final int upperBound;
54
55 private final double bp;
56
57 private final double bq;
58
59
60 private double[] midpoint;
61
62
63
64
65
66
67 private HypergeometricDistribution(int populationSize,
68 int numberOfSuccesses,
69 int sampleSize) {
70 this.numberOfSuccesses = numberOfSuccesses;
71 this.populationSize = populationSize;
72 this.sampleSize = sampleSize;
73 lowerBound = getLowerDomain(populationSize, numberOfSuccesses, sampleSize);
74 upperBound = getUpperDomain(numberOfSuccesses, sampleSize);
75 bp = (double) sampleSize / populationSize;
76 bq = (double) (populationSize - sampleSize) / populationSize;
77 }
78
79
80
81
82
83
84
85
86
87
88
89
90 public static HypergeometricDistribution of(int populationSize,
91 int numberOfSuccesses,
92 int sampleSize) {
93 if (populationSize <= 0) {
94 throw new DistributionException(DistributionException.NOT_STRICTLY_POSITIVE,
95 populationSize);
96 }
97 if (numberOfSuccesses < 0) {
98 throw new DistributionException(DistributionException.NEGATIVE,
99 numberOfSuccesses);
100 }
101 if (sampleSize < 0) {
102 throw new DistributionException(DistributionException.NEGATIVE,
103 sampleSize);
104 }
105
106 if (numberOfSuccesses > populationSize) {
107 throw new DistributionException(DistributionException.TOO_LARGE,
108 numberOfSuccesses, populationSize);
109 }
110 if (sampleSize > populationSize) {
111 throw new DistributionException(DistributionException.TOO_LARGE,
112 sampleSize, populationSize);
113 }
114 return new HypergeometricDistribution(populationSize, numberOfSuccesses, sampleSize);
115 }
116
117
118
119
120
121
122
123
124
125
126 private static int getLowerDomain(int nn, int k, int n) {
127
128
129 return Math.max(0, k - (nn - n));
130 }
131
132
133
134
135
136
137
138
139
140 private static int getUpperDomain(int k, int n) {
141 return Math.min(n, k);
142 }
143
144
145
146
147
148
149 public int getPopulationSize() {
150 return populationSize;
151 }
152
153
154
155
156
157
158 public int getNumberOfSuccesses() {
159 return numberOfSuccesses;
160 }
161
162
163
164
165
166
167 public int getSampleSize() {
168 return sampleSize;
169 }
170
171
172 @Override
173 public double probability(int x) {
174 return Math.exp(logProbability(x));
175 }
176
177
178 @Override
179 public double probability(int x0, int x1) {
180 if (x0 > x1) {
181 throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GT_HIGH, x0, x1);
182 }
183 if (x0 == x1 || x1 < lowerBound) {
184 return 0;
185 }
186
187 if (x0 < lowerBound) {
188 return cumulativeProbability(x1);
189 }
190 if (x1 >= upperBound) {
191
192 return survivalProbability(x0);
193 }
194
195
196 final int lo = x0 + 1;
197
198 final int mode = (int) Math.floor((sampleSize + 1.0) * (numberOfSuccesses + 1.0) / (populationSize + 2.0));
199 return Math.abs(mode - lo) > Math.abs(mode - x1) ?
200 innerCumulativeProbability(lo, x1) :
201 innerCumulativeProbability(x1, lo);
202 }
203
204
205 @Override
206 public double logProbability(int x) {
207 if (x < lowerBound || x > upperBound) {
208 return Double.NEGATIVE_INFINITY;
209 }
210 return computeLogProbability(x);
211 }
212
213
214
215
216
217
218
219 private double computeLogProbability(int x) {
220 final double p1 =
221 SaddlePointExpansionUtils.logBinomialProbability(x, numberOfSuccesses, bp, bq);
222 final double p2 =
223 SaddlePointExpansionUtils.logBinomialProbability(sampleSize - x,
224 populationSize - numberOfSuccesses, bp, bq);
225 final double p3 =
226 SaddlePointExpansionUtils.logBinomialProbability(sampleSize, populationSize, bp, bq);
227 return p1 + p2 - p3;
228 }
229
230
231 @Override
232 public double cumulativeProbability(int x) {
233 if (x < lowerBound) {
234 return 0.0;
235 } else if (x >= upperBound) {
236 return 1.0;
237 }
238 final double[] mid = getMidPoint();
239 final int m = (int) mid[0];
240 if (x < m) {
241 return innerCumulativeProbability(lowerBound, x);
242 } else if (x > m) {
243 return 1 - innerCumulativeProbability(upperBound, x + 1);
244 }
245
246 return mid[1];
247 }
248
249
250 @Override
251 public double survivalProbability(int x) {
252 if (x < lowerBound) {
253 return 1.0;
254 } else if (x >= upperBound) {
255 return 0.0;
256 }
257 final double[] mid = getMidPoint();
258 final int m = (int) mid[0];
259 if (x < m) {
260 return 1 - innerCumulativeProbability(lowerBound, x);
261 } else if (x > m) {
262 return innerCumulativeProbability(upperBound, x + 1);
263 }
264
265 return 1 - mid[1];
266 }
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281 private double innerCumulativeProbability(int x0, int x1) {
282
283
284 int x = x0;
285 double ret = Math.exp(computeLogProbability(x));
286 if (x0 < x1) {
287 while (x != x1) {
288 x++;
289 ret += Math.exp(computeLogProbability(x));
290 }
291 } else {
292 while (x != x1) {
293 x--;
294 ret += Math.exp(computeLogProbability(x));
295 }
296 }
297 return ret;
298 }
299
300 @Override
301 public int inverseCumulativeProbability(double p) {
302 ArgumentUtils.checkProbability(p);
303 return computeInverseProbability(p, 1 - p, false);
304 }
305
306 @Override
307 public int inverseSurvivalProbability(double p) {
308 ArgumentUtils.checkProbability(p);
309 return computeInverseProbability(1 - p, p, true);
310 }
311
312
313
314
315
316
317
318
319
320 private int computeInverseProbability(double p, double q, boolean complement) {
321 if (p == 0) {
322 return lowerBound;
323 }
324 if (q == 0) {
325 return upperBound;
326 }
327
328
329
330
331
332
333
334 final double[] mid = getMidPoint();
335 final int m = (int) mid[0];
336 final double mp = mid[1];
337
338 final int midPointComparison = complement ?
339 Double.compare(1 - mp, q) :
340 Double.compare(p, mp);
341
342 if (midPointComparison < 0) {
343 return inverseLower(p, q, complement);
344 } else if (midPointComparison > 0) {
345
346
347
348 return Math.max(m + 1, inverseUpper(p, q, complement));
349 }
350
351 return m;
352 }
353
354
355
356
357
358
359
360
361
362 private int inverseLower(double p, double q, boolean complement) {
363
364 int x = lowerBound;
365 final DoublePredicate test = complement ?
366 i -> 1 - i > q :
367 i -> i < p;
368 double cdf = Math.exp(computeLogProbability(x));
369 while (test.test(cdf)) {
370 x++;
371 cdf += Math.exp(computeLogProbability(x));
372 }
373 return x;
374 }
375
376
377
378
379
380
381
382
383
384 private int inverseUpper(double p, double q, boolean complement) {
385
386 int x = upperBound;
387 final DoublePredicate test = complement ?
388 i -> i < q :
389 i -> 1 - i > p;
390 double sf = 0;
391 while (test.test(sf)) {
392 sf += Math.exp(computeLogProbability(x));
393 x--;
394 }
395
396
397 if (complement && sf > q ||
398 !complement && 1 - sf < p) {
399 x++;
400 }
401 return x;
402 }
403
404
405
406
407
408
409
410
411
412 @Override
413 public double getMean() {
414 return getSampleSize() * (getNumberOfSuccesses() / (double) getPopulationSize());
415 }
416
417
418
419
420
421
422
423
424
425 @Override
426 public double getVariance() {
427 final double N = getPopulationSize();
428 final double K = getNumberOfSuccesses();
429 final double n = getSampleSize();
430 return (n * K * (N - K) * (N - n)) / (N * N * (N - 1));
431 }
432
433
434
435
436
437
438
439
440
441 @Override
442 public int getSupportLowerBound() {
443 return lowerBound;
444 }
445
446
447
448
449
450
451
452
453
454 @Override
455 public int getSupportUpperBound() {
456 return upperBound;
457 }
458
459
460
461
462
463
464
465
466
467 private double[] getMidPoint() {
468 double[] v = midpoint;
469 if (v == null) {
470
471 int x = lowerBound;
472 double p0 = 0;
473 double p1 = Math.exp(computeLogProbability(x));
474
475
476 while (p1 < HALF) {
477 x++;
478 p0 = p1;
479 p1 += Math.exp(computeLogProbability(x));
480 }
481
482
483 if (p1 - HALF >= HALF - p0) {
484 x--;
485 p1 = p0;
486 }
487 midpoint = v = new double[] {x, p1};
488 }
489 return v;
490 }
491 }