Hypergeom.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.statistics.inference;

  18. import org.apache.commons.statistics.distribution.HypergeometricDistribution;

  19. /**
  20.  * Provide a wrapper around the {@link HypergeometricDistribution} that caches
  21.  * all probability mass values.
  22.  *
  23.  * <p>This class extracts the logic from the HypergeometricDistribution implementation
  24.  * used for the cumulative probability functions. It allows fast computation of
  25.  * the CDF and SF for the entire supported domain.
  26.  *
  27.  * @since 1.1
  28.  */
  29. class Hypergeom {
  30.     /** 1/2. */
  31.     private static final double HALF = 0.5;
  32.     /** The lower bound of the support (inclusive). */
  33.     private final int lowerBound;
  34.     /** The upper bound of the support (inclusive). */
  35.     private final int upperBound;
  36.     /** Cached probability values. This holds values from x=0 even though the supported
  37.      * lower bound may be above x=0. This allows x to be used as an index without offsetting
  38.      * using the lower bound. */
  39.     private final double[] prob;
  40.     /** Cached midpoint, m, of the CDF/SF. This is not the true median. It is the value where
  41.      * the CDF is closest to 0.5; as such the CDF(m) may be below 0.5 if the next value
  42.      * CDF(m+1) is further from 0.5. Used for the cumulative probability functions. */
  43.     private final int m;
  44.     /** Cached CDF of the midpoint.
  45.      * Used for the cumulative probability functions. */
  46.     private final double midCDF;
  47.     /** Lower mode. */
  48.     private final int m1;
  49.     /** Upper mode. */
  50.     private final int m2;

  51.     /**
  52.      * @param populationSize Population size.
  53.      * @param numberOfSuccesses Number of successes in the population.
  54.      * @param sampleSize Sample size.
  55.      */
  56.     Hypergeom(int populationSize,
  57.               int numberOfSuccesses,
  58.               int sampleSize) {
  59.         final HypergeometricDistribution dist =
  60.             HypergeometricDistribution.of(populationSize, numberOfSuccesses, sampleSize);

  61.         // Cache all values required to compute the cumulative probability functions

  62.         // Bounds
  63.         lowerBound = dist.getSupportLowerBound();
  64.         upperBound = dist.getSupportUpperBound();

  65.         // PMF values
  66.         prob = new double[upperBound + 1];
  67.         for (int x = lowerBound; x <= upperBound; x++) {
  68.             prob[x] = dist.probability(x);
  69.         }

  70.         // Compute mid-point for CDF/SF computation
  71.         // Find the closest sum(PDF) to 0.5.
  72.         int x = lowerBound;
  73.         double p0 = 0;
  74.         double p1 = prob[x];
  75.         // No check of the upper bound required here as the CDF should sum to 1 and 0.5
  76.         // is exceeded before a bounds error.
  77.         while (p1 < HALF) {
  78.             x++;
  79.             p0 = p1;
  80.             p1 += prob[x];
  81.         }
  82.         // p1 >= 0.5 > p0
  83.         // Pick closet
  84.         if (p1 - HALF >= HALF - p0) {
  85.             x--;
  86.             p1 = p0;
  87.         }
  88.         m = x;
  89.         midCDF = p1;

  90.         // Compute the mode (lower != upper in the case where v is integer).
  91.         // This value is used by the UnconditionedExactTest and is cached here for convenience.
  92.         final double v = ((double) numberOfSuccesses + 1) * ((double) sampleSize + 1) / (populationSize + 2.0);
  93.         m1 = (int) Math.ceil(v) - 1;
  94.         m2 = (int) Math.floor(v);
  95.     }

  96.     /**
  97.      * Get the lower bound of the support.
  98.      *
  99.      * @return lower bound
  100.      */
  101.     int getSupportLowerBound() {
  102.         return lowerBound;
  103.     }

  104.     /**
  105.      * Get the upper bound of the support.
  106.      *
  107.      * @return upper bound
  108.      */
  109.     int getSupportUpperBound() {
  110.         return upperBound;
  111.     }

  112.     /**
  113.      * Get the lower mode of the distribution.
  114.      *
  115.      * @return lower mode
  116.      */
  117.     int getLowerMode() {
  118.         return m1;
  119.     }

  120.     /**
  121.      * Get the upper mode of the distribution.
  122.      *
  123.      * @return upper mode
  124.      */
  125.     int getUpperMode() {
  126.         return m2;
  127.     }

  128.     /**
  129.      * Compute the probability mass function (PMF) at the specified value.
  130.      *
  131.      * @param x Value.
  132.      * @return P(X = x)
  133.      * @throws IndexOutOfBoundsException if the value {@code x} is not in the supported domain.
  134.      */
  135.     double pmf(int x) {
  136.         return prob[x];
  137.     }

  138.     /**
  139.      * Compute the cumulative distribution function (CDF) at the specified value.
  140.      *
  141.      * @param x Value.
  142.      * @return P(X <= x)
  143.      */
  144.     double cdf(int x) {
  145.         if (x < lowerBound) {
  146.             return 0.0;
  147.         } else if (x >= upperBound) {
  148.             return 1.0;
  149.         }
  150.         if (x < m) {
  151.             return innerCumulativeProbability(lowerBound, x);
  152.         } else if (x > m) {
  153.             return 1 - innerCumulativeProbability(upperBound, x + 1);
  154.         }
  155.         // cdf(x)
  156.         return midCDF;
  157.     }

  158.     /**
  159.      * Compute the survival function (SF) at the specified value. This is the complementary
  160.      * cumulative distribution function.
  161.      *
  162.      * @param x Value.
  163.      * @return P(X > x)
  164.      */
  165.     double sf(int x) {
  166.         if (x < lowerBound) {
  167.             return 1.0;
  168.         } else if (x >= upperBound) {
  169.             return 0.0;
  170.         }
  171.         if (x < m) {
  172.             return 1 - innerCumulativeProbability(lowerBound, x);
  173.         } else if (x > m) {
  174.             return innerCumulativeProbability(upperBound, x + 1);
  175.         }
  176.         // 1 - cdf(x)
  177.         return 1 - midCDF;
  178.     }

  179.     /**
  180.      * For this distribution, {@code X}, this method returns
  181.      * {@code P(x0 <= X <= x1)}.
  182.      * This probability is computed by summing the point probabilities for the
  183.      * values {@code x0, x0 + dx, x0 + 2 * dx, ..., x1}; the direction {@code dx} is determined
  184.      * using a comparison of the input bounds.
  185.      * This should be called by using {@code x0} as the domain limit and {@code x1}
  186.      * as the internal value. This will result in a sum of increasingly larger magnitudes.
  187.      *
  188.      * @param x0 Inclusive domain bound.
  189.      * @param x1 Inclusive internal bound.
  190.      * @return {@code P(x0 <= X <= x1)}.
  191.      */
  192.     private double innerCumulativeProbability(int x0, int x1) {
  193.         // Assume the range is within the domain.
  194.         int x = x0;
  195.         double ret = prob[x];
  196.         if (x0 < x1) {
  197.             while (x != x1) {
  198.                 x++;
  199.                 ret += prob[x];
  200.             }
  201.         } else {
  202.             while (x != x1) {
  203.                 x--;
  204.                 ret += prob[x];
  205.             }
  206.         }
  207.         return ret;
  208.     }
  209. }