View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.math3.distribution;
18  
19  import org.apache.commons.math3.exception.NotStrictlyPositiveException;
20  import org.apache.commons.math3.exception.util.LocalizedFormats;
21  import org.apache.commons.math3.special.Gamma;
22  import org.apache.commons.math3.util.CombinatoricsUtils;
23  import org.apache.commons.math3.util.MathUtils;
24  import org.apache.commons.math3.util.FastMath;
25  import org.apache.commons.math3.random.RandomGenerator;
26  import org.apache.commons.math3.random.Well19937c;
27  
28  /**
29   * Implementation of the Poisson distribution.
30   *
31   * @see <a href="http://en.wikipedia.org/wiki/Poisson_distribution">Poisson distribution (Wikipedia)</a>
32   * @see <a href="http://mathworld.wolfram.com/PoissonDistribution.html">Poisson distribution (MathWorld)</a>
33   */
34  public class PoissonDistribution extends AbstractIntegerDistribution {
35      /**
36       * Default maximum number of iterations for cumulative probability calculations.
37       * @since 2.1
38       */
39      public static final int DEFAULT_MAX_ITERATIONS = 10000000;
40      /**
41       * Default convergence criterion.
42       * @since 2.1
43       */
44      public static final double DEFAULT_EPSILON = 1e-12;
45      /** Serializable version identifier. */
46      private static final long serialVersionUID = -3349935121172596109L;
47      /** Distribution used to compute normal approximation. */
48      private final NormalDistribution normal;
49      /** Distribution needed for the {@link #sample()} method. */
50      private final ExponentialDistribution exponential;
51      /** Mean of the distribution. */
52      private final double mean;
53  
54      /**
55       * Maximum number of iterations for cumulative probability. Cumulative
56       * probabilities are estimated using either Lanczos series approximation
57       * of {@link Gamma#regularizedGammaP(double, double, double, int)}
58       * or continued fraction approximation of
59       * {@link Gamma#regularizedGammaQ(double, double, double, int)}.
60       */
61      private final int maxIterations;
62  
63      /** Convergence criterion for cumulative probability. */
64      private final double epsilon;
65  
66      /**
67       * Creates a new Poisson distribution with specified mean.
68       *
69       * @param p the Poisson mean
70       * @throws NotStrictlyPositiveException if {@code p <= 0}.
71       */
72      public PoissonDistribution(double p) throws NotStrictlyPositiveException {
73          this(p, DEFAULT_EPSILON, DEFAULT_MAX_ITERATIONS);
74      }
75  
76      /**
77       * Creates a new Poisson distribution with specified mean, convergence
78       * criterion and maximum number of iterations.
79       *
80       * @param p Poisson mean.
81       * @param epsilon Convergence criterion for cumulative probabilities.
82       * @param maxIterations the maximum number of iterations for cumulative
83       * probabilities.
84       * @throws NotStrictlyPositiveException if {@code p <= 0}.
85       * @since 2.1
86       */
87      public PoissonDistribution(double p, double epsilon, int maxIterations)
88      throws NotStrictlyPositiveException {
89          this(new Well19937c(), p, epsilon, maxIterations);
90      }
91  
92      /**
93       * Creates a new Poisson distribution with specified mean, convergence
94       * criterion and maximum number of iterations.
95       *
96       * @param rng Random number generator.
97       * @param p Poisson mean.
98       * @param epsilon Convergence criterion for cumulative probabilities.
99       * @param maxIterations the maximum number of iterations for cumulative
100      * probabilities.
101      * @throws NotStrictlyPositiveException if {@code p <= 0}.
102      * @since 3.1
103      */
104     public PoissonDistribution(RandomGenerator rng,
105                                double p,
106                                double epsilon,
107                                int maxIterations)
108     throws NotStrictlyPositiveException {
109         super(rng);
110 
111         if (p <= 0) {
112             throw new NotStrictlyPositiveException(LocalizedFormats.MEAN, p);
113         }
114         mean = p;
115         this.epsilon = epsilon;
116         this.maxIterations = maxIterations;
117 
118         // Use the same RNG instance as the parent class.
119         normal = new NormalDistribution(rng, p, FastMath.sqrt(p),
120                                         NormalDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY);
121         exponential = new ExponentialDistribution(rng, 1,
122                                                   ExponentialDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY);
123     }
124 
125     /**
126      * Creates a new Poisson distribution with the specified mean and
127      * convergence criterion.
128      *
129      * @param p Poisson mean.
130      * @param epsilon Convergence criterion for cumulative probabilities.
131      * @throws NotStrictlyPositiveException if {@code p <= 0}.
132      * @since 2.1
133      */
134     public PoissonDistribution(double p, double epsilon)
135     throws NotStrictlyPositiveException {
136         this(p, epsilon, DEFAULT_MAX_ITERATIONS);
137     }
138 
139     /**
140      * Creates a new Poisson distribution with the specified mean and maximum
141      * number of iterations.
142      *
143      * @param p Poisson mean.
144      * @param maxIterations Maximum number of iterations for cumulative
145      * probabilities.
146      * @since 2.1
147      */
148     public PoissonDistribution(double p, int maxIterations) {
149         this(p, DEFAULT_EPSILON, maxIterations);
150     }
151 
152     /**
153      * Get the mean for the distribution.
154      *
155      * @return the mean for the distribution.
156      */
157     public double getMean() {
158         return mean;
159     }
160 
161     /** {@inheritDoc} */
162     public double probability(int x) {
163         final double logProbability = logProbability(x);
164         return logProbability == Double.NEGATIVE_INFINITY ? 0 : FastMath.exp(logProbability);
165     }
166 
167     /** {@inheritDoc} */
168     @Override
169     public double logProbability(int x) {
170         double ret;
171         if (x < 0 || x == Integer.MAX_VALUE) {
172             ret = Double.NEGATIVE_INFINITY;
173         } else if (x == 0) {
174             ret = -mean;
175         } else {
176             ret = -SaddlePointExpansion.getStirlingError(x) -
177                   SaddlePointExpansion.getDeviancePart(x, mean) -
178                   0.5 * FastMath.log(MathUtils.TWO_PI) - 0.5 * FastMath.log(x);
179         }
180         return ret;
181     }
182 
183     /** {@inheritDoc} */
184     public double cumulativeProbability(int x) {
185         if (x < 0) {
186             return 0;
187         }
188         if (x == Integer.MAX_VALUE) {
189             return 1;
190         }
191         return Gamma.regularizedGammaQ((double) x + 1, mean, epsilon,
192                                        maxIterations);
193     }
194 
195     /**
196      * Calculates the Poisson distribution function using a normal
197      * approximation. The {@code N(mean, sqrt(mean))} distribution is used
198      * to approximate the Poisson distribution. The computation uses
199      * "half-correction" (evaluating the normal distribution function at
200      * {@code x + 0.5}).
201      *
202      * @param x Upper bound, inclusive.
203      * @return the distribution function value calculated using a normal
204      * approximation.
205      */
206     public double normalApproximateProbability(int x)  {
207         // calculate the probability using half-correction
208         return normal.cumulativeProbability(x + 0.5);
209     }
210 
211     /**
212      * {@inheritDoc}
213      *
214      * For mean parameter {@code p}, the mean is {@code p}.
215      */
216     public double getNumericalMean() {
217         return getMean();
218     }
219 
220     /**
221      * {@inheritDoc}
222      *
223      * For mean parameter {@code p}, the variance is {@code p}.
224      */
225     public double getNumericalVariance() {
226         return getMean();
227     }
228 
229     /**
230      * {@inheritDoc}
231      *
232      * The lower bound of the support is always 0 no matter the mean parameter.
233      *
234      * @return lower bound of the support (always 0)
235      */
236     public int getSupportLowerBound() {
237         return 0;
238     }
239 
240     /**
241      * {@inheritDoc}
242      *
243      * The upper bound of the support is positive infinity,
244      * regardless of the parameter values. There is no integer infinity,
245      * so this method returns {@code Integer.MAX_VALUE}.
246      *
247      * @return upper bound of the support (always {@code Integer.MAX_VALUE} for
248      * positive infinity)
249      */
250     public int getSupportUpperBound() {
251         return Integer.MAX_VALUE;
252     }
253 
254     /**
255      * {@inheritDoc}
256      *
257      * The support of this distribution is connected.
258      *
259      * @return {@code true}
260      */
261     public boolean isSupportConnected() {
262         return true;
263     }
264 
265     /**
266      * {@inheritDoc}
267      * <p>
268      * <strong>Algorithm Description</strong>:
269      * <ul>
270      *  <li>For small means, uses simulation of a Poisson process
271      *   using Uniform deviates, as described
272      *   <a href="http://irmi.epfl.ch/cmos/Pmmi/interactive/rng7.htm"> here</a>.
273      *   The Poisson process (and hence value returned) is bounded by 1000 * mean.
274      *  </li>
275      *  <li>For large means, uses the rejection algorithm described in
276      *   <quote>
277      *    Devroye, Luc. (1981).<i>The Computer Generation of Poisson Random Variables</i>
278      *    <strong>Computing</strong> vol. 26 pp. 197-207.
279      *   </quote>
280      *  </li>
281      * </ul>
282      * </p>
283      *
284      * @return a random value.
285      * @since 2.2
286      */
287     @Override
288     public int sample() {
289         return (int) FastMath.min(nextPoisson(mean), Integer.MAX_VALUE);
290     }
291 
292     /**
293      * @param meanPoisson Mean of the Poisson distribution.
294      * @return the next sample.
295      */
296     private long nextPoisson(double meanPoisson) {
297         final double pivot = 40.0d;
298         if (meanPoisson < pivot) {
299             double p = FastMath.exp(-meanPoisson);
300             long n = 0;
301             double r = 1.0d;
302             double rnd = 1.0d;
303 
304             while (n < 1000 * meanPoisson) {
305                 rnd = random.nextDouble();
306                 r *= rnd;
307                 if (r >= p) {
308                     n++;
309                 } else {
310                     return n;
311                 }
312             }
313             return n;
314         } else {
315             final double lambda = FastMath.floor(meanPoisson);
316             final double lambdaFractional = meanPoisson - lambda;
317             final double logLambda = FastMath.log(lambda);
318             final double logLambdaFactorial = CombinatoricsUtils.factorialLog((int) lambda);
319             final long y2 = lambdaFractional < Double.MIN_VALUE ? 0 : nextPoisson(lambdaFractional);
320             final double delta = FastMath.sqrt(lambda * FastMath.log(32 * lambda / FastMath.PI + 1));
321             final double halfDelta = delta / 2;
322             final double twolpd = 2 * lambda + delta;
323             final double a1 = FastMath.sqrt(FastMath.PI * twolpd) * FastMath.exp(1 / (8 * lambda));
324             final double a2 = (twolpd / delta) * FastMath.exp(-delta * (1 + delta) / twolpd);
325             final double aSum = a1 + a2 + 1;
326             final double p1 = a1 / aSum;
327             final double p2 = a2 / aSum;
328             final double c1 = 1 / (8 * lambda);
329 
330             double x = 0;
331             double y = 0;
332             double v = 0;
333             int a = 0;
334             double t = 0;
335             double qr = 0;
336             double qa = 0;
337             for (;;) {
338                 final double u = random.nextDouble();
339                 if (u <= p1) {
340                     final double n = random.nextGaussian();
341                     x = n * FastMath.sqrt(lambda + halfDelta) - 0.5d;
342                     if (x > delta || x < -lambda) {
343                         continue;
344                     }
345                     y = x < 0 ? FastMath.floor(x) : FastMath.ceil(x);
346                     final double e = exponential.sample();
347                     v = -e - (n * n / 2) + c1;
348                 } else {
349                     if (u > p1 + p2) {
350                         y = lambda;
351                         break;
352                     } else {
353                         x = delta + (twolpd / delta) * exponential.sample();
354                         y = FastMath.ceil(x);
355                         v = -exponential.sample() - delta * (x + 1) / twolpd;
356                     }
357                 }
358                 a = x < 0 ? 1 : 0;
359                 t = y * (y + 1) / (2 * lambda);
360                 if (v < -t && a == 0) {
361                     y = lambda + y;
362                     break;
363                 }
364                 qr = t * ((2 * y + 1) / (6 * lambda) - 1);
365                 qa = qr - (t * t) / (3 * (lambda + a * (y + 1)));
366                 if (v < qa) {
367                     y = lambda + y;
368                     break;
369                 }
370                 if (v > qr) {
371                     continue;
372                 }
373                 if (v < y * logLambda - CombinatoricsUtils.factorialLog((int) (y + lambda)) + logLambdaFactorial) {
374                     y = lambda + y;
375                     break;
376                 }
377             }
378             return y2 + (long) y;
379         }
380     }
381 }