View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.rng.sampling.distribution;
18  
19  import org.apache.commons.rng.UniformRandomProvider;
20  import org.apache.commons.rng.sampling.distribution.InternalUtils.FactorialLog;
21  
22  /**
23   * Sampler for the <a href="http://mathworld.wolfram.com/PoissonDistribution.html">Poisson distribution</a>.
24   *
25   * <ul>
26   *  <li>
27   *   For large means, we use the rejection algorithm described in
28   *   <blockquote>
29   *    Devroye, Luc. (1981).<i>The Computer Generation of Poisson Random Variables</i><br>
30   *    <strong>Computing</strong> vol. 26 pp. 197-207.
31   *   </blockquote>
32   *  </li>
33   * </ul>
34   *
35   * <p>This sampler is suitable for {@code mean >= 40}.</p>
36   *
37   * <p>Sampling uses:</p>
38   *
39   * <ul>
40   *   <li>{@link UniformRandomProvider#nextLong()}
41   *   <li>{@link UniformRandomProvider#nextDouble()}
42   * </ul>
43   *
44   * @since 1.1
45   */
46  public class LargeMeanPoissonSampler
47      implements SharedStateDiscreteSampler {
48      /** Upper bound to avoid truncation. */
49      private static final double MAX_MEAN = 0.5 * Integer.MAX_VALUE;
50      /** Class to compute {@code log(n!)}. This has no cached values. */
51      private static final InternalUtils.FactorialLog NO_CACHE_FACTORIAL_LOG;
52      /** Used when there is no requirement for a small mean Poisson sampler. */
53      private static final SharedStateDiscreteSampler NO_SMALL_MEAN_POISSON_SAMPLER =
54          new SharedStateDiscreteSampler() {
55              @Override
56              public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
57                  // No requirement for RNG
58                  return this;
59              }
60  
61              @Override
62              public int sample() {
63                  // No Poisson sample
64                  return 0;
65              }
66          };
67  
68      static {
69          // Create without a cache.
70          NO_CACHE_FACTORIAL_LOG = FactorialLog.create();
71      }
72  
73      /** Underlying source of randomness. */
74      private final UniformRandomProvider rng;
75      /** Exponential. */
76      private final SharedStateContinuousSampler exponential;
77      /** Gaussian. */
78      private final SharedStateContinuousSampler gaussian;
79      /** Local class to compute {@code log(n!)}. This may have cached values. */
80      private final InternalUtils.FactorialLog factorialLog;
81  
82      // Working values
83  
84      /** Algorithm constant: {@code Math.floor(mean)}. */
85      private final double lambda;
86      /** Algorithm constant: {@code Math.log(lambda)}. */
87      private final double logLambda;
88      /** Algorithm constant: {@code factorialLog((int) lambda)}. */
89      private final double logLambdaFactorial;
90      /** Algorithm constant: {@code Math.sqrt(lambda * Math.log(32 * lambda / Math.PI + 1))}. */
91      private final double delta;
92      /** Algorithm constant: {@code delta / 2}. */
93      private final double halfDelta;
94      /** Algorithm constant: {@code Math.sqrt(lambda + halfDelta)}. */
95      private final double sqrtLambdaPlusHalfDelta;
96      /** Algorithm constant: {@code 2 * lambda + delta}. */
97      private final double twolpd;
98      /**
99       * Algorithm constant: {@code a1 / aSum}.
100      * <ul>
101      *  <li>{@code a1 = Math.sqrt(Math.PI * twolpd) * Math.exp(c1)}</li>
102      *  <li>{@code aSum = a1 + a2 + 1}</li>
103      * </ul>
104      */
105     private final double p1;
106     /**
107      * Algorithm constant: {@code a2 / aSum}.
108      * <ul>
109      *  <li>{@code a2 = (twolpd / delta) * Math.exp(-delta * (1 + delta) / twolpd)}</li>
110      *  <li>{@code aSum = a1 + a2 + 1}</li>
111      * </ul>
112      */
113     private final double p2;
114     /** Algorithm constant: {@code 1 / (8 * lambda)}. */
115     private final double c1;
116 
117     /** The internal Poisson sampler for the lambda fraction. */
118     private final SharedStateDiscreteSampler smallMeanPoissonSampler;
119 
120 
121     /**
122      * Create an instance.
123      *
124      * @param rng Generator of uniformly distributed random numbers.
125      * @param mean Mean.
126      * @throws IllegalArgumentException if {@code mean < 1} or
127      * {@code mean > 0.5 *} {@link Integer#MAX_VALUE}.
128      */
129     public LargeMeanPoissonSampler(UniformRandomProvider rng,
130                                    double mean) {
131         // Validation before java.lang.Object constructor exits prevents partially initialized object
132         this(InternalUtils.requireRangeClosed(1, MAX_MEAN, mean, "mean"), rng);
133     }
134 
135     /**
136      * Instantiates a sampler using a precomputed state.
137      *
138      * @param rng              Generator of uniformly distributed random numbers.
139      * @param state            The state for {@code lambda = (int)Math.floor(mean)}.
140      * @param lambdaFractional The lambda fractional value
141      *                         ({@code mean - (int)Math.floor(mean))}.
142      * @throws IllegalArgumentException
143      *                         if {@code lambdaFractional < 0 || lambdaFractional >= 1}.
144      */
145     LargeMeanPoissonSampler(UniformRandomProvider rng,
146                             LargeMeanPoissonSamplerState state,
147                             double lambdaFractional) {
148         // Validation before java.lang.Object constructor exits prevents partially initialized object
149         this(state, InternalUtils.requireRange(0, 1, lambdaFractional, "lambdaFractional"), rng);
150     }
151 
152     /**
153      * @param mean Mean.
154      * @param rng Generator of uniformly distributed random numbers.
155      */
156     private LargeMeanPoissonSampler(double mean,
157                                     UniformRandomProvider rng) {
158         this.rng = rng;
159 
160         gaussian = ZigguratSampler.NormalizedGaussian.of(rng);
161         exponential = ZigguratSampler.Exponential.of(rng);
162         // Plain constructor uses the uncached function.
163         factorialLog = NO_CACHE_FACTORIAL_LOG;
164 
165         // Cache values used in the algorithm
166         lambda = Math.floor(mean);
167         logLambda = Math.log(lambda);
168         logLambdaFactorial = getFactorialLog((int) lambda);
169         delta = Math.sqrt(lambda * Math.log(32 * lambda / Math.PI + 1));
170         halfDelta = delta / 2;
171         sqrtLambdaPlusHalfDelta = Math.sqrt(lambda + halfDelta);
172         twolpd = 2 * lambda + delta;
173         c1 = 1 / (8 * lambda);
174         final double a1 = Math.sqrt(Math.PI * twolpd) * Math.exp(c1);
175         final double a2 = (twolpd / delta) * Math.exp(-delta * (1 + delta) / twolpd);
176         final double aSum = a1 + a2 + 1;
177         p1 = a1 / aSum;
178         p2 = a2 / aSum;
179 
180         // The algorithm requires a Poisson sample from the remaining lambda fraction.
181         final double lambdaFractional = mean - lambda;
182         smallMeanPoissonSampler = (lambdaFractional < Double.MIN_VALUE) ?
183             NO_SMALL_MEAN_POISSON_SAMPLER : // Not used.
184             KempSmallMeanPoissonSampler.of(rng, lambdaFractional);
185     }
186 
187     /**
188      * Instantiates a sampler using a precomputed state.
189      *
190      * @param state            The state for {@code lambda = (int)Math.floor(mean)}.
191      * @param lambdaFractional The lambda fractional value
192      *                         ({@code mean - (int)Math.floor(mean))}.
193      * @param rng              Generator of uniformly distributed random numbers.
194      */
195     private LargeMeanPoissonSampler(LargeMeanPoissonSamplerState state,
196                                     double lambdaFractional,
197                                     UniformRandomProvider rng) {
198         this.rng = rng;
199 
200         gaussian = ZigguratSampler.NormalizedGaussian.of(rng);
201         exponential = ZigguratSampler.Exponential.of(rng);
202         // Plain constructor uses the uncached function.
203         factorialLog = NO_CACHE_FACTORIAL_LOG;
204 
205         // Use the state to initialize the algorithm
206         lambda = state.getLambdaRaw();
207         logLambda = state.getLogLambda();
208         logLambdaFactorial = state.getLogLambdaFactorial();
209         delta = state.getDelta();
210         halfDelta = state.getHalfDelta();
211         sqrtLambdaPlusHalfDelta = state.getSqrtLambdaPlusHalfDelta();
212         twolpd = state.getTwolpd();
213         p1 = state.getP1();
214         p2 = state.getP2();
215         c1 = state.getC1();
216 
217         // The algorithm requires a Poisson sample from the remaining lambda fraction.
218         smallMeanPoissonSampler = (lambdaFractional < Double.MIN_VALUE) ?
219             NO_SMALL_MEAN_POISSON_SAMPLER : // Not used.
220             KempSmallMeanPoissonSampler.of(rng, lambdaFractional);
221     }
222 
223     /**
224      * @param rng Generator of uniformly distributed random numbers.
225      * @param source Source to copy.
226      */
227     private LargeMeanPoissonSampler(UniformRandomProvider rng,
228                                     LargeMeanPoissonSampler source) {
229         this.rng = rng;
230 
231         gaussian = source.gaussian.withUniformRandomProvider(rng);
232         exponential = source.exponential.withUniformRandomProvider(rng);
233         // Reuse the cache
234         factorialLog = source.factorialLog;
235 
236         lambda = source.lambda;
237         logLambda = source.logLambda;
238         logLambdaFactorial = source.logLambdaFactorial;
239         delta = source.delta;
240         halfDelta = source.halfDelta;
241         sqrtLambdaPlusHalfDelta = source.sqrtLambdaPlusHalfDelta;
242         twolpd = source.twolpd;
243         p1 = source.p1;
244         p2 = source.p2;
245         c1 = source.c1;
246 
247         // Share the state of the small sampler
248         smallMeanPoissonSampler = source.smallMeanPoissonSampler.withUniformRandomProvider(rng);
249     }
250 
251     /** {@inheritDoc} */
252     @Override
253     public int sample() {
254         // This will never be null. It may be a no-op delegate that returns zero.
255         final int y2 = smallMeanPoissonSampler.sample();
256 
257         double x;
258         double y;
259         double v;
260         int a;
261         double t;
262         double qr;
263         double qa;
264         while (true) {
265             // Step 1:
266             final double u = rng.nextDouble();
267             if (u <= p1) {
268                 // Step 2:
269                 final double n = gaussian.sample();
270                 x = n * sqrtLambdaPlusHalfDelta - 0.5d;
271                 if (x > delta || x < -lambda) {
272                     continue;
273                 }
274                 y = x < 0 ? Math.floor(x) : Math.ceil(x);
275                 final double e = exponential.sample();
276                 v = -e - 0.5 * n * n + c1;
277             } else {
278                 // Step 3:
279                 if (u > p1 + p2) {
280                     y = lambda;
281                     break;
282                 }
283                 x = delta + (twolpd / delta) * exponential.sample();
284                 y = Math.ceil(x);
285                 v = -exponential.sample() - delta * (x + 1) / twolpd;
286             }
287             // The Squeeze Principle
288             // Step 4.1:
289             a = x < 0 ? 1 : 0;
290             t = y * (y + 1) / (2 * lambda);
291             // Step 4.2
292             if (v < -t && a == 0) {
293                 y = lambda + y;
294                 break;
295             }
296             // Step 4.3:
297             qr = t * ((2 * y + 1) / (6 * lambda) - 1);
298             qa = qr - (t * t) / (3 * (lambda + a * (y + 1)));
299             // Step 4.4:
300             if (v < qa) {
301                 y = lambda + y;
302                 break;
303             }
304             // Step 4.5:
305             if (v > qr) {
306                 continue;
307             }
308             // Step 4.6:
309             if (v < y * logLambda - getFactorialLog((int) (y + lambda)) + logLambdaFactorial) {
310                 y = lambda + y;
311                 break;
312             }
313         }
314 
315         return (int) Math.min(y2 + (long) y, Integer.MAX_VALUE);
316     }
317 
318     /**
319      * Compute the natural logarithm of the factorial of {@code n}.
320      *
321      * @param n Argument.
322      * @return {@code log(n!)}
323      * @throws IllegalArgumentException if {@code n < 0}.
324      */
325     private double getFactorialLog(int n) {
326         return factorialLog.value(n);
327     }
328 
329     /** {@inheritDoc} */
330     @Override
331     public String toString() {
332         return "Large Mean Poisson deviate [" + rng.toString() + "]";
333     }
334 
335     /**
336      * {@inheritDoc}
337      *
338      * @since 1.3
339      */
340     @Override
341     public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
342         return new LargeMeanPoissonSampler(rng, this);
343     }
344 
345     /**
346      * Creates a new Poisson distribution sampler.
347      *
348      * @param rng Generator of uniformly distributed random numbers.
349      * @param mean Mean.
350      * @return the sampler
351      * @throws IllegalArgumentException if {@code mean < 1} or {@code mean > 0.5 *}
352      * {@link Integer#MAX_VALUE}.
353      * @since 1.3
354      */
355     public static SharedStateDiscreteSampler of(UniformRandomProvider rng,
356                                                 double mean) {
357         return new LargeMeanPoissonSampler(rng, mean);
358     }
359 
360     /**
361      * Gets the initialisation state of the sampler.
362      *
363      * <p>The state is computed using an integer {@code lambda} value of
364      * {@code lambda = (int)Math.floor(mean)}.
365      *
366      * <p>The state will be suitable for reconstructing a new sampler with a mean
367      * in the range {@code lambda <= mean < lambda+1} using
368      * {@link #LargeMeanPoissonSampler(UniformRandomProvider, LargeMeanPoissonSamplerState, double)}.
369      *
370      * @return the state
371      */
372     LargeMeanPoissonSamplerState getState() {
373         return new LargeMeanPoissonSamplerState(lambda, logLambda, logLambdaFactorial,
374                 delta, halfDelta, sqrtLambdaPlusHalfDelta, twolpd, p1, p2, c1);
375     }
376 
377     /**
378      * Encapsulate the state of the sampler. The state is valid for construction of
379      * a sampler in the range {@code lambda <= mean < lambda+1}.
380      *
381      * <p>This class is immutable.
382      *
383      * @see #getLambda()
384      */
385     static final class LargeMeanPoissonSamplerState {
386         /** Algorithm constant {@code lambda}. */
387         private final double lambda;
388         /** Algorithm constant {@code logLambda}. */
389         private final double logLambda;
390         /** Algorithm constant {@code logLambdaFactorial}. */
391         private final double logLambdaFactorial;
392         /** Algorithm constant {@code delta}. */
393         private final double delta;
394         /** Algorithm constant {@code halfDelta}. */
395         private final double halfDelta;
396         /** Algorithm constant {@code sqrtLambdaPlusHalfDelta}. */
397         private final double sqrtLambdaPlusHalfDelta;
398         /** Algorithm constant {@code twolpd}. */
399         private final double twolpd;
400         /** Algorithm constant {@code p1}. */
401         private final double p1;
402         /** Algorithm constant {@code p2}. */
403         private final double p2;
404         /** Algorithm constant {@code c1}. */
405         private final double c1;
406 
407         /**
408          * Creates the state.
409          *
410          * <p>The state is valid for construction of a sampler in the range
411          * {@code lambda <= mean < lambda+1} where {@code lambda} is an integer.
412          *
413          * @param lambda the lambda
414          * @param logLambda the log lambda
415          * @param logLambdaFactorial the log lambda factorial
416          * @param delta the delta
417          * @param halfDelta the half delta
418          * @param sqrtLambdaPlusHalfDelta the sqrt(lambda+half delta)
419          * @param twolpd the two lambda plus delta
420          * @param p1 the p1 constant
421          * @param p2 the p2 constant
422          * @param c1 the c1 constant
423          */
424         LargeMeanPoissonSamplerState(double lambda, double logLambda,
425                 double logLambdaFactorial, double delta, double halfDelta,
426                 double sqrtLambdaPlusHalfDelta, double twolpd,
427                 double p1, double p2, double c1) {
428             this.lambda = lambda;
429             this.logLambda = logLambda;
430             this.logLambdaFactorial = logLambdaFactorial;
431             this.delta = delta;
432             this.halfDelta = halfDelta;
433             this.sqrtLambdaPlusHalfDelta = sqrtLambdaPlusHalfDelta;
434             this.twolpd = twolpd;
435             this.p1 = p1;
436             this.p2 = p2;
437             this.c1 = c1;
438         }
439 
440         /**
441          * Get the lambda value for the state.
442          *
443          * <p>Equal to {@code floor(mean)} for a Poisson sampler.
444          * @return the lambda value
445          */
446         int getLambda() {
447             return (int) getLambdaRaw();
448         }
449 
450         /**
451          * @return algorithm constant {@code lambda}
452          */
453         double getLambdaRaw() {
454             return lambda;
455         }
456 
457         /**
458          * @return algorithm constant {@code logLambda}
459          */
460         double getLogLambda() {
461             return logLambda;
462         }
463 
464         /**
465          * @return algorithm constant {@code logLambdaFactorial}
466          */
467         double getLogLambdaFactorial() {
468             return logLambdaFactorial;
469         }
470 
471         /**
472          * @return algorithm constant {@code delta}
473          */
474         double getDelta() {
475             return delta;
476         }
477 
478         /**
479          * @return algorithm constant {@code halfDelta}
480          */
481         double getHalfDelta() {
482             return halfDelta;
483         }
484 
485         /**
486          * @return algorithm constant {@code sqrtLambdaPlusHalfDelta}
487          */
488         double getSqrtLambdaPlusHalfDelta() {
489             return sqrtLambdaPlusHalfDelta;
490         }
491 
492         /**
493          * @return algorithm constant {@code twolpd}
494          */
495         double getTwolpd() {
496             return twolpd;
497         }
498 
499         /**
500          * @return algorithm constant {@code p1}
501          */
502         double getP1() {
503             return p1;
504         }
505 
506         /**
507          * @return algorithm constant {@code p2}
508          */
509         double getP2() {
510             return p2;
511         }
512 
513         /**
514          * @return algorithm constant {@code c1}
515          */
516         double getC1() {
517             return c1;
518         }
519     }
520 }