001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math3.distribution;
019
020import org.apache.commons.math3.exception.NotPositiveException;
021import org.apache.commons.math3.exception.NotStrictlyPositiveException;
022import org.apache.commons.math3.exception.NumberIsTooLargeException;
023import org.apache.commons.math3.exception.util.LocalizedFormats;
024import org.apache.commons.math3.util.FastMath;
025import org.apache.commons.math3.random.RandomGenerator;
026import org.apache.commons.math3.random.Well19937c;
027
028/**
029 * Implementation of the hypergeometric distribution.
030 *
031 * @see <a href="http://en.wikipedia.org/wiki/Hypergeometric_distribution">Hypergeometric distribution (Wikipedia)</a>
032 * @see <a href="http://mathworld.wolfram.com/HypergeometricDistribution.html">Hypergeometric distribution (MathWorld)</a>
033 * @version $Id: HypergeometricDistribution.java 1534358 2013-10-21 20:13:52Z tn $
034 */
035public class HypergeometricDistribution extends AbstractIntegerDistribution {
036    /** Serializable version identifier. */
037    private static final long serialVersionUID = -436928820673516179L;
038    /** The number of successes in the population. */
039    private final int numberOfSuccesses;
040    /** The population size. */
041    private final int populationSize;
042    /** The sample size. */
043    private final int sampleSize;
044    /** Cached numerical variance */
045    private double numericalVariance = Double.NaN;
046    /** Whether or not the numerical variance has been calculated */
047    private boolean numericalVarianceIsCalculated = false;
048
049    /**
050     * Construct a new hypergeometric distribution with the specified population
051     * size, number of successes in the population, and sample size.
052     *
053     * @param populationSize Population size.
054     * @param numberOfSuccesses Number of successes in the population.
055     * @param sampleSize Sample size.
056     * @throws NotPositiveException if {@code numberOfSuccesses < 0}.
057     * @throws NotStrictlyPositiveException if {@code populationSize <= 0}.
058     * @throws NumberIsTooLargeException if {@code numberOfSuccesses > populationSize},
059     * or {@code sampleSize > populationSize}.
060     */
061    public HypergeometricDistribution(int populationSize, int numberOfSuccesses, int sampleSize)
062    throws NotPositiveException, NotStrictlyPositiveException, NumberIsTooLargeException {
063        this(new Well19937c(), populationSize, numberOfSuccesses, sampleSize);
064    }
065
066    /**
067     * Creates a new hypergeometric distribution.
068     *
069     * @param rng Random number generator.
070     * @param populationSize Population size.
071     * @param numberOfSuccesses Number of successes in the population.
072     * @param sampleSize Sample size.
073     * @throws NotPositiveException if {@code numberOfSuccesses < 0}.
074     * @throws NotStrictlyPositiveException if {@code populationSize <= 0}.
075     * @throws NumberIsTooLargeException if {@code numberOfSuccesses > populationSize},
076     * or {@code sampleSize > populationSize}.
077     * @since 3.1
078     */
079    public HypergeometricDistribution(RandomGenerator rng,
080                                      int populationSize,
081                                      int numberOfSuccesses,
082                                      int sampleSize)
083    throws NotPositiveException, NotStrictlyPositiveException, NumberIsTooLargeException {
084        super(rng);
085
086        if (populationSize <= 0) {
087            throw new NotStrictlyPositiveException(LocalizedFormats.POPULATION_SIZE,
088                                                   populationSize);
089        }
090        if (numberOfSuccesses < 0) {
091            throw new NotPositiveException(LocalizedFormats.NUMBER_OF_SUCCESSES,
092                                           numberOfSuccesses);
093        }
094        if (sampleSize < 0) {
095            throw new NotPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES,
096                                           sampleSize);
097        }
098
099        if (numberOfSuccesses > populationSize) {
100            throw new NumberIsTooLargeException(LocalizedFormats.NUMBER_OF_SUCCESS_LARGER_THAN_POPULATION_SIZE,
101                                                numberOfSuccesses, populationSize, true);
102        }
103        if (sampleSize > populationSize) {
104            throw new NumberIsTooLargeException(LocalizedFormats.SAMPLE_SIZE_LARGER_THAN_POPULATION_SIZE,
105                                                sampleSize, populationSize, true);
106        }
107
108        this.numberOfSuccesses = numberOfSuccesses;
109        this.populationSize = populationSize;
110        this.sampleSize = sampleSize;
111    }
112
113    /** {@inheritDoc} */
114    public double cumulativeProbability(int x) {
115        double ret;
116
117        int[] domain = getDomain(populationSize, numberOfSuccesses, sampleSize);
118        if (x < domain[0]) {
119            ret = 0.0;
120        } else if (x >= domain[1]) {
121            ret = 1.0;
122        } else {
123            ret = innerCumulativeProbability(domain[0], x, 1);
124        }
125
126        return ret;
127    }
128
129    /**
130     * Return the domain for the given hypergeometric distribution parameters.
131     *
132     * @param n Population size.
133     * @param m Number of successes in the population.
134     * @param k Sample size.
135     * @return a two element array containing the lower and upper bounds of the
136     * hypergeometric distribution.
137     */
138    private int[] getDomain(int n, int m, int k) {
139        return new int[] { getLowerDomain(n, m, k), getUpperDomain(m, k) };
140    }
141
142    /**
143     * Return the lowest domain value for the given hypergeometric distribution
144     * parameters.
145     *
146     * @param n Population size.
147     * @param m Number of successes in the population.
148     * @param k Sample size.
149     * @return the lowest domain value of the hypergeometric distribution.
150     */
151    private int getLowerDomain(int n, int m, int k) {
152        return FastMath.max(0, m - (n - k));
153    }
154
155    /**
156     * Access the number of successes.
157     *
158     * @return the number of successes.
159     */
160    public int getNumberOfSuccesses() {
161        return numberOfSuccesses;
162    }
163
164    /**
165     * Access the population size.
166     *
167     * @return the population size.
168     */
169    public int getPopulationSize() {
170        return populationSize;
171    }
172
173    /**
174     * Access the sample size.
175     *
176     * @return the sample size.
177     */
178    public int getSampleSize() {
179        return sampleSize;
180    }
181
182    /**
183     * Return the highest domain value for the given hypergeometric distribution
184     * parameters.
185     *
186     * @param m Number of successes in the population.
187     * @param k Sample size.
188     * @return the highest domain value of the hypergeometric distribution.
189     */
190    private int getUpperDomain(int m, int k) {
191        return FastMath.min(k, m);
192    }
193
194    /** {@inheritDoc} */
195    public double probability(int x) {
196        final double logProbability = logProbability(x);
197        return logProbability == Double.NEGATIVE_INFINITY ? 0 : FastMath.exp(logProbability);
198    }
199
200    /** {@inheritDoc} */
201    @Override
202    public double logProbability(int x) {
203        double ret;
204
205        int[] domain = getDomain(populationSize, numberOfSuccesses, sampleSize);
206        if (x < domain[0] || x > domain[1]) {
207            ret = Double.NEGATIVE_INFINITY;
208        } else {
209            double p = (double) sampleSize / (double) populationSize;
210            double q = (double) (populationSize - sampleSize) / (double) populationSize;
211            double p1 = SaddlePointExpansion.logBinomialProbability(x,
212                    numberOfSuccesses, p, q);
213            double p2 =
214                    SaddlePointExpansion.logBinomialProbability(sampleSize - x,
215                            populationSize - numberOfSuccesses, p, q);
216            double p3 =
217                    SaddlePointExpansion.logBinomialProbability(sampleSize, populationSize, p, q);
218            ret = p1 + p2 - p3;
219        }
220
221        return ret;
222    }
223
224    /**
225     * For this distribution, {@code X}, this method returns {@code P(X >= x)}.
226     *
227     * @param x Value at which the CDF is evaluated.
228     * @return the upper tail CDF for this distribution.
229     * @since 1.1
230     */
231    public double upperCumulativeProbability(int x) {
232        double ret;
233
234        final int[] domain = getDomain(populationSize, numberOfSuccesses, sampleSize);
235        if (x <= domain[0]) {
236            ret = 1.0;
237        } else if (x > domain[1]) {
238            ret = 0.0;
239        } else {
240            ret = innerCumulativeProbability(domain[1], x, -1);
241        }
242
243        return ret;
244    }
245
246    /**
247     * For this distribution, {@code X}, this method returns
248     * {@code P(x0 <= X <= x1)}.
249     * This probability is computed by summing the point probabilities for the
250     * values {@code x0, x0 + 1, x0 + 2, ..., x1}, in the order directed by
251     * {@code dx}.
252     *
253     * @param x0 Inclusive lower bound.
254     * @param x1 Inclusive upper bound.
255     * @param dx Direction of summation (1 indicates summing from x0 to x1, and
256     * 0 indicates summing from x1 to x0).
257     * @return {@code P(x0 <= X <= x1)}.
258     */
259    private double innerCumulativeProbability(int x0, int x1, int dx) {
260        double ret = probability(x0);
261        while (x0 != x1) {
262            x0 += dx;
263            ret += probability(x0);
264        }
265        return ret;
266    }
267
268    /**
269     * {@inheritDoc}
270     *
271     * For population size {@code N}, number of successes {@code m}, and sample
272     * size {@code n}, the mean is {@code n * m / N}.
273     */
274    public double getNumericalMean() {
275        return getSampleSize() * (getNumberOfSuccesses() / (double) getPopulationSize());
276    }
277
278    /**
279     * {@inheritDoc}
280     *
281     * For population size {@code N}, number of successes {@code m}, and sample
282     * size {@code n}, the variance is
283     * {@code [n * m * (N - n) * (N - m)] / [N^2 * (N - 1)]}.
284     */
285    public double getNumericalVariance() {
286        if (!numericalVarianceIsCalculated) {
287            numericalVariance = calculateNumericalVariance();
288            numericalVarianceIsCalculated = true;
289        }
290        return numericalVariance;
291    }
292
293    /**
294     * Used by {@link #getNumericalVariance()}.
295     *
296     * @return the variance of this distribution
297     */
298    protected double calculateNumericalVariance() {
299        final double N = getPopulationSize();
300        final double m = getNumberOfSuccesses();
301        final double n = getSampleSize();
302        return (n * m * (N - n) * (N - m)) / (N * N * (N - 1));
303    }
304
305    /**
306     * {@inheritDoc}
307     *
308     * For population size {@code N}, number of successes {@code m}, and sample
309     * size {@code n}, the lower bound of the support is
310     * {@code max(0, n + m - N)}.
311     *
312     * @return lower bound of the support
313     */
314    public int getSupportLowerBound() {
315        return FastMath.max(0,
316                            getSampleSize() + getNumberOfSuccesses() - getPopulationSize());
317    }
318
319    /**
320     * {@inheritDoc}
321     *
322     * For number of successes {@code m} and sample size {@code n}, the upper
323     * bound of the support is {@code min(m, n)}.
324     *
325     * @return upper bound of the support
326     */
327    public int getSupportUpperBound() {
328        return FastMath.min(getNumberOfSuccesses(), getSampleSize());
329    }
330
331    /**
332     * {@inheritDoc}
333     *
334     * The support of this distribution is connected.
335     *
336     * @return {@code true}
337     */
338    public boolean isSupportConnected() {
339        return true;
340    }
341}