View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.rng.sampling.distribution;
18  
19  import org.apache.commons.rng.UniformRandomProvider;
20  import org.apache.commons.rng.sampling.distribution.InternalUtils.FactorialLog;
21  
22  /**
23   * Sampler for the <a href="http://mathworld.wolfram.com/PoissonDistribution.html">Poisson distribution</a>.
24   *
25   * <ul>
26   *  <li>
27   *   For large means, we use the rejection algorithm described in
28   *   <blockquote>
29   *    Devroye, Luc. (1981).<i>The Computer Generation of Poisson Random Variables</i><br>
30   *    <strong>Computing</strong> vol. 26 pp. 197-207.
31   *   </blockquote>
32   *  </li>
33   * </ul>
34   *
35   * <p>This sampler is suitable for {@code mean >= 40}.</p>
36   *
37   * <p>Sampling uses:</p>
38   *
39   * <ul>
40   *   <li>{@link UniformRandomProvider#nextLong()}
41   *   <li>{@link UniformRandomProvider#nextDouble()}
42   * </ul>
43   *
44   * @since 1.1
45   */
46  public class LargeMeanPoissonSampler
47      implements SharedStateDiscreteSampler {
48      /** Upper bound to avoid truncation. */
49      private static final double MAX_MEAN = 0.5 * Integer.MAX_VALUE;
50      /** Class to compute {@code log(n!)}. This has no cached values. */
51      private static final InternalUtils.FactorialLog NO_CACHE_FACTORIAL_LOG;
52      /** Used when there is no requirement for a small mean Poisson sampler. */
53      private static final SharedStateDiscreteSampler NO_SMALL_MEAN_POISSON_SAMPLER =
54          new SharedStateDiscreteSampler() {
55              @Override
56              public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
57                  // No requirement for RNG
58                  return this;
59              }
60  
61              @Override
62              public int sample() {
63                  // No Poisson sample
64                  return 0;
65              }
66          };
67  
68      static {
69          // Create without a cache.
70          NO_CACHE_FACTORIAL_LOG = FactorialLog.create();
71      }
72  
73      /** Underlying source of randomness. */
74      private final UniformRandomProvider rng;
75      /** Exponential. */
76      private final SharedStateContinuousSampler exponential;
77      /** Gaussian. */
78      private final SharedStateContinuousSampler gaussian;
79      /** Local class to compute {@code log(n!)}. This may have cached values. */
80      private final InternalUtils.FactorialLog factorialLog;
81  
82      // Working values
83  
84      /** Algorithm constant: {@code Math.floor(mean)}. */
85      private final double lambda;
86      /** Algorithm constant: {@code Math.log(lambda)}. */
87      private final double logLambda;
88      /** Algorithm constant: {@code factorialLog((int) lambda)}. */
89      private final double logLambdaFactorial;
90      /** Algorithm constant: {@code Math.sqrt(lambda * Math.log(32 * lambda / Math.PI + 1))}. */
91      private final double delta;
92      /** Algorithm constant: {@code delta / 2}. */
93      private final double halfDelta;
94      /** Algorithm constant: {@code Math.sqrt(lambda + halfDelta)}. */
95      private final double sqrtLambdaPlusHalfDelta;
96      /** Algorithm constant: {@code 2 * lambda + delta}. */
97      private final double twolpd;
98      /**
99       * Algorithm constant: {@code a1 / aSum}.
100      * <ul>
101      *  <li>{@code a1 = Math.sqrt(Math.PI * twolpd) * Math.exp(c1)}</li>
102      *  <li>{@code aSum = a1 + a2 + 1}</li>
103      * </ul>
104      */
105     private final double p1;
106     /**
107      * Algorithm constant: {@code a2 / aSum}.
108      * <ul>
109      *  <li>{@code a2 = (twolpd / delta) * Math.exp(-delta * (1 + delta) / twolpd)}</li>
110      *  <li>{@code aSum = a1 + a2 + 1}</li>
111      * </ul>
112      */
113     private final double p2;
114     /** Algorithm constant: {@code 1 / (8 * lambda)}. */
115     private final double c1;
116 
117     /** The internal Poisson sampler for the lambda fraction. */
118     private final SharedStateDiscreteSampler smallMeanPoissonSampler;
119 
120     /**
121      * @param rng Generator of uniformly distributed random numbers.
122      * @param mean Mean.
123      * @throws IllegalArgumentException if {@code mean < 1} or
124      * {@code mean > 0.5 *} {@link Integer#MAX_VALUE}.
125      */
126     public LargeMeanPoissonSampler(UniformRandomProvider rng,
127                                    double mean) {
128         if (mean < 1) {
129             throw new IllegalArgumentException("mean is not >= 1: " + mean);
130         }
131         // The algorithm is not valid if Math.floor(mean) is not an integer.
132         if (mean > MAX_MEAN) {
133             throw new IllegalArgumentException("mean " + mean + " > " + MAX_MEAN);
134         }
135         this.rng = rng;
136 
137         gaussian = ZigguratSampler.NormalizedGaussian.of(rng);
138         exponential = ZigguratSampler.Exponential.of(rng);
139         // Plain constructor uses the uncached function.
140         factorialLog = NO_CACHE_FACTORIAL_LOG;
141 
142         // Cache values used in the algorithm
143         lambda = Math.floor(mean);
144         logLambda = Math.log(lambda);
145         logLambdaFactorial = getFactorialLog((int) lambda);
146         delta = Math.sqrt(lambda * Math.log(32 * lambda / Math.PI + 1));
147         halfDelta = delta / 2;
148         sqrtLambdaPlusHalfDelta = Math.sqrt(lambda + halfDelta);
149         twolpd = 2 * lambda + delta;
150         c1 = 1 / (8 * lambda);
151         final double a1 = Math.sqrt(Math.PI * twolpd) * Math.exp(c1);
152         final double a2 = (twolpd / delta) * Math.exp(-delta * (1 + delta) / twolpd);
153         final double aSum = a1 + a2 + 1;
154         p1 = a1 / aSum;
155         p2 = a2 / aSum;
156 
157         // The algorithm requires a Poisson sample from the remaining lambda fraction.
158         final double lambdaFractional = mean - lambda;
159         smallMeanPoissonSampler = (lambdaFractional < Double.MIN_VALUE) ?
160             NO_SMALL_MEAN_POISSON_SAMPLER : // Not used.
161             KempSmallMeanPoissonSampler.of(rng, lambdaFractional);
162     }
163 
164     /**
165      * Instantiates a sampler using a precomputed state.
166      *
167      * @param rng              Generator of uniformly distributed random numbers.
168      * @param state            The state for {@code lambda = (int)Math.floor(mean)}.
169      * @param lambdaFractional The lambda fractional value
170      *                         ({@code mean - (int)Math.floor(mean))}.
171      * @throws IllegalArgumentException
172      *                         if {@code lambdaFractional < 0 || lambdaFractional >= 1}.
173      */
174     LargeMeanPoissonSampler(UniformRandomProvider rng,
175                             LargeMeanPoissonSamplerState state,
176                             double lambdaFractional) {
177         if (lambdaFractional < 0 || lambdaFractional >= 1) {
178             throw new IllegalArgumentException(
179                     "lambdaFractional must be in the range 0 (inclusive) to 1 (exclusive): " + lambdaFractional);
180         }
181         this.rng = rng;
182 
183         gaussian = ZigguratSampler.NormalizedGaussian.of(rng);
184         exponential = ZigguratSampler.Exponential.of(rng);
185         // Plain constructor uses the uncached function.
186         factorialLog = NO_CACHE_FACTORIAL_LOG;
187 
188         // Use the state to initialize the algorithm
189         lambda = state.getLambdaRaw();
190         logLambda = state.getLogLambda();
191         logLambdaFactorial = state.getLogLambdaFactorial();
192         delta = state.getDelta();
193         halfDelta = state.getHalfDelta();
194         sqrtLambdaPlusHalfDelta = state.getSqrtLambdaPlusHalfDelta();
195         twolpd = state.getTwolpd();
196         p1 = state.getP1();
197         p2 = state.getP2();
198         c1 = state.getC1();
199 
200         // The algorithm requires a Poisson sample from the remaining lambda fraction.
201         smallMeanPoissonSampler = (lambdaFractional < Double.MIN_VALUE) ?
202             NO_SMALL_MEAN_POISSON_SAMPLER : // Not used.
203             KempSmallMeanPoissonSampler.of(rng, lambdaFractional);
204     }
205 
206     /**
207      * @param rng Generator of uniformly distributed random numbers.
208      * @param source Source to copy.
209      */
210     private LargeMeanPoissonSampler(UniformRandomProvider rng,
211                                     LargeMeanPoissonSampler source) {
212         this.rng = rng;
213 
214         gaussian = source.gaussian.withUniformRandomProvider(rng);
215         exponential = source.exponential.withUniformRandomProvider(rng);
216         // Reuse the cache
217         factorialLog = source.factorialLog;
218 
219         lambda = source.lambda;
220         logLambda = source.logLambda;
221         logLambdaFactorial = source.logLambdaFactorial;
222         delta = source.delta;
223         halfDelta = source.halfDelta;
224         sqrtLambdaPlusHalfDelta = source.sqrtLambdaPlusHalfDelta;
225         twolpd = source.twolpd;
226         p1 = source.p1;
227         p2 = source.p2;
228         c1 = source.c1;
229 
230         // Share the state of the small sampler
231         smallMeanPoissonSampler = source.smallMeanPoissonSampler.withUniformRandomProvider(rng);
232     }
233 
234     /** {@inheritDoc} */
235     @Override
236     public int sample() {
237         // This will never be null. It may be a no-op delegate that returns zero.
238         final int y2 = smallMeanPoissonSampler.sample();
239 
240         double x;
241         double y;
242         double v;
243         int a;
244         double t;
245         double qr;
246         double qa;
247         while (true) {
248             // Step 1:
249             final double u = rng.nextDouble();
250             if (u <= p1) {
251                 // Step 2:
252                 final double n = gaussian.sample();
253                 x = n * sqrtLambdaPlusHalfDelta - 0.5d;
254                 if (x > delta || x < -lambda) {
255                     continue;
256                 }
257                 y = x < 0 ? Math.floor(x) : Math.ceil(x);
258                 final double e = exponential.sample();
259                 v = -e - 0.5 * n * n + c1;
260             } else {
261                 // Step 3:
262                 if (u > p1 + p2) {
263                     y = lambda;
264                     break;
265                 }
266                 x = delta + (twolpd / delta) * exponential.sample();
267                 y = Math.ceil(x);
268                 v = -exponential.sample() - delta * (x + 1) / twolpd;
269             }
270             // The Squeeze Principle
271             // Step 4.1:
272             a = x < 0 ? 1 : 0;
273             t = y * (y + 1) / (2 * lambda);
274             // Step 4.2
275             if (v < -t && a == 0) {
276                 y = lambda + y;
277                 break;
278             }
279             // Step 4.3:
280             qr = t * ((2 * y + 1) / (6 * lambda) - 1);
281             qa = qr - (t * t) / (3 * (lambda + a * (y + 1)));
282             // Step 4.4:
283             if (v < qa) {
284                 y = lambda + y;
285                 break;
286             }
287             // Step 4.5:
288             if (v > qr) {
289                 continue;
290             }
291             // Step 4.6:
292             if (v < y * logLambda - getFactorialLog((int) (y + lambda)) + logLambdaFactorial) {
293                 y = lambda + y;
294                 break;
295             }
296         }
297 
298         return (int) Math.min(y2 + (long) y, Integer.MAX_VALUE);
299     }
300 
301     /**
302      * Compute the natural logarithm of the factorial of {@code n}.
303      *
304      * @param n Argument.
305      * @return {@code log(n!)}
306      * @throws IllegalArgumentException if {@code n < 0}.
307      */
308     private double getFactorialLog(int n) {
309         return factorialLog.value(n);
310     }
311 
312     /** {@inheritDoc} */
313     @Override
314     public String toString() {
315         return "Large Mean Poisson deviate [" + rng.toString() + "]";
316     }
317 
318     /**
319      * {@inheritDoc}
320      *
321      * @since 1.3
322      */
323     @Override
324     public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
325         return new LargeMeanPoissonSampler(rng, this);
326     }
327 
328     /**
329      * Creates a new Poisson distribution sampler.
330      *
331      * @param rng Generator of uniformly distributed random numbers.
332      * @param mean Mean.
333      * @return the sampler
334      * @throws IllegalArgumentException if {@code mean < 1} or {@code mean > 0.5 *}
335      * {@link Integer#MAX_VALUE}.
336      * @since 1.3
337      */
338     public static SharedStateDiscreteSampler of(UniformRandomProvider rng,
339                                                 double mean) {
340         return new LargeMeanPoissonSampler(rng, mean);
341     }
342 
343     /**
344      * Gets the initialisation state of the sampler.
345      *
346      * <p>The state is computed using an integer {@code lambda} value of
347      * {@code lambda = (int)Math.floor(mean)}.
348      *
349      * <p>The state will be suitable for reconstructing a new sampler with a mean
350      * in the range {@code lambda <= mean < lambda+1} using
351      * {@link #LargeMeanPoissonSampler(UniformRandomProvider, LargeMeanPoissonSamplerState, double)}.
352      *
353      * @return the state
354      */
355     LargeMeanPoissonSamplerState getState() {
356         return new LargeMeanPoissonSamplerState(lambda, logLambda, logLambdaFactorial,
357                 delta, halfDelta, sqrtLambdaPlusHalfDelta, twolpd, p1, p2, c1);
358     }
359 
360     /**
361      * Encapsulate the state of the sampler. The state is valid for construction of
362      * a sampler in the range {@code lambda <= mean < lambda+1}.
363      *
364      * <p>This class is immutable.
365      *
366      * @see #getLambda()
367      */
368     static final class LargeMeanPoissonSamplerState {
369         /** Algorithm constant {@code lambda}. */
370         private final double lambda;
371         /** Algorithm constant {@code logLambda}. */
372         private final double logLambda;
373         /** Algorithm constant {@code logLambdaFactorial}. */
374         private final double logLambdaFactorial;
375         /** Algorithm constant {@code delta}. */
376         private final double delta;
377         /** Algorithm constant {@code halfDelta}. */
378         private final double halfDelta;
379         /** Algorithm constant {@code sqrtLambdaPlusHalfDelta}. */
380         private final double sqrtLambdaPlusHalfDelta;
381         /** Algorithm constant {@code twolpd}. */
382         private final double twolpd;
383         /** Algorithm constant {@code p1}. */
384         private final double p1;
385         /** Algorithm constant {@code p2}. */
386         private final double p2;
387         /** Algorithm constant {@code c1}. */
388         private final double c1;
389 
390         /**
391          * Creates the state.
392          *
393          * <p>The state is valid for construction of a sampler in the range
394          * {@code lambda <= mean < lambda+1} where {@code lambda} is an integer.
395          *
396          * @param lambda the lambda
397          * @param logLambda the log lambda
398          * @param logLambdaFactorial the log lambda factorial
399          * @param delta the delta
400          * @param halfDelta the half delta
401          * @param sqrtLambdaPlusHalfDelta the sqrt(lambda+half delta)
402          * @param twolpd the two lambda plus delta
403          * @param p1 the p1 constant
404          * @param p2 the p2 constant
405          * @param c1 the c1 constant
406          */
407         LargeMeanPoissonSamplerState(double lambda, double logLambda,
408                 double logLambdaFactorial, double delta, double halfDelta,
409                 double sqrtLambdaPlusHalfDelta, double twolpd,
410                 double p1, double p2, double c1) {
411             this.lambda = lambda;
412             this.logLambda = logLambda;
413             this.logLambdaFactorial = logLambdaFactorial;
414             this.delta = delta;
415             this.halfDelta = halfDelta;
416             this.sqrtLambdaPlusHalfDelta = sqrtLambdaPlusHalfDelta;
417             this.twolpd = twolpd;
418             this.p1 = p1;
419             this.p2 = p2;
420             this.c1 = c1;
421         }
422 
423         /**
424          * Get the lambda value for the state.
425          *
426          * <p>Equal to {@code floor(mean)} for a Poisson sampler.
427          * @return the lambda value
428          */
429         int getLambda() {
430             return (int) getLambdaRaw();
431         }
432 
433         /**
434          * @return algorithm constant {@code lambda}
435          */
436         double getLambdaRaw() {
437             return lambda;
438         }
439 
440         /**
441          * @return algorithm constant {@code logLambda}
442          */
443         double getLogLambda() {
444             return logLambda;
445         }
446 
447         /**
448          * @return algorithm constant {@code logLambdaFactorial}
449          */
450         double getLogLambdaFactorial() {
451             return logLambdaFactorial;
452         }
453 
454         /**
455          * @return algorithm constant {@code delta}
456          */
457         double getDelta() {
458             return delta;
459         }
460 
461         /**
462          * @return algorithm constant {@code halfDelta}
463          */
464         double getHalfDelta() {
465             return halfDelta;
466         }
467 
468         /**
469          * @return algorithm constant {@code sqrtLambdaPlusHalfDelta}
470          */
471         double getSqrtLambdaPlusHalfDelta() {
472             return sqrtLambdaPlusHalfDelta;
473         }
474 
475         /**
476          * @return algorithm constant {@code twolpd}
477          */
478         double getTwolpd() {
479             return twolpd;
480         }
481 
482         /**
483          * @return algorithm constant {@code p1}
484          */
485         double getP1() {
486             return p1;
487         }
488 
489         /**
490          * @return algorithm constant {@code p2}
491          */
492         double getP2() {
493             return p2;
494         }
495 
496         /**
497          * @return algorithm constant {@code c1}
498          */
499         double getC1() {
500             return c1;
501         }
502     }
503 }