Hypergeom.java
- /*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- package org.apache.commons.statistics.inference;
- import org.apache.commons.statistics.distribution.HypergeometricDistribution;
- /**
- * Provide a wrapper around the {@link HypergeometricDistribution} that caches
- * all probability mass values.
- *
- * <p>This class extracts the logic from the HypergeometricDistribution implementation
- * used for the cumulative probability functions. It allows fast computation of
- * the CDF and SF for the entire supported domain.
- *
- * @since 1.1
- */
- class Hypergeom {
- /** 1/2. */
- private static final double HALF = 0.5;
- /** The lower bound of the support (inclusive). */
- private final int lowerBound;
- /** The upper bound of the support (inclusive). */
- private final int upperBound;
- /** Cached probability values. This holds values from x=0 even though the supported
- * lower bound may be above x=0. This allows x to be used as an index without offsetting
- * using the lower bound. */
- private final double[] prob;
- /** Cached midpoint, m, of the CDF/SF. This is not the true median. It is the value where
- * the CDF is closest to 0.5; as such the CDF(m) may be below 0.5 if the next value
- * CDF(m+1) is further from 0.5. Used for the cumulative probability functions. */
- private final int m;
- /** Cached CDF of the midpoint.
- * Used for the cumulative probability functions. */
- private final double midCDF;
- /** Lower mode. */
- private final int m1;
- /** Upper mode. */
- private final int m2;
- /**
- * @param populationSize Population size.
- * @param numberOfSuccesses Number of successes in the population.
- * @param sampleSize Sample size.
- */
- Hypergeom(int populationSize,
- int numberOfSuccesses,
- int sampleSize) {
- final HypergeometricDistribution dist =
- HypergeometricDistribution.of(populationSize, numberOfSuccesses, sampleSize);
- // Cache all values required to compute the cumulative probability functions
- // Bounds
- lowerBound = dist.getSupportLowerBound();
- upperBound = dist.getSupportUpperBound();
- // PMF values
- prob = new double[upperBound + 1];
- for (int x = lowerBound; x <= upperBound; x++) {
- prob[x] = dist.probability(x);
- }
- // Compute mid-point for CDF/SF computation
- // Find the closest sum(PDF) to 0.5.
- int x = lowerBound;
- double p0 = 0;
- double p1 = prob[x];
- // No check of the upper bound required here as the CDF should sum to 1 and 0.5
- // is exceeded before a bounds error.
- while (p1 < HALF) {
- x++;
- p0 = p1;
- p1 += prob[x];
- }
- // p1 >= 0.5 > p0
- // Pick closet
- if (p1 - HALF >= HALF - p0) {
- x--;
- p1 = p0;
- }
- m = x;
- midCDF = p1;
- // Compute the mode (lower != upper in the case where v is integer).
- // This value is used by the UnconditionedExactTest and is cached here for convenience.
- final double v = ((double) numberOfSuccesses + 1) * ((double) sampleSize + 1) / (populationSize + 2.0);
- m1 = (int) Math.ceil(v) - 1;
- m2 = (int) Math.floor(v);
- }
- /**
- * Get the lower bound of the support.
- *
- * @return lower bound
- */
- int getSupportLowerBound() {
- return lowerBound;
- }
- /**
- * Get the upper bound of the support.
- *
- * @return upper bound
- */
- int getSupportUpperBound() {
- return upperBound;
- }
- /**
- * Get the lower mode of the distribution.
- *
- * @return lower mode
- */
- int getLowerMode() {
- return m1;
- }
- /**
- * Get the upper mode of the distribution.
- *
- * @return upper mode
- */
- int getUpperMode() {
- return m2;
- }
- /**
- * Compute the probability mass function (PMF) at the specified value.
- *
- * @param x Value.
- * @return P(X = x)
- * @throws IndexOutOfBoundsException if the value {@code x} is not in the supported domain.
- */
- double pmf(int x) {
- return prob[x];
- }
- /**
- * Compute the cumulative distribution function (CDF) at the specified value.
- *
- * @param x Value.
- * @return P(X <= x)
- */
- double cdf(int x) {
- if (x < lowerBound) {
- return 0.0;
- } else if (x >= upperBound) {
- return 1.0;
- }
- if (x < m) {
- return innerCumulativeProbability(lowerBound, x);
- } else if (x > m) {
- return 1 - innerCumulativeProbability(upperBound, x + 1);
- }
- // cdf(x)
- return midCDF;
- }
- /**
- * Compute the survival function (SF) at the specified value. This is the complementary
- * cumulative distribution function.
- *
- * @param x Value.
- * @return P(X > x)
- */
- double sf(int x) {
- if (x < lowerBound) {
- return 1.0;
- } else if (x >= upperBound) {
- return 0.0;
- }
- if (x < m) {
- return 1 - innerCumulativeProbability(lowerBound, x);
- } else if (x > m) {
- return innerCumulativeProbability(upperBound, x + 1);
- }
- // 1 - cdf(x)
- return 1 - midCDF;
- }
- /**
- * For this distribution, {@code X}, this method returns
- * {@code P(x0 <= X <= x1)}.
- * This probability is computed by summing the point probabilities for the
- * values {@code x0, x0 + dx, x0 + 2 * dx, ..., x1}; the direction {@code dx} is determined
- * using a comparison of the input bounds.
- * This should be called by using {@code x0} as the domain limit and {@code x1}
- * as the internal value. This will result in a sum of increasingly larger magnitudes.
- *
- * @param x0 Inclusive domain bound.
- * @param x1 Inclusive internal bound.
- * @return {@code P(x0 <= X <= x1)}.
- */
- private double innerCumulativeProbability(int x0, int x1) {
- // Assume the range is within the domain.
- int x = x0;
- double ret = prob[x];
- if (x0 < x1) {
- while (x != x1) {
- x++;
- ret += prob[x];
- }
- } else {
- while (x != x1) {
- x--;
- ret += prob[x];
- }
- }
- return ret;
- }
- }