HypergeometricDistribution.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.statistics.distribution;

  18. import java.util.function.DoublePredicate;

  19. /**
  20.  * Implementation of the hypergeometric distribution.
  21.  *
  22.  * <p>The probability mass function of \( X \) is:
  23.  *
  24.  * <p>\[ f(k; N, K, n) = \frac{\binom{K}{k} \binom{N - K}{n-k}}{\binom{N}{n}} \]
  25.  *
  26.  * <p>for \( N \in \{0, 1, 2, \dots\} \) the population size,
  27.  * \( K \in \{0, 1, \dots, N\} \) the number of success states,
  28.  * \( n \in \{0, 1, \dots, N\} \) the number of samples,
  29.  * \( k \in \{\max(0, n+K-N), \dots, \min(n, K)\} \) the number of successes, and
  30.  *
  31.  * <p>\[ \binom{a}{b} = \frac{a!}{b! \, (a-b)!} \]
  32.  *
  33.  * <p>is the binomial coefficient.
  34.  *
  35.  * @see <a href="https://en.wikipedia.org/wiki/Hypergeometric_distribution">Hypergeometric distribution (Wikipedia)</a>
  36.  * @see <a href="https://mathworld.wolfram.com/HypergeometricDistribution.html">Hypergeometric distribution (MathWorld)</a>
  37.  */
  38. public final class HypergeometricDistribution extends AbstractDiscreteDistribution {
  39.     /** 1/2. */
  40.     private static final double HALF = 0.5;
  41.     /** The number of successes in the population. */
  42.     private final int numberOfSuccesses;
  43.     /** The population size. */
  44.     private final int populationSize;
  45.     /** The sample size. */
  46.     private final int sampleSize;
  47.     /** The lower bound of the support (inclusive). */
  48.     private final int lowerBound;
  49.     /** The upper bound of the support (inclusive). */
  50.     private final int upperBound;
  51.     /** Binomial probability of success (sampleSize / populationSize). */
  52.     private final double bp;
  53.     /** Binomial probability of failure ((populationSize - sampleSize) / populationSize). */
  54.     private final double bq;
  55.     /** Cached midpoint of the CDF/SF. The array holds [x, cdf(x)] for the midpoint x.
  56.      * Used for the cumulative probability functions. */
  57.     private double[] midpoint;

  58.     /**
  59.      * @param populationSize Population size.
  60.      * @param numberOfSuccesses Number of successes in the population.
  61.      * @param sampleSize Sample size.
  62.      */
  63.     private HypergeometricDistribution(int populationSize,
  64.                                        int numberOfSuccesses,
  65.                                        int sampleSize) {
  66.         this.numberOfSuccesses = numberOfSuccesses;
  67.         this.populationSize = populationSize;
  68.         this.sampleSize = sampleSize;
  69.         lowerBound = getLowerDomain(populationSize, numberOfSuccesses, sampleSize);
  70.         upperBound = getUpperDomain(numberOfSuccesses, sampleSize);
  71.         bp = (double) sampleSize / populationSize;
  72.         bq = (double) (populationSize - sampleSize) / populationSize;
  73.     }

  74.     /**
  75.      * Creates a hypergeometric distribution.
  76.      *
  77.      * @param populationSize Population size.
  78.      * @param numberOfSuccesses Number of successes in the population.
  79.      * @param sampleSize Sample size.
  80.      * @return the distribution
  81.      * @throws IllegalArgumentException if {@code numberOfSuccesses < 0}, or
  82.      * {@code populationSize <= 0} or {@code numberOfSuccesses > populationSize}, or
  83.      * {@code sampleSize > populationSize}.
  84.      */
  85.     public static HypergeometricDistribution of(int populationSize,
  86.                                                 int numberOfSuccesses,
  87.                                                 int sampleSize) {
  88.         if (populationSize <= 0) {
  89.             throw new DistributionException(DistributionException.NOT_STRICTLY_POSITIVE,
  90.                                             populationSize);
  91.         }
  92.         if (numberOfSuccesses < 0) {
  93.             throw new DistributionException(DistributionException.NEGATIVE,
  94.                                             numberOfSuccesses);
  95.         }
  96.         if (sampleSize < 0) {
  97.             throw new DistributionException(DistributionException.NEGATIVE,
  98.                                             sampleSize);
  99.         }

  100.         if (numberOfSuccesses > populationSize) {
  101.             throw new DistributionException(DistributionException.TOO_LARGE,
  102.                                             numberOfSuccesses, populationSize);
  103.         }
  104.         if (sampleSize > populationSize) {
  105.             throw new DistributionException(DistributionException.TOO_LARGE,
  106.                                             sampleSize, populationSize);
  107.         }
  108.         return new HypergeometricDistribution(populationSize, numberOfSuccesses, sampleSize);
  109.     }

  110.     /**
  111.      * Return the lowest domain value for the given hypergeometric distribution
  112.      * parameters.
  113.      *
  114.      * @param nn Population size.
  115.      * @param k Number of successes in the population.
  116.      * @param n Sample size.
  117.      * @return the lowest domain value of the hypergeometric distribution.
  118.      */
  119.     private static int getLowerDomain(int nn, int k, int n) {
  120.         // Avoid overflow given N > n:
  121.         // n + K - N == K - (N - n)
  122.         return Math.max(0, k - (nn - n));
  123.     }

  124.     /**
  125.      * Return the highest domain value for the given hypergeometric distribution
  126.      * parameters.
  127.      *
  128.      * @param k Number of successes in the population.
  129.      * @param n Sample size.
  130.      * @return the highest domain value of the hypergeometric distribution.
  131.      */
  132.     private static int getUpperDomain(int k, int n) {
  133.         return Math.min(n, k);
  134.     }

  135.     /**
  136.      * Gets the population size parameter of this distribution.
  137.      *
  138.      * @return the population size.
  139.      */
  140.     public int getPopulationSize() {
  141.         return populationSize;
  142.     }

  143.     /**
  144.      * Gets the number of successes parameter of this distribution.
  145.      *
  146.      * @return the number of successes.
  147.      */
  148.     public int getNumberOfSuccesses() {
  149.         return numberOfSuccesses;
  150.     }

  151.     /**
  152.      * Gets the sample size parameter of this distribution.
  153.      *
  154.      * @return the sample size.
  155.      */
  156.     public int getSampleSize() {
  157.         return sampleSize;
  158.     }

  159.     /** {@inheritDoc} */
  160.     @Override
  161.     public double probability(int x) {
  162.         return Math.exp(logProbability(x));
  163.     }

  164.     /** {@inheritDoc} */
  165.     @Override
  166.     public double probability(int x0, int x1) {
  167.         if (x0 > x1) {
  168.             throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GT_HIGH, x0, x1);
  169.         }
  170.         if (x0 == x1 || x1 < lowerBound) {
  171.             return 0;
  172.         }
  173.         // If the range is outside the bounds use the appropriate cumulative probability
  174.         if (x0 < lowerBound) {
  175.             return cumulativeProbability(x1);
  176.         }
  177.         if (x1 >= upperBound) {
  178.             // 1 - cdf(x0)
  179.             return survivalProbability(x0);
  180.         }
  181.         // Here: lower <= x0 < x1 < upper:
  182.         // sum(pdf(x)) for x in (x0, x1]
  183.         final int lo = x0 + 1;
  184.         // Sum small values first by starting at the point the greatest distance from the mode.
  185.         final int mode = (int) Math.floor((sampleSize + 1.0) * (numberOfSuccesses + 1.0) / (populationSize + 2.0));
  186.         return Math.abs(mode - lo) > Math.abs(mode - x1) ?
  187.             innerCumulativeProbability(lo, x1) :
  188.             innerCumulativeProbability(x1, lo);
  189.     }

  190.     /** {@inheritDoc} */
  191.     @Override
  192.     public double logProbability(int x) {
  193.         if (x < lowerBound || x > upperBound) {
  194.             return Double.NEGATIVE_INFINITY;
  195.         }
  196.         return computeLogProbability(x);
  197.     }

  198.     /**
  199.      * Compute the log probability.
  200.      *
  201.      * @param x Value.
  202.      * @return log(P(X = x))
  203.      */
  204.     private double computeLogProbability(int x) {
  205.         final double p1 =
  206.                 SaddlePointExpansionUtils.logBinomialProbability(x, numberOfSuccesses, bp, bq);
  207.         final double p2 =
  208.                 SaddlePointExpansionUtils.logBinomialProbability(sampleSize - x,
  209.                         populationSize - numberOfSuccesses, bp, bq);
  210.         final double p3 =
  211.                 SaddlePointExpansionUtils.logBinomialProbability(sampleSize, populationSize, bp, bq);
  212.         return p1 + p2 - p3;
  213.     }

  214.     /** {@inheritDoc} */
  215.     @Override
  216.     public double cumulativeProbability(int x) {
  217.         if (x < lowerBound) {
  218.             return 0.0;
  219.         } else if (x >= upperBound) {
  220.             return 1.0;
  221.         }
  222.         final double[] mid = getMidPoint();
  223.         final int m = (int) mid[0];
  224.         if (x < m) {
  225.             return innerCumulativeProbability(lowerBound, x);
  226.         } else if (x > m) {
  227.             return 1 - innerCumulativeProbability(upperBound, x + 1);
  228.         }
  229.         // cdf(x)
  230.         return mid[1];
  231.     }

  232.     /** {@inheritDoc} */
  233.     @Override
  234.     public double survivalProbability(int x) {
  235.         if (x < lowerBound) {
  236.             return 1.0;
  237.         } else if (x >= upperBound) {
  238.             return 0.0;
  239.         }
  240.         final double[] mid = getMidPoint();
  241.         final int m = (int) mid[0];
  242.         if (x < m) {
  243.             return 1 - innerCumulativeProbability(lowerBound, x);
  244.         } else if (x > m) {
  245.             return innerCumulativeProbability(upperBound, x + 1);
  246.         }
  247.         // 1 - cdf(x)
  248.         return 1 - mid[1];
  249.     }

  250.     /**
  251.      * For this distribution, {@code X}, this method returns
  252.      * {@code P(x0 <= X <= x1)}.
  253.      * This probability is computed by summing the point probabilities for the
  254.      * values {@code x0, x0 + dx, x0 + 2 * dx, ..., x1}; the direction {@code dx} is determined
  255.      * using a comparison of the input bounds.
  256.      * This should be called by using {@code x0} as the domain limit and {@code x1}
  257.      * as the internal value. This will result in an initial sum of increasing larger magnitudes.
  258.      *
  259.      * @param x0 Inclusive domain bound.
  260.      * @param x1 Inclusive internal bound.
  261.      * @return {@code P(x0 <= X <= x1)}.
  262.      */
  263.     private double innerCumulativeProbability(int x0, int x1) {
  264.         // Assume the range is within the domain.
  265.         // Reuse the computation for probability(x) but avoid checking the domain for each call.
  266.         int x = x0;
  267.         double ret = Math.exp(computeLogProbability(x));
  268.         if (x0 < x1) {
  269.             while (x != x1) {
  270.                 x++;
  271.                 ret += Math.exp(computeLogProbability(x));
  272.             }
  273.         } else {
  274.             while (x != x1) {
  275.                 x--;
  276.                 ret += Math.exp(computeLogProbability(x));
  277.             }
  278.         }
  279.         return ret;
  280.     }

  281.     @Override
  282.     public int inverseCumulativeProbability(double p) {
  283.         ArgumentUtils.checkProbability(p);
  284.         return computeInverseProbability(p, 1 - p, false);
  285.     }

  286.     @Override
  287.     public int inverseSurvivalProbability(double p) {
  288.         ArgumentUtils.checkProbability(p);
  289.         return computeInverseProbability(1 - p, p, true);
  290.     }

  291.     /**
  292.      * Implementation for the inverse cumulative or survival probability.
  293.      *
  294.      * @param p Cumulative probability.
  295.      * @param q Survival probability.
  296.      * @param complement Set to true to compute the inverse survival probability.
  297.      * @return the value
  298.      */
  299.     private int computeInverseProbability(double p, double q, boolean complement) {
  300.         if (p == 0) {
  301.             return lowerBound;
  302.         }
  303.         if (q == 0) {
  304.             return upperBound;
  305.         }

  306.         // Sum the PDF(x) until the appropriate p-value is obtained
  307.         // CDF: require smallest x where P(X<=x) >= p
  308.         // SF:  require smallest x where P(X>x) <= q
  309.         // The choice of summation uses the mid-point.
  310.         // The test on the CDF or SF is based on the appropriate input p-value.

  311.         final double[] mid = getMidPoint();
  312.         final int m = (int) mid[0];
  313.         final double mp = mid[1];

  314.         final int midPointComparison = complement ?
  315.             Double.compare(1 - mp, q) :
  316.             Double.compare(p, mp);

  317.         if (midPointComparison < 0) {
  318.             return inverseLower(p, q, complement);
  319.         } else if (midPointComparison > 0) {
  320.             // Avoid floating-point summation error when the mid-point computed using the
  321.             // lower sum is different to the midpoint computed using the upper sum.
  322.             // Here we know the result must be above the midpoint so we can clip the result.
  323.             return Math.max(m + 1, inverseUpper(p, q, complement));
  324.         }
  325.         // Exact mid-point
  326.         return m;
  327.     }

  328.     /**
  329.      * Compute the inverse cumulative or survival probability using the lower sum.
  330.      *
  331.      * @param p Cumulative probability.
  332.      * @param q Survival probability.
  333.      * @param complement Set to true to compute the inverse survival probability.
  334.      * @return the value
  335.      */
  336.     private int inverseLower(double p, double q, boolean complement) {
  337.         // Sum from the lower bound (computing the cdf)
  338.         int x = lowerBound;
  339.         final DoublePredicate test = complement ?
  340.             i -> 1 - i > q :
  341.             i -> i < p;
  342.         double cdf = Math.exp(computeLogProbability(x));
  343.         while (test.test(cdf)) {
  344.             x++;
  345.             cdf += Math.exp(computeLogProbability(x));
  346.         }
  347.         return x;
  348.     }

  349.     /**
  350.      * Compute the inverse cumulative or survival probability using the upper sum.
  351.      *
  352.      * @param p Cumulative probability.
  353.      * @param q Survival probability.
  354.      * @param complement Set to true to compute the inverse survival probability.
  355.      * @return the value
  356.      */
  357.     private int inverseUpper(double p, double q, boolean complement) {
  358.         // Sum from the upper bound (computing the sf)
  359.         int x = upperBound;
  360.         final DoublePredicate test = complement ?
  361.             i -> i < q :
  362.             i -> 1 - i > p;
  363.         double sf = 0;
  364.         while (test.test(sf)) {
  365.             sf += Math.exp(computeLogProbability(x));
  366.             x--;
  367.         }
  368.         // Here either sf(x) >= q, or cdf(x) <= p
  369.         // Ensure sf(x) <= q, or cdf(x) >= p
  370.         if (complement && sf > q ||
  371.             !complement && 1 - sf < p) {
  372.             x++;
  373.         }
  374.         return x;
  375.     }

  376.     /**
  377.      * {@inheritDoc}
  378.      *
  379.      * <p>For population size \( N \), number of successes \( K \), and sample
  380.      * size \( n \), the mean is:
  381.      *
  382.      * <p>\[ n \frac{K}{N} \]
  383.      */
  384.     @Override
  385.     public double getMean() {
  386.         return getSampleSize() * (getNumberOfSuccesses() / (double) getPopulationSize());
  387.     }

  388.     /**
  389.      * {@inheritDoc}
  390.      *
  391.      * <p>For population size \( N \), number of successes \( K \), and sample
  392.      * size \( n \), the variance is:
  393.      *
  394.      * <p>\[ n \frac{K}{N} \frac{N-K}{N} \frac{N-n}{N-1} \]
  395.      */
  396.     @Override
  397.     public double getVariance() {
  398.         final double N = getPopulationSize();
  399.         final double K = getNumberOfSuccesses();
  400.         final double n = getSampleSize();
  401.         return (n * K * (N - K) * (N - n)) / (N * N * (N - 1));
  402.     }

  403.     /**
  404.      * {@inheritDoc}
  405.      *
  406.      * <p>For population size \( N \), number of successes \( K \), and sample
  407.      * size \( n \), the lower bound of the support is \( \max \{ 0, n + K - N \} \).
  408.      *
  409.      * @return lower bound of the support
  410.      */
  411.     @Override
  412.     public int getSupportLowerBound() {
  413.         return lowerBound;
  414.     }

  415.     /**
  416.      * {@inheritDoc}
  417.      *
  418.      * <p>For number of successes \( K \), and sample
  419.      * size \( n \), the upper bound of the support is \( \min \{ n, K \} \).
  420.      *
  421.      * @return upper bound of the support
  422.      */
  423.     @Override
  424.     public int getSupportUpperBound() {
  425.         return upperBound;
  426.     }

  427.     /**
  428.      * Return the mid-point {@code x} of the distribution, and the cdf(x).
  429.      *
  430.      * <p>This is not the true median. It is the value where the CDF(x) is closest to 0.5;
  431.      * as such the CDF may be below 0.5 if the next value of x is further from 0.5.
  432.      *
  433.      * @return the mid-point ([x, cdf(x)])
  434.      */
  435.     private double[] getMidPoint() {
  436.         double[] v = midpoint;
  437.         if (v == null) {
  438.             // Find the closest sum(PDF) to 0.5
  439.             int x = lowerBound;
  440.             double p0 = 0;
  441.             double p1 = Math.exp(computeLogProbability(x));
  442.             // No check of the upper bound required here as the CDF should sum to 1 and 0.5
  443.             // is exceeded before a bounds error.
  444.             while (p1 < HALF) {
  445.                 x++;
  446.                 p0 = p1;
  447.                 p1 += Math.exp(computeLogProbability(x));
  448.             }
  449.             // p1 >= 0.5 > p0
  450.             // Pick closet
  451.             if (p1 - HALF >= HALF - p0) {
  452.                 x--;
  453.                 p1 = p0;
  454.             }
  455.             midpoint = v = new double[] {x, p1};
  456.         }
  457.         return v;
  458.     }
  459. }