View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.statistics.distribution;
19  
20  import java.util.function.DoublePredicate;
21  
22  /**
23   * Implementation of the hypergeometric distribution.
24   *
25   * <p>The probability mass function of \( X \) is:
26   *
27   * <p>\[ f(k; N, K, n) = \frac{\binom{K}{k} \binom{N - K}{n-k}}{\binom{N}{n}} \]
28   *
29   * <p>for \( N \in \{0, 1, 2, \dots\} \) the population size,
30   * \( K \in \{0, 1, \dots, N\} \) the number of success states,
31   * \( n \in \{0, 1, \dots, N\} \) the number of samples,
32   * \( k \in \{\max(0, n+K-N), \dots, \min(n, K)\} \) the number of successes, and
33   *
34   * <p>\[ \binom{a}{b} = \frac{a!}{b! \, (a-b)!} \]
35   *
36   * <p>is the binomial coefficient.
37   *
38   * @see <a href="https://en.wikipedia.org/wiki/Hypergeometric_distribution">Hypergeometric distribution (Wikipedia)</a>
39   * @see <a href="https://mathworld.wolfram.com/HypergeometricDistribution.html">Hypergeometric distribution (MathWorld)</a>
40   */
41  public final class HypergeometricDistribution extends AbstractDiscreteDistribution {
42      /** 1/2. */
43      private static final double HALF = 0.5;
44      /** The number of successes in the population. */
45      private final int numberOfSuccesses;
46      /** The population size. */
47      private final int populationSize;
48      /** The sample size. */
49      private final int sampleSize;
50      /** The lower bound of the support (inclusive). */
51      private final int lowerBound;
52      /** The upper bound of the support (inclusive). */
53      private final int upperBound;
54      /** Binomial probability of success (sampleSize / populationSize). */
55      private final double bp;
56      /** Binomial probability of failure ((populationSize - sampleSize) / populationSize). */
57      private final double bq;
58      /** Cached midpoint of the CDF/SF. The array holds [x, cdf(x)] for the midpoint x.
59       * Used for the cumulative probability functions. */
60      private double[] midpoint;
61  
62      /**
63       * @param populationSize Population size.
64       * @param numberOfSuccesses Number of successes in the population.
65       * @param sampleSize Sample size.
66       */
67      private HypergeometricDistribution(int populationSize,
68                                         int numberOfSuccesses,
69                                         int sampleSize) {
70          this.numberOfSuccesses = numberOfSuccesses;
71          this.populationSize = populationSize;
72          this.sampleSize = sampleSize;
73          lowerBound = getLowerDomain(populationSize, numberOfSuccesses, sampleSize);
74          upperBound = getUpperDomain(numberOfSuccesses, sampleSize);
75          bp = (double) sampleSize / populationSize;
76          bq = (double) (populationSize - sampleSize) / populationSize;
77      }
78  
79      /**
80       * Creates a hypergeometric distribution.
81       *
82       * @param populationSize Population size.
83       * @param numberOfSuccesses Number of successes in the population.
84       * @param sampleSize Sample size.
85       * @return the distribution
86       * @throws IllegalArgumentException if {@code numberOfSuccesses < 0}, or
87       * {@code populationSize <= 0} or {@code numberOfSuccesses > populationSize}, or
88       * {@code sampleSize > populationSize}.
89       */
90      public static HypergeometricDistribution of(int populationSize,
91                                                  int numberOfSuccesses,
92                                                  int sampleSize) {
93          if (populationSize <= 0) {
94              throw new DistributionException(DistributionException.NOT_STRICTLY_POSITIVE,
95                                              populationSize);
96          }
97          if (numberOfSuccesses < 0) {
98              throw new DistributionException(DistributionException.NEGATIVE,
99                                              numberOfSuccesses);
100         }
101         if (sampleSize < 0) {
102             throw new DistributionException(DistributionException.NEGATIVE,
103                                             sampleSize);
104         }
105 
106         if (numberOfSuccesses > populationSize) {
107             throw new DistributionException(DistributionException.TOO_LARGE,
108                                             numberOfSuccesses, populationSize);
109         }
110         if (sampleSize > populationSize) {
111             throw new DistributionException(DistributionException.TOO_LARGE,
112                                             sampleSize, populationSize);
113         }
114         return new HypergeometricDistribution(populationSize, numberOfSuccesses, sampleSize);
115     }
116 
117     /**
118      * Return the lowest domain value for the given hypergeometric distribution
119      * parameters.
120      *
121      * @param nn Population size.
122      * @param k Number of successes in the population.
123      * @param n Sample size.
124      * @return the lowest domain value of the hypergeometric distribution.
125      */
126     private static int getLowerDomain(int nn, int k, int n) {
127         // Avoid overflow given N > n:
128         // n + K - N == K - (N - n)
129         return Math.max(0, k - (nn - n));
130     }
131 
132     /**
133      * Return the highest domain value for the given hypergeometric distribution
134      * parameters.
135      *
136      * @param k Number of successes in the population.
137      * @param n Sample size.
138      * @return the highest domain value of the hypergeometric distribution.
139      */
140     private static int getUpperDomain(int k, int n) {
141         return Math.min(n, k);
142     }
143 
144     /**
145      * Gets the population size parameter of this distribution.
146      *
147      * @return the population size.
148      */
149     public int getPopulationSize() {
150         return populationSize;
151     }
152 
153     /**
154      * Gets the number of successes parameter of this distribution.
155      *
156      * @return the number of successes.
157      */
158     public int getNumberOfSuccesses() {
159         return numberOfSuccesses;
160     }
161 
162     /**
163      * Gets the sample size parameter of this distribution.
164      *
165      * @return the sample size.
166      */
167     public int getSampleSize() {
168         return sampleSize;
169     }
170 
171     /** {@inheritDoc} */
172     @Override
173     public double probability(int x) {
174         return Math.exp(logProbability(x));
175     }
176 
177     /** {@inheritDoc} */
178     @Override
179     public double probability(int x0, int x1) {
180         if (x0 > x1) {
181             throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GT_HIGH, x0, x1);
182         }
183         if (x0 == x1 || x1 < lowerBound) {
184             return 0;
185         }
186         // If the range is outside the bounds use the appropriate cumulative probability
187         if (x0 < lowerBound) {
188             return cumulativeProbability(x1);
189         }
190         if (x1 >= upperBound) {
191             // 1 - cdf(x0)
192             return survivalProbability(x0);
193         }
194         // Here: lower <= x0 < x1 < upper:
195         // sum(pdf(x)) for x in (x0, x1]
196         final int lo = x0 + 1;
197         // Sum small values first by starting at the point the greatest distance from the mode.
198         final int mode = (int) Math.floor((sampleSize + 1.0) * (numberOfSuccesses + 1.0) / (populationSize + 2.0));
199         return Math.abs(mode - lo) > Math.abs(mode - x1) ?
200             innerCumulativeProbability(lo, x1) :
201             innerCumulativeProbability(x1, lo);
202     }
203 
204     /** {@inheritDoc} */
205     @Override
206     public double logProbability(int x) {
207         if (x < lowerBound || x > upperBound) {
208             return Double.NEGATIVE_INFINITY;
209         }
210         return computeLogProbability(x);
211     }
212 
213     /**
214      * Compute the log probability.
215      *
216      * @param x Value.
217      * @return log(P(X = x))
218      */
219     private double computeLogProbability(int x) {
220         final double p1 =
221                 SaddlePointExpansionUtils.logBinomialProbability(x, numberOfSuccesses, bp, bq);
222         final double p2 =
223                 SaddlePointExpansionUtils.logBinomialProbability(sampleSize - x,
224                         populationSize - numberOfSuccesses, bp, bq);
225         final double p3 =
226                 SaddlePointExpansionUtils.logBinomialProbability(sampleSize, populationSize, bp, bq);
227         return p1 + p2 - p3;
228     }
229 
230     /** {@inheritDoc} */
231     @Override
232     public double cumulativeProbability(int x) {
233         if (x < lowerBound) {
234             return 0.0;
235         } else if (x >= upperBound) {
236             return 1.0;
237         }
238         final double[] mid = getMidPoint();
239         final int m = (int) mid[0];
240         if (x < m) {
241             return innerCumulativeProbability(lowerBound, x);
242         } else if (x > m) {
243             return 1 - innerCumulativeProbability(upperBound, x + 1);
244         }
245         // cdf(x)
246         return mid[1];
247     }
248 
249     /** {@inheritDoc} */
250     @Override
251     public double survivalProbability(int x) {
252         if (x < lowerBound) {
253             return 1.0;
254         } else if (x >= upperBound) {
255             return 0.0;
256         }
257         final double[] mid = getMidPoint();
258         final int m = (int) mid[0];
259         if (x < m) {
260             return 1 - innerCumulativeProbability(lowerBound, x);
261         } else if (x > m) {
262             return innerCumulativeProbability(upperBound, x + 1);
263         }
264         // 1 - cdf(x)
265         return 1 - mid[1];
266     }
267 
268     /**
269      * For this distribution, {@code X}, this method returns
270      * {@code P(x0 <= X <= x1)}.
271      * This probability is computed by summing the point probabilities for the
272      * values {@code x0, x0 + dx, x0 + 2 * dx, ..., x1}; the direction {@code dx} is determined
273      * using a comparison of the input bounds.
274      * This should be called by using {@code x0} as the domain limit and {@code x1}
275      * as the internal value. This will result in an initial sum of increasing larger magnitudes.
276      *
277      * @param x0 Inclusive domain bound.
278      * @param x1 Inclusive internal bound.
279      * @return {@code P(x0 <= X <= x1)}.
280      */
281     private double innerCumulativeProbability(int x0, int x1) {
282         // Assume the range is within the domain.
283         // Reuse the computation for probability(x) but avoid checking the domain for each call.
284         int x = x0;
285         double ret = Math.exp(computeLogProbability(x));
286         if (x0 < x1) {
287             while (x != x1) {
288                 x++;
289                 ret += Math.exp(computeLogProbability(x));
290             }
291         } else {
292             while (x != x1) {
293                 x--;
294                 ret += Math.exp(computeLogProbability(x));
295             }
296         }
297         return ret;
298     }
299 
300     @Override
301     public int inverseCumulativeProbability(double p) {
302         ArgumentUtils.checkProbability(p);
303         return computeInverseProbability(p, 1 - p, false);
304     }
305 
306     @Override
307     public int inverseSurvivalProbability(double p) {
308         ArgumentUtils.checkProbability(p);
309         return computeInverseProbability(1 - p, p, true);
310     }
311 
312     /**
313      * Implementation for the inverse cumulative or survival probability.
314      *
315      * @param p Cumulative probability.
316      * @param q Survival probability.
317      * @param complement Set to true to compute the inverse survival probability.
318      * @return the value
319      */
320     private int computeInverseProbability(double p, double q, boolean complement) {
321         if (p == 0) {
322             return lowerBound;
323         }
324         if (q == 0) {
325             return upperBound;
326         }
327 
328         // Sum the PDF(x) until the appropriate p-value is obtained
329         // CDF: require smallest x where P(X<=x) >= p
330         // SF:  require smallest x where P(X>x) <= q
331         // The choice of summation uses the mid-point.
332         // The test on the CDF or SF is based on the appropriate input p-value.
333 
334         final double[] mid = getMidPoint();
335         final int m = (int) mid[0];
336         final double mp = mid[1];
337 
338         final int midPointComparison = complement ?
339             Double.compare(1 - mp, q) :
340             Double.compare(p, mp);
341 
342         if (midPointComparison < 0) {
343             return inverseLower(p, q, complement);
344         } else if (midPointComparison > 0) {
345             // Avoid floating-point summation error when the mid-point computed using the
346             // lower sum is different to the midpoint computed using the upper sum.
347             // Here we know the result must be above the midpoint so we can clip the result.
348             return Math.max(m + 1, inverseUpper(p, q, complement));
349         }
350         // Exact mid-point
351         return m;
352     }
353 
354     /**
355      * Compute the inverse cumulative or survival probability using the lower sum.
356      *
357      * @param p Cumulative probability.
358      * @param q Survival probability.
359      * @param complement Set to true to compute the inverse survival probability.
360      * @return the value
361      */
362     private int inverseLower(double p, double q, boolean complement) {
363         // Sum from the lower bound (computing the cdf)
364         int x = lowerBound;
365         final DoublePredicate test = complement ?
366             i -> 1 - i > q :
367             i -> i < p;
368         double cdf = Math.exp(computeLogProbability(x));
369         while (test.test(cdf)) {
370             x++;
371             cdf += Math.exp(computeLogProbability(x));
372         }
373         return x;
374     }
375 
376     /**
377      * Compute the inverse cumulative or survival probability using the upper sum.
378      *
379      * @param p Cumulative probability.
380      * @param q Survival probability.
381      * @param complement Set to true to compute the inverse survival probability.
382      * @return the value
383      */
384     private int inverseUpper(double p, double q, boolean complement) {
385         // Sum from the upper bound (computing the sf)
386         int x = upperBound;
387         final DoublePredicate test = complement ?
388             i -> i < q :
389             i -> 1 - i > p;
390         double sf = 0;
391         while (test.test(sf)) {
392             sf += Math.exp(computeLogProbability(x));
393             x--;
394         }
395         // Here either sf(x) >= q, or cdf(x) <= p
396         // Ensure sf(x) <= q, or cdf(x) >= p
397         if (complement && sf > q ||
398             !complement && 1 - sf < p) {
399             x++;
400         }
401         return x;
402     }
403 
404     /**
405      * {@inheritDoc}
406      *
407      * <p>For population size \( N \), number of successes \( K \), and sample
408      * size \( n \), the mean is:
409      *
410      * <p>\[ n \frac{K}{N} \]
411      */
412     @Override
413     public double getMean() {
414         return getSampleSize() * (getNumberOfSuccesses() / (double) getPopulationSize());
415     }
416 
417     /**
418      * {@inheritDoc}
419      *
420      * <p>For population size \( N \), number of successes \( K \), and sample
421      * size \( n \), the variance is:
422      *
423      * <p>\[ n \frac{K}{N} \frac{N-K}{N} \frac{N-n}{N-1} \]
424      */
425     @Override
426     public double getVariance() {
427         final double N = getPopulationSize();
428         final double K = getNumberOfSuccesses();
429         final double n = getSampleSize();
430         return (n * K * (N - K) * (N - n)) / (N * N * (N - 1));
431     }
432 
433     /**
434      * {@inheritDoc}
435      *
436      * <p>For population size \( N \), number of successes \( K \), and sample
437      * size \( n \), the lower bound of the support is \( \max \{ 0, n + K - N \} \).
438      *
439      * @return lower bound of the support
440      */
441     @Override
442     public int getSupportLowerBound() {
443         return lowerBound;
444     }
445 
446     /**
447      * {@inheritDoc}
448      *
449      * <p>For number of successes \( K \), and sample
450      * size \( n \), the upper bound of the support is \( \min \{ n, K \} \).
451      *
452      * @return upper bound of the support
453      */
454     @Override
455     public int getSupportUpperBound() {
456         return upperBound;
457     }
458 
459     /**
460      * Return the mid-point {@code x} of the distribution, and the cdf(x).
461      *
462      * <p>This is not the true median. It is the value where the CDF(x) is closest to 0.5;
463      * as such the CDF may be below 0.5 if the next value of x is further from 0.5.
464      *
465      * @return the mid-point ([x, cdf(x)])
466      */
467     private double[] getMidPoint() {
468         double[] v = midpoint;
469         if (v == null) {
470             // Find the closest sum(PDF) to 0.5
471             int x = lowerBound;
472             double p0 = 0;
473             double p1 = Math.exp(computeLogProbability(x));
474             // No check of the upper bound required here as the CDF should sum to 1 and 0.5
475             // is exceeded before a bounds error.
476             while (p1 < HALF) {
477                 x++;
478                 p0 = p1;
479                 p1 += Math.exp(computeLogProbability(x));
480             }
481             // p1 >= 0.5 > p0
482             // Pick closet
483             if (p1 - HALF >= HALF - p0) {
484                 x--;
485                 p1 = p0;
486             }
487             midpoint = v = new double[] {x, p1};
488         }
489         return v;
490     }
491 }