1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.commons.statistics.distribution;
19
20 import java.util.function.DoublePredicate;
21
22 /**
23 * Implementation of the hypergeometric distribution.
24 *
25 * <p>The probability mass function of \( X \) is:
26 *
27 * <p>\[ f(k; N, K, n) = \frac{\binom{K}{k} \binom{N - K}{n-k}}{\binom{N}{n}} \]
28 *
29 * <p>for \( N \in \{0, 1, 2, \dots\} \) the population size,
30 * \( K \in \{0, 1, \dots, N\} \) the number of success states,
31 * \( n \in \{0, 1, \dots, N\} \) the number of samples,
32 * \( k \in \{\max(0, n+K-N), \dots, \min(n, K)\} \) the number of successes, and
33 *
34 * <p>\[ \binom{a}{b} = \frac{a!}{b! \, (a-b)!} \]
35 *
36 * <p>is the binomial coefficient.
37 *
38 * @see <a href="https://en.wikipedia.org/wiki/Hypergeometric_distribution">Hypergeometric distribution (Wikipedia)</a>
39 * @see <a href="https://mathworld.wolfram.com/HypergeometricDistribution.html">Hypergeometric distribution (MathWorld)</a>
40 */
41 public final class HypergeometricDistribution extends AbstractDiscreteDistribution {
42 /** 1/2. */
43 private static final double HALF = 0.5;
44 /** The number of successes in the population. */
45 private final int numberOfSuccesses;
46 /** The population size. */
47 private final int populationSize;
48 /** The sample size. */
49 private final int sampleSize;
50 /** The lower bound of the support (inclusive). */
51 private final int lowerBound;
52 /** The upper bound of the support (inclusive). */
53 private final int upperBound;
54 /** Binomial probability of success (sampleSize / populationSize). */
55 private final double bp;
56 /** Binomial probability of failure ((populationSize - sampleSize) / populationSize). */
57 private final double bq;
58 /** Cached midpoint of the CDF/SF. The array holds [x, cdf(x)] for the midpoint x.
59 * Used for the cumulative probability functions. */
60 private double[] midpoint;
61
62 /**
63 * @param populationSize Population size.
64 * @param numberOfSuccesses Number of successes in the population.
65 * @param sampleSize Sample size.
66 */
67 private HypergeometricDistribution(int populationSize,
68 int numberOfSuccesses,
69 int sampleSize) {
70 this.numberOfSuccesses = numberOfSuccesses;
71 this.populationSize = populationSize;
72 this.sampleSize = sampleSize;
73 lowerBound = getLowerDomain(populationSize, numberOfSuccesses, sampleSize);
74 upperBound = getUpperDomain(numberOfSuccesses, sampleSize);
75 bp = (double) sampleSize / populationSize;
76 bq = (double) (populationSize - sampleSize) / populationSize;
77 }
78
79 /**
80 * Creates a hypergeometric distribution.
81 *
82 * @param populationSize Population size.
83 * @param numberOfSuccesses Number of successes in the population.
84 * @param sampleSize Sample size.
85 * @return the distribution
86 * @throws IllegalArgumentException if {@code numberOfSuccesses < 0}, or
87 * {@code populationSize <= 0} or {@code numberOfSuccesses > populationSize}, or
88 * {@code sampleSize > populationSize}.
89 */
90 public static HypergeometricDistribution of(int populationSize,
91 int numberOfSuccesses,
92 int sampleSize) {
93 if (populationSize <= 0) {
94 throw new DistributionException(DistributionException.NOT_STRICTLY_POSITIVE,
95 populationSize);
96 }
97 if (numberOfSuccesses < 0) {
98 throw new DistributionException(DistributionException.NEGATIVE,
99 numberOfSuccesses);
100 }
101 if (sampleSize < 0) {
102 throw new DistributionException(DistributionException.NEGATIVE,
103 sampleSize);
104 }
105
106 if (numberOfSuccesses > populationSize) {
107 throw new DistributionException(DistributionException.TOO_LARGE,
108 numberOfSuccesses, populationSize);
109 }
110 if (sampleSize > populationSize) {
111 throw new DistributionException(DistributionException.TOO_LARGE,
112 sampleSize, populationSize);
113 }
114 return new HypergeometricDistribution(populationSize, numberOfSuccesses, sampleSize);
115 }
116
117 /**
118 * Return the lowest domain value for the given hypergeometric distribution
119 * parameters.
120 *
121 * @param nn Population size.
122 * @param k Number of successes in the population.
123 * @param n Sample size.
124 * @return the lowest domain value of the hypergeometric distribution.
125 */
126 private static int getLowerDomain(int nn, int k, int n) {
127 // Avoid overflow given N > n:
128 // n + K - N == K - (N - n)
129 return Math.max(0, k - (nn - n));
130 }
131
132 /**
133 * Return the highest domain value for the given hypergeometric distribution
134 * parameters.
135 *
136 * @param k Number of successes in the population.
137 * @param n Sample size.
138 * @return the highest domain value of the hypergeometric distribution.
139 */
140 private static int getUpperDomain(int k, int n) {
141 return Math.min(n, k);
142 }
143
144 /**
145 * Gets the population size parameter of this distribution.
146 *
147 * @return the population size.
148 */
149 public int getPopulationSize() {
150 return populationSize;
151 }
152
153 /**
154 * Gets the number of successes parameter of this distribution.
155 *
156 * @return the number of successes.
157 */
158 public int getNumberOfSuccesses() {
159 return numberOfSuccesses;
160 }
161
162 /**
163 * Gets the sample size parameter of this distribution.
164 *
165 * @return the sample size.
166 */
167 public int getSampleSize() {
168 return sampleSize;
169 }
170
171 /** {@inheritDoc} */
172 @Override
173 public double probability(int x) {
174 return Math.exp(logProbability(x));
175 }
176
177 /** {@inheritDoc} */
178 @Override
179 public double probability(int x0, int x1) {
180 if (x0 > x1) {
181 throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GT_HIGH, x0, x1);
182 }
183 if (x0 == x1 || x1 < lowerBound) {
184 return 0;
185 }
186 // If the range is outside the bounds use the appropriate cumulative probability
187 if (x0 < lowerBound) {
188 return cumulativeProbability(x1);
189 }
190 if (x1 >= upperBound) {
191 // 1 - cdf(x0)
192 return survivalProbability(x0);
193 }
194 // Here: lower <= x0 < x1 < upper:
195 // sum(pdf(x)) for x in (x0, x1]
196 final int lo = x0 + 1;
197 // Sum small values first by starting at the point the greatest distance from the mode.
198 final int mode = (int) Math.floor((sampleSize + 1.0) * (numberOfSuccesses + 1.0) / (populationSize + 2.0));
199 return Math.abs(mode - lo) > Math.abs(mode - x1) ?
200 innerCumulativeProbability(lo, x1) :
201 innerCumulativeProbability(x1, lo);
202 }
203
204 /** {@inheritDoc} */
205 @Override
206 public double logProbability(int x) {
207 if (x < lowerBound || x > upperBound) {
208 return Double.NEGATIVE_INFINITY;
209 }
210 return computeLogProbability(x);
211 }
212
213 /**
214 * Compute the log probability.
215 *
216 * @param x Value.
217 * @return log(P(X = x))
218 */
219 private double computeLogProbability(int x) {
220 final double p1 =
221 SaddlePointExpansionUtils.logBinomialProbability(x, numberOfSuccesses, bp, bq);
222 final double p2 =
223 SaddlePointExpansionUtils.logBinomialProbability(sampleSize - x,
224 populationSize - numberOfSuccesses, bp, bq);
225 final double p3 =
226 SaddlePointExpansionUtils.logBinomialProbability(sampleSize, populationSize, bp, bq);
227 return p1 + p2 - p3;
228 }
229
230 /** {@inheritDoc} */
231 @Override
232 public double cumulativeProbability(int x) {
233 if (x < lowerBound) {
234 return 0.0;
235 } else if (x >= upperBound) {
236 return 1.0;
237 }
238 final double[] mid = getMidPoint();
239 final int m = (int) mid[0];
240 if (x < m) {
241 return innerCumulativeProbability(lowerBound, x);
242 } else if (x > m) {
243 return 1 - innerCumulativeProbability(upperBound, x + 1);
244 }
245 // cdf(x)
246 return mid[1];
247 }
248
249 /** {@inheritDoc} */
250 @Override
251 public double survivalProbability(int x) {
252 if (x < lowerBound) {
253 return 1.0;
254 } else if (x >= upperBound) {
255 return 0.0;
256 }
257 final double[] mid = getMidPoint();
258 final int m = (int) mid[0];
259 if (x < m) {
260 return 1 - innerCumulativeProbability(lowerBound, x);
261 } else if (x > m) {
262 return innerCumulativeProbability(upperBound, x + 1);
263 }
264 // 1 - cdf(x)
265 return 1 - mid[1];
266 }
267
268 /**
269 * For this distribution, {@code X}, this method returns
270 * {@code P(x0 <= X <= x1)}.
271 * This probability is computed by summing the point probabilities for the
272 * values {@code x0, x0 + dx, x0 + 2 * dx, ..., x1}; the direction {@code dx} is determined
273 * using a comparison of the input bounds.
274 * This should be called by using {@code x0} as the domain limit and {@code x1}
275 * as the internal value. This will result in an initial sum of increasing larger magnitudes.
276 *
277 * @param x0 Inclusive domain bound.
278 * @param x1 Inclusive internal bound.
279 * @return {@code P(x0 <= X <= x1)}.
280 */
281 private double innerCumulativeProbability(int x0, int x1) {
282 // Assume the range is within the domain.
283 // Reuse the computation for probability(x) but avoid checking the domain for each call.
284 int x = x0;
285 double ret = Math.exp(computeLogProbability(x));
286 if (x0 < x1) {
287 while (x != x1) {
288 x++;
289 ret += Math.exp(computeLogProbability(x));
290 }
291 } else {
292 while (x != x1) {
293 x--;
294 ret += Math.exp(computeLogProbability(x));
295 }
296 }
297 return ret;
298 }
299
300 @Override
301 public int inverseCumulativeProbability(double p) {
302 ArgumentUtils.checkProbability(p);
303 return computeInverseProbability(p, 1 - p, false);
304 }
305
306 @Override
307 public int inverseSurvivalProbability(double p) {
308 ArgumentUtils.checkProbability(p);
309 return computeInverseProbability(1 - p, p, true);
310 }
311
312 /**
313 * Implementation for the inverse cumulative or survival probability.
314 *
315 * @param p Cumulative probability.
316 * @param q Survival probability.
317 * @param complement Set to true to compute the inverse survival probability.
318 * @return the value
319 */
320 private int computeInverseProbability(double p, double q, boolean complement) {
321 if (p == 0) {
322 return lowerBound;
323 }
324 if (q == 0) {
325 return upperBound;
326 }
327
328 // Sum the PDF(x) until the appropriate p-value is obtained
329 // CDF: require smallest x where P(X<=x) >= p
330 // SF: require smallest x where P(X>x) <= q
331 // The choice of summation uses the mid-point.
332 // The test on the CDF or SF is based on the appropriate input p-value.
333
334 final double[] mid = getMidPoint();
335 final int m = (int) mid[0];
336 final double mp = mid[1];
337
338 final int midPointComparison = complement ?
339 Double.compare(1 - mp, q) :
340 Double.compare(p, mp);
341
342 if (midPointComparison < 0) {
343 return inverseLower(p, q, complement);
344 } else if (midPointComparison > 0) {
345 // Avoid floating-point summation error when the mid-point computed using the
346 // lower sum is different to the midpoint computed using the upper sum.
347 // Here we know the result must be above the midpoint so we can clip the result.
348 return Math.max(m + 1, inverseUpper(p, q, complement));
349 }
350 // Exact mid-point
351 return m;
352 }
353
354 /**
355 * Compute the inverse cumulative or survival probability using the lower sum.
356 *
357 * @param p Cumulative probability.
358 * @param q Survival probability.
359 * @param complement Set to true to compute the inverse survival probability.
360 * @return the value
361 */
362 private int inverseLower(double p, double q, boolean complement) {
363 // Sum from the lower bound (computing the cdf)
364 int x = lowerBound;
365 final DoublePredicate test = complement ?
366 i -> 1 - i > q :
367 i -> i < p;
368 double cdf = Math.exp(computeLogProbability(x));
369 while (test.test(cdf)) {
370 x++;
371 cdf += Math.exp(computeLogProbability(x));
372 }
373 return x;
374 }
375
376 /**
377 * Compute the inverse cumulative or survival probability using the upper sum.
378 *
379 * @param p Cumulative probability.
380 * @param q Survival probability.
381 * @param complement Set to true to compute the inverse survival probability.
382 * @return the value
383 */
384 private int inverseUpper(double p, double q, boolean complement) {
385 // Sum from the upper bound (computing the sf)
386 int x = upperBound;
387 final DoublePredicate test = complement ?
388 i -> i < q :
389 i -> 1 - i > p;
390 double sf = 0;
391 while (test.test(sf)) {
392 sf += Math.exp(computeLogProbability(x));
393 x--;
394 }
395 // Here either sf(x) >= q, or cdf(x) <= p
396 // Ensure sf(x) <= q, or cdf(x) >= p
397 if (complement && sf > q ||
398 !complement && 1 - sf < p) {
399 x++;
400 }
401 return x;
402 }
403
404 /**
405 * {@inheritDoc}
406 *
407 * <p>For population size \( N \), number of successes \( K \), and sample
408 * size \( n \), the mean is:
409 *
410 * <p>\[ n \frac{K}{N} \]
411 */
412 @Override
413 public double getMean() {
414 return getSampleSize() * (getNumberOfSuccesses() / (double) getPopulationSize());
415 }
416
417 /**
418 * {@inheritDoc}
419 *
420 * <p>For population size \( N \), number of successes \( K \), and sample
421 * size \( n \), the variance is:
422 *
423 * <p>\[ n \frac{K}{N} \frac{N-K}{N} \frac{N-n}{N-1} \]
424 */
425 @Override
426 public double getVariance() {
427 final double N = getPopulationSize();
428 final double K = getNumberOfSuccesses();
429 final double n = getSampleSize();
430 return (n * K * (N - K) * (N - n)) / (N * N * (N - 1));
431 }
432
433 /**
434 * {@inheritDoc}
435 *
436 * <p>For population size \( N \), number of successes \( K \), and sample
437 * size \( n \), the lower bound of the support is \( \max \{ 0, n + K - N \} \).
438 *
439 * @return lower bound of the support
440 */
441 @Override
442 public int getSupportLowerBound() {
443 return lowerBound;
444 }
445
446 /**
447 * {@inheritDoc}
448 *
449 * <p>For number of successes \( K \), and sample
450 * size \( n \), the upper bound of the support is \( \min \{ n, K \} \).
451 *
452 * @return upper bound of the support
453 */
454 @Override
455 public int getSupportUpperBound() {
456 return upperBound;
457 }
458
459 /**
460 * Return the mid-point {@code x} of the distribution, and the cdf(x).
461 *
462 * <p>This is not the true median. It is the value where the CDF(x) is closest to 0.5;
463 * as such the CDF may be below 0.5 if the next value of x is further from 0.5.
464 *
465 * @return the mid-point ([x, cdf(x)])
466 */
467 private double[] getMidPoint() {
468 double[] v = midpoint;
469 if (v == null) {
470 // Find the closest sum(PDF) to 0.5
471 int x = lowerBound;
472 double p0 = 0;
473 double p1 = Math.exp(computeLogProbability(x));
474 // No check of the upper bound required here as the CDF should sum to 1 and 0.5
475 // is exceeded before a bounds error.
476 while (p1 < HALF) {
477 x++;
478 p0 = p1;
479 p1 += Math.exp(computeLogProbability(x));
480 }
481 // p1 >= 0.5 > p0
482 // Pick closet
483 if (p1 - HALF >= HALF - p0) {
484 x--;
485 p1 = p0;
486 }
487 midpoint = v = new double[] {x, p1};
488 }
489 return v;
490 }
491 }