001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math3.distribution;
019
020import org.apache.commons.math3.exception.NotPositiveException;
021import org.apache.commons.math3.exception.NotStrictlyPositiveException;
022import org.apache.commons.math3.exception.NumberIsTooLargeException;
023import org.apache.commons.math3.exception.util.LocalizedFormats;
024import org.apache.commons.math3.random.RandomGenerator;
025import org.apache.commons.math3.random.Well19937c;
026import org.apache.commons.math3.util.FastMath;
027
028/**
029 * Implementation of the hypergeometric distribution.
030 *
031 * @see <a href="http://en.wikipedia.org/wiki/Hypergeometric_distribution">Hypergeometric distribution (Wikipedia)</a>
032 * @see <a href="http://mathworld.wolfram.com/HypergeometricDistribution.html">Hypergeometric distribution (MathWorld)</a>
033 */
034public class HypergeometricDistribution extends AbstractIntegerDistribution {
035    /** Serializable version identifier. */
036    private static final long serialVersionUID = -436928820673516179L;
037    /** The number of successes in the population. */
038    private final int numberOfSuccesses;
039    /** The population size. */
040    private final int populationSize;
041    /** The sample size. */
042    private final int sampleSize;
043    /** Cached numerical variance */
044    private double numericalVariance = Double.NaN;
045    /** Whether or not the numerical variance has been calculated */
046    private boolean numericalVarianceIsCalculated = false;
047
048    /**
049     * Construct a new hypergeometric distribution with the specified population
050     * size, number of successes in the population, and sample size.
051     * <p>
052     * <b>Note:</b> this constructor will implicitly create an instance of
053     * {@link Well19937c} as random generator to be used for sampling only (see
054     * {@link #sample()} and {@link #sample(int)}). In case no sampling is
055     * needed for the created distribution, it is advised to pass {@code null}
056     * as random generator via the appropriate constructors to avoid the
057     * additional initialisation overhead.
058     *
059     * @param populationSize Population size.
060     * @param numberOfSuccesses Number of successes in the population.
061     * @param sampleSize Sample size.
062     * @throws NotPositiveException if {@code numberOfSuccesses < 0}.
063     * @throws NotStrictlyPositiveException if {@code populationSize <= 0}.
064     * @throws NumberIsTooLargeException if {@code numberOfSuccesses > populationSize},
065     * or {@code sampleSize > populationSize}.
066     */
067    public HypergeometricDistribution(int populationSize, int numberOfSuccesses, int sampleSize)
068    throws NotPositiveException, NotStrictlyPositiveException, NumberIsTooLargeException {
069        this(new Well19937c(), populationSize, numberOfSuccesses, sampleSize);
070    }
071
072    /**
073     * Creates a new hypergeometric distribution.
074     *
075     * @param rng Random number generator.
076     * @param populationSize Population size.
077     * @param numberOfSuccesses Number of successes in the population.
078     * @param sampleSize Sample size.
079     * @throws NotPositiveException if {@code numberOfSuccesses < 0}.
080     * @throws NotStrictlyPositiveException if {@code populationSize <= 0}.
081     * @throws NumberIsTooLargeException if {@code numberOfSuccesses > populationSize},
082     * or {@code sampleSize > populationSize}.
083     * @since 3.1
084     */
085    public HypergeometricDistribution(RandomGenerator rng,
086                                      int populationSize,
087                                      int numberOfSuccesses,
088                                      int sampleSize)
089    throws NotPositiveException, NotStrictlyPositiveException, NumberIsTooLargeException {
090        super(rng);
091
092        if (populationSize <= 0) {
093            throw new NotStrictlyPositiveException(LocalizedFormats.POPULATION_SIZE,
094                                                   populationSize);
095        }
096        if (numberOfSuccesses < 0) {
097            throw new NotPositiveException(LocalizedFormats.NUMBER_OF_SUCCESSES,
098                                           numberOfSuccesses);
099        }
100        if (sampleSize < 0) {
101            throw new NotPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES,
102                                           sampleSize);
103        }
104
105        if (numberOfSuccesses > populationSize) {
106            throw new NumberIsTooLargeException(LocalizedFormats.NUMBER_OF_SUCCESS_LARGER_THAN_POPULATION_SIZE,
107                                                numberOfSuccesses, populationSize, true);
108        }
109        if (sampleSize > populationSize) {
110            throw new NumberIsTooLargeException(LocalizedFormats.SAMPLE_SIZE_LARGER_THAN_POPULATION_SIZE,
111                                                sampleSize, populationSize, true);
112        }
113
114        this.numberOfSuccesses = numberOfSuccesses;
115        this.populationSize = populationSize;
116        this.sampleSize = sampleSize;
117    }
118
119    /** {@inheritDoc} */
120    public double cumulativeProbability(int x) {
121        double ret;
122
123        int[] domain = getDomain(populationSize, numberOfSuccesses, sampleSize);
124        if (x < domain[0]) {
125            ret = 0.0;
126        } else if (x >= domain[1]) {
127            ret = 1.0;
128        } else {
129            ret = innerCumulativeProbability(domain[0], x, 1);
130        }
131
132        return ret;
133    }
134
135    /**
136     * Return the domain for the given hypergeometric distribution parameters.
137     *
138     * @param n Population size.
139     * @param m Number of successes in the population.
140     * @param k Sample size.
141     * @return a two element array containing the lower and upper bounds of the
142     * hypergeometric distribution.
143     */
144    private int[] getDomain(int n, int m, int k) {
145        return new int[] { getLowerDomain(n, m, k), getUpperDomain(m, k) };
146    }
147
148    /**
149     * Return the lowest domain value for the given hypergeometric distribution
150     * parameters.
151     *
152     * @param n Population size.
153     * @param m Number of successes in the population.
154     * @param k Sample size.
155     * @return the lowest domain value of the hypergeometric distribution.
156     */
157    private int getLowerDomain(int n, int m, int k) {
158        return FastMath.max(0, m - (n - k));
159    }
160
161    /**
162     * Access the number of successes.
163     *
164     * @return the number of successes.
165     */
166    public int getNumberOfSuccesses() {
167        return numberOfSuccesses;
168    }
169
170    /**
171     * Access the population size.
172     *
173     * @return the population size.
174     */
175    public int getPopulationSize() {
176        return populationSize;
177    }
178
179    /**
180     * Access the sample size.
181     *
182     * @return the sample size.
183     */
184    public int getSampleSize() {
185        return sampleSize;
186    }
187
188    /**
189     * Return the highest domain value for the given hypergeometric distribution
190     * parameters.
191     *
192     * @param m Number of successes in the population.
193     * @param k Sample size.
194     * @return the highest domain value of the hypergeometric distribution.
195     */
196    private int getUpperDomain(int m, int k) {
197        return FastMath.min(k, m);
198    }
199
200    /** {@inheritDoc} */
201    public double probability(int x) {
202        final double logProbability = logProbability(x);
203        return logProbability == Double.NEGATIVE_INFINITY ? 0 : FastMath.exp(logProbability);
204    }
205
206    /** {@inheritDoc} */
207    @Override
208    public double logProbability(int x) {
209        double ret;
210
211        int[] domain = getDomain(populationSize, numberOfSuccesses, sampleSize);
212        if (x < domain[0] || x > domain[1]) {
213            ret = Double.NEGATIVE_INFINITY;
214        } else {
215            double p = (double) sampleSize / (double) populationSize;
216            double q = (double) (populationSize - sampleSize) / (double) populationSize;
217            double p1 = SaddlePointExpansion.logBinomialProbability(x,
218                    numberOfSuccesses, p, q);
219            double p2 =
220                    SaddlePointExpansion.logBinomialProbability(sampleSize - x,
221                            populationSize - numberOfSuccesses, p, q);
222            double p3 =
223                    SaddlePointExpansion.logBinomialProbability(sampleSize, populationSize, p, q);
224            ret = p1 + p2 - p3;
225        }
226
227        return ret;
228    }
229
230    /**
231     * For this distribution, {@code X}, this method returns {@code P(X >= x)}.
232     *
233     * @param x Value at which the CDF is evaluated.
234     * @return the upper tail CDF for this distribution.
235     * @since 1.1
236     */
237    public double upperCumulativeProbability(int x) {
238        double ret;
239
240        final int[] domain = getDomain(populationSize, numberOfSuccesses, sampleSize);
241        if (x <= domain[0]) {
242            ret = 1.0;
243        } else if (x > domain[1]) {
244            ret = 0.0;
245        } else {
246            ret = innerCumulativeProbability(domain[1], x, -1);
247        }
248
249        return ret;
250    }
251
252    /**
253     * For this distribution, {@code X}, this method returns
254     * {@code P(x0 <= X <= x1)}.
255     * This probability is computed by summing the point probabilities for the
256     * values {@code x0, x0 + 1, x0 + 2, ..., x1}, in the order directed by
257     * {@code dx}.
258     *
259     * @param x0 Inclusive lower bound.
260     * @param x1 Inclusive upper bound.
261     * @param dx Direction of summation (1 indicates summing from x0 to x1, and
262     * 0 indicates summing from x1 to x0).
263     * @return {@code P(x0 <= X <= x1)}.
264     */
265    private double innerCumulativeProbability(int x0, int x1, int dx) {
266        double ret = probability(x0);
267        while (x0 != x1) {
268            x0 += dx;
269            ret += probability(x0);
270        }
271        return ret;
272    }
273
274    /**
275     * {@inheritDoc}
276     *
277     * For population size {@code N}, number of successes {@code m}, and sample
278     * size {@code n}, the mean is {@code n * m / N}.
279     */
280    public double getNumericalMean() {
281        return getSampleSize() * (getNumberOfSuccesses() / (double) getPopulationSize());
282    }
283
284    /**
285     * {@inheritDoc}
286     *
287     * For population size {@code N}, number of successes {@code m}, and sample
288     * size {@code n}, the variance is
289     * {@code [n * m * (N - n) * (N - m)] / [N^2 * (N - 1)]}.
290     */
291    public double getNumericalVariance() {
292        if (!numericalVarianceIsCalculated) {
293            numericalVariance = calculateNumericalVariance();
294            numericalVarianceIsCalculated = true;
295        }
296        return numericalVariance;
297    }
298
299    /**
300     * Used by {@link #getNumericalVariance()}.
301     *
302     * @return the variance of this distribution
303     */
304    protected double calculateNumericalVariance() {
305        final double N = getPopulationSize();
306        final double m = getNumberOfSuccesses();
307        final double n = getSampleSize();
308        return (n * m * (N - n) * (N - m)) / (N * N * (N - 1));
309    }
310
311    /**
312     * {@inheritDoc}
313     *
314     * For population size {@code N}, number of successes {@code m}, and sample
315     * size {@code n}, the lower bound of the support is
316     * {@code max(0, n + m - N)}.
317     *
318     * @return lower bound of the support
319     */
320    public int getSupportLowerBound() {
321        return FastMath.max(0,
322                            getSampleSize() + getNumberOfSuccesses() - getPopulationSize());
323    }
324
325    /**
326     * {@inheritDoc}
327     *
328     * For number of successes {@code m} and sample size {@code n}, the upper
329     * bound of the support is {@code min(m, n)}.
330     *
331     * @return upper bound of the support
332     */
333    public int getSupportUpperBound() {
334        return FastMath.min(getNumberOfSuccesses(), getSampleSize());
335    }
336
337    /**
338     * {@inheritDoc}
339     *
340     * The support of this distribution is connected.
341     *
342     * @return {@code true}
343     */
344    public boolean isSupportConnected() {
345        return true;
346    }
347}