1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17 package org.apache.commons.rng.sampling.distribution;
18
19 import org.apache.commons.rng.UniformRandomProvider;
20 import org.apache.commons.rng.sampling.distribution.InternalUtils.FactorialLog;
21
22 /**
23 * Sampler for the <a href="http://mathworld.wolfram.com/PoissonDistribution.html">Poisson distribution</a>.
24 *
25 * <ul>
26 * <li>
27 * For large means, we use the rejection algorithm described in
28 * <blockquote>
29 * Devroye, Luc. (1981).<i>The Computer Generation of Poisson Random Variables</i><br>
30 * <strong>Computing</strong> vol. 26 pp. 197-207.
31 * </blockquote>
32 * </li>
33 * </ul>
34 *
35 * <p>This sampler is suitable for {@code mean >= 40}.</p>
36 *
37 * <p>Sampling uses:</p>
38 *
39 * <ul>
40 * <li>{@link UniformRandomProvider#nextLong()}
41 * <li>{@link UniformRandomProvider#nextDouble()}
42 * </ul>
43 *
44 * @since 1.1
45 */
46 public class LargeMeanPoissonSampler
47 implements SharedStateDiscreteSampler {
48 /** Upper bound to avoid truncation. */
49 private static final double MAX_MEAN = 0.5 * Integer.MAX_VALUE;
50 /** Class to compute {@code log(n!)}. This has no cached values. */
51 private static final InternalUtils.FactorialLog NO_CACHE_FACTORIAL_LOG;
52 /** Used when there is no requirement for a small mean Poisson sampler. */
53 private static final SharedStateDiscreteSampler NO_SMALL_MEAN_POISSON_SAMPLER =
54 new SharedStateDiscreteSampler() {
55 @Override
56 public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
57 // No requirement for RNG
58 return this;
59 }
60
61 @Override
62 public int sample() {
63 // No Poisson sample
64 return 0;
65 }
66 };
67
68 static {
69 // Create without a cache.
70 NO_CACHE_FACTORIAL_LOG = FactorialLog.create();
71 }
72
73 /** Underlying source of randomness. */
74 private final UniformRandomProvider rng;
75 /** Exponential. */
76 private final SharedStateContinuousSampler exponential;
77 /** Gaussian. */
78 private final SharedStateContinuousSampler gaussian;
79 /** Local class to compute {@code log(n!)}. This may have cached values. */
80 private final InternalUtils.FactorialLog factorialLog;
81
82 // Working values
83
84 /** Algorithm constant: {@code Math.floor(mean)}. */
85 private final double lambda;
86 /** Algorithm constant: {@code Math.log(lambda)}. */
87 private final double logLambda;
88 /** Algorithm constant: {@code factorialLog((int) lambda)}. */
89 private final double logLambdaFactorial;
90 /** Algorithm constant: {@code Math.sqrt(lambda * Math.log(32 * lambda / Math.PI + 1))}. */
91 private final double delta;
92 /** Algorithm constant: {@code delta / 2}. */
93 private final double halfDelta;
94 /** Algorithm constant: {@code Math.sqrt(lambda + halfDelta)}. */
95 private final double sqrtLambdaPlusHalfDelta;
96 /** Algorithm constant: {@code 2 * lambda + delta}. */
97 private final double twolpd;
98 /**
99 * Algorithm constant: {@code a1 / aSum}.
100 * <ul>
101 * <li>{@code a1 = Math.sqrt(Math.PI * twolpd) * Math.exp(c1)}</li>
102 * <li>{@code aSum = a1 + a2 + 1}</li>
103 * </ul>
104 */
105 private final double p1;
106 /**
107 * Algorithm constant: {@code a2 / aSum}.
108 * <ul>
109 * <li>{@code a2 = (twolpd / delta) * Math.exp(-delta * (1 + delta) / twolpd)}</li>
110 * <li>{@code aSum = a1 + a2 + 1}</li>
111 * </ul>
112 */
113 private final double p2;
114 /** Algorithm constant: {@code 1 / (8 * lambda)}. */
115 private final double c1;
116
117 /** The internal Poisson sampler for the lambda fraction. */
118 private final SharedStateDiscreteSampler smallMeanPoissonSampler;
119
120
121 /**
122 * Create an instance.
123 *
124 * @param rng Generator of uniformly distributed random numbers.
125 * @param mean Mean.
126 * @throws IllegalArgumentException if {@code mean < 1} or
127 * {@code mean > 0.5 *} {@link Integer#MAX_VALUE}.
128 */
129 public LargeMeanPoissonSampler(UniformRandomProvider rng,
130 double mean) {
131 // Validation before java.lang.Object constructor exits prevents partially initialized object
132 this(InternalUtils.requireRangeClosed(1, MAX_MEAN, mean, "mean"), rng);
133 }
134
135 /**
136 * Instantiates a sampler using a precomputed state.
137 *
138 * @param rng Generator of uniformly distributed random numbers.
139 * @param state The state for {@code lambda = (int)Math.floor(mean)}.
140 * @param lambdaFractional The lambda fractional value
141 * ({@code mean - (int)Math.floor(mean))}.
142 * @throws IllegalArgumentException
143 * if {@code lambdaFractional < 0 || lambdaFractional >= 1}.
144 */
145 LargeMeanPoissonSampler(UniformRandomProvider rng,
146 LargeMeanPoissonSamplerState state,
147 double lambdaFractional) {
148 // Validation before java.lang.Object constructor exits prevents partially initialized object
149 this(state, InternalUtils.requireRange(0, 1, lambdaFractional, "lambdaFractional"), rng);
150 }
151
152 /**
153 * @param mean Mean.
154 * @param rng Generator of uniformly distributed random numbers.
155 */
156 private LargeMeanPoissonSampler(double mean,
157 UniformRandomProvider rng) {
158 this.rng = rng;
159
160 gaussian = ZigguratSampler.NormalizedGaussian.of(rng);
161 exponential = ZigguratSampler.Exponential.of(rng);
162 // Plain constructor uses the uncached function.
163 factorialLog = NO_CACHE_FACTORIAL_LOG;
164
165 // Cache values used in the algorithm
166 lambda = Math.floor(mean);
167 logLambda = Math.log(lambda);
168 logLambdaFactorial = getFactorialLog((int) lambda);
169 delta = Math.sqrt(lambda * Math.log(32 * lambda / Math.PI + 1));
170 halfDelta = delta / 2;
171 sqrtLambdaPlusHalfDelta = Math.sqrt(lambda + halfDelta);
172 twolpd = 2 * lambda + delta;
173 c1 = 1 / (8 * lambda);
174 final double a1 = Math.sqrt(Math.PI * twolpd) * Math.exp(c1);
175 final double a2 = (twolpd / delta) * Math.exp(-delta * (1 + delta) / twolpd);
176 final double aSum = a1 + a2 + 1;
177 p1 = a1 / aSum;
178 p2 = a2 / aSum;
179
180 // The algorithm requires a Poisson sample from the remaining lambda fraction.
181 final double lambdaFractional = mean - lambda;
182 smallMeanPoissonSampler = (lambdaFractional < Double.MIN_VALUE) ?
183 NO_SMALL_MEAN_POISSON_SAMPLER : // Not used.
184 KempSmallMeanPoissonSampler.of(rng, lambdaFractional);
185 }
186
187 /**
188 * Instantiates a sampler using a precomputed state.
189 *
190 * @param state The state for {@code lambda = (int)Math.floor(mean)}.
191 * @param lambdaFractional The lambda fractional value
192 * ({@code mean - (int)Math.floor(mean))}.
193 * @param rng Generator of uniformly distributed random numbers.
194 */
195 private LargeMeanPoissonSampler(LargeMeanPoissonSamplerState state,
196 double lambdaFractional,
197 UniformRandomProvider rng) {
198 this.rng = rng;
199
200 gaussian = ZigguratSampler.NormalizedGaussian.of(rng);
201 exponential = ZigguratSampler.Exponential.of(rng);
202 // Plain constructor uses the uncached function.
203 factorialLog = NO_CACHE_FACTORIAL_LOG;
204
205 // Use the state to initialize the algorithm
206 lambda = state.getLambdaRaw();
207 logLambda = state.getLogLambda();
208 logLambdaFactorial = state.getLogLambdaFactorial();
209 delta = state.getDelta();
210 halfDelta = state.getHalfDelta();
211 sqrtLambdaPlusHalfDelta = state.getSqrtLambdaPlusHalfDelta();
212 twolpd = state.getTwolpd();
213 p1 = state.getP1();
214 p2 = state.getP2();
215 c1 = state.getC1();
216
217 // The algorithm requires a Poisson sample from the remaining lambda fraction.
218 smallMeanPoissonSampler = (lambdaFractional < Double.MIN_VALUE) ?
219 NO_SMALL_MEAN_POISSON_SAMPLER : // Not used.
220 KempSmallMeanPoissonSampler.of(rng, lambdaFractional);
221 }
222
223 /**
224 * @param rng Generator of uniformly distributed random numbers.
225 * @param source Source to copy.
226 */
227 private LargeMeanPoissonSampler(UniformRandomProvider rng,
228 LargeMeanPoissonSampler source) {
229 this.rng = rng;
230
231 gaussian = source.gaussian.withUniformRandomProvider(rng);
232 exponential = source.exponential.withUniformRandomProvider(rng);
233 // Reuse the cache
234 factorialLog = source.factorialLog;
235
236 lambda = source.lambda;
237 logLambda = source.logLambda;
238 logLambdaFactorial = source.logLambdaFactorial;
239 delta = source.delta;
240 halfDelta = source.halfDelta;
241 sqrtLambdaPlusHalfDelta = source.sqrtLambdaPlusHalfDelta;
242 twolpd = source.twolpd;
243 p1 = source.p1;
244 p2 = source.p2;
245 c1 = source.c1;
246
247 // Share the state of the small sampler
248 smallMeanPoissonSampler = source.smallMeanPoissonSampler.withUniformRandomProvider(rng);
249 }
250
251 /** {@inheritDoc} */
252 @Override
253 public int sample() {
254 // This will never be null. It may be a no-op delegate that returns zero.
255 final int y2 = smallMeanPoissonSampler.sample();
256
257 double x;
258 double y;
259 double v;
260 int a;
261 double t;
262 double qr;
263 double qa;
264 while (true) {
265 // Step 1:
266 final double u = rng.nextDouble();
267 if (u <= p1) {
268 // Step 2:
269 final double n = gaussian.sample();
270 x = n * sqrtLambdaPlusHalfDelta - 0.5d;
271 if (x > delta || x < -lambda) {
272 continue;
273 }
274 y = x < 0 ? Math.floor(x) : Math.ceil(x);
275 final double e = exponential.sample();
276 v = -e - 0.5 * n * n + c1;
277 } else {
278 // Step 3:
279 if (u > p1 + p2) {
280 y = lambda;
281 break;
282 }
283 x = delta + (twolpd / delta) * exponential.sample();
284 y = Math.ceil(x);
285 v = -exponential.sample() - delta * (x + 1) / twolpd;
286 }
287 // The Squeeze Principle
288 // Step 4.1:
289 a = x < 0 ? 1 : 0;
290 t = y * (y + 1) / (2 * lambda);
291 // Step 4.2
292 if (v < -t && a == 0) {
293 y = lambda + y;
294 break;
295 }
296 // Step 4.3:
297 qr = t * ((2 * y + 1) / (6 * lambda) - 1);
298 qa = qr - (t * t) / (3 * (lambda + a * (y + 1)));
299 // Step 4.4:
300 if (v < qa) {
301 y = lambda + y;
302 break;
303 }
304 // Step 4.5:
305 if (v > qr) {
306 continue;
307 }
308 // Step 4.6:
309 if (v < y * logLambda - getFactorialLog((int) (y + lambda)) + logLambdaFactorial) {
310 y = lambda + y;
311 break;
312 }
313 }
314
315 return (int) Math.min(y2 + (long) y, Integer.MAX_VALUE);
316 }
317
318 /**
319 * Compute the natural logarithm of the factorial of {@code n}.
320 *
321 * @param n Argument.
322 * @return {@code log(n!)}
323 * @throws IllegalArgumentException if {@code n < 0}.
324 */
325 private double getFactorialLog(int n) {
326 return factorialLog.value(n);
327 }
328
329 /** {@inheritDoc} */
330 @Override
331 public String toString() {
332 return "Large Mean Poisson deviate [" + rng.toString() + "]";
333 }
334
335 /**
336 * {@inheritDoc}
337 *
338 * @since 1.3
339 */
340 @Override
341 public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
342 return new LargeMeanPoissonSampler(rng, this);
343 }
344
345 /**
346 * Creates a new Poisson distribution sampler.
347 *
348 * @param rng Generator of uniformly distributed random numbers.
349 * @param mean Mean.
350 * @return the sampler
351 * @throws IllegalArgumentException if {@code mean < 1} or {@code mean > 0.5 *}
352 * {@link Integer#MAX_VALUE}.
353 * @since 1.3
354 */
355 public static SharedStateDiscreteSampler of(UniformRandomProvider rng,
356 double mean) {
357 return new LargeMeanPoissonSampler(rng, mean);
358 }
359
360 /**
361 * Gets the initialisation state of the sampler.
362 *
363 * <p>The state is computed using an integer {@code lambda} value of
364 * {@code lambda = (int)Math.floor(mean)}.
365 *
366 * <p>The state will be suitable for reconstructing a new sampler with a mean
367 * in the range {@code lambda <= mean < lambda+1} using
368 * {@link #LargeMeanPoissonSampler(UniformRandomProvider, LargeMeanPoissonSamplerState, double)}.
369 *
370 * @return the state
371 */
372 LargeMeanPoissonSamplerState getState() {
373 return new LargeMeanPoissonSamplerState(lambda, logLambda, logLambdaFactorial,
374 delta, halfDelta, sqrtLambdaPlusHalfDelta, twolpd, p1, p2, c1);
375 }
376
377 /**
378 * Encapsulate the state of the sampler. The state is valid for construction of
379 * a sampler in the range {@code lambda <= mean < lambda+1}.
380 *
381 * <p>This class is immutable.
382 *
383 * @see #getLambda()
384 */
385 static final class LargeMeanPoissonSamplerState {
386 /** Algorithm constant {@code lambda}. */
387 private final double lambda;
388 /** Algorithm constant {@code logLambda}. */
389 private final double logLambda;
390 /** Algorithm constant {@code logLambdaFactorial}. */
391 private final double logLambdaFactorial;
392 /** Algorithm constant {@code delta}. */
393 private final double delta;
394 /** Algorithm constant {@code halfDelta}. */
395 private final double halfDelta;
396 /** Algorithm constant {@code sqrtLambdaPlusHalfDelta}. */
397 private final double sqrtLambdaPlusHalfDelta;
398 /** Algorithm constant {@code twolpd}. */
399 private final double twolpd;
400 /** Algorithm constant {@code p1}. */
401 private final double p1;
402 /** Algorithm constant {@code p2}. */
403 private final double p2;
404 /** Algorithm constant {@code c1}. */
405 private final double c1;
406
407 /**
408 * Creates the state.
409 *
410 * <p>The state is valid for construction of a sampler in the range
411 * {@code lambda <= mean < lambda+1} where {@code lambda} is an integer.
412 *
413 * @param lambda the lambda
414 * @param logLambda the log lambda
415 * @param logLambdaFactorial the log lambda factorial
416 * @param delta the delta
417 * @param halfDelta the half delta
418 * @param sqrtLambdaPlusHalfDelta the sqrt(lambda+half delta)
419 * @param twolpd the two lambda plus delta
420 * @param p1 the p1 constant
421 * @param p2 the p2 constant
422 * @param c1 the c1 constant
423 */
424 LargeMeanPoissonSamplerState(double lambda, double logLambda,
425 double logLambdaFactorial, double delta, double halfDelta,
426 double sqrtLambdaPlusHalfDelta, double twolpd,
427 double p1, double p2, double c1) {
428 this.lambda = lambda;
429 this.logLambda = logLambda;
430 this.logLambdaFactorial = logLambdaFactorial;
431 this.delta = delta;
432 this.halfDelta = halfDelta;
433 this.sqrtLambdaPlusHalfDelta = sqrtLambdaPlusHalfDelta;
434 this.twolpd = twolpd;
435 this.p1 = p1;
436 this.p2 = p2;
437 this.c1 = c1;
438 }
439
440 /**
441 * Get the lambda value for the state.
442 *
443 * <p>Equal to {@code floor(mean)} for a Poisson sampler.
444 * @return the lambda value
445 */
446 int getLambda() {
447 return (int) getLambdaRaw();
448 }
449
450 /**
451 * @return algorithm constant {@code lambda}
452 */
453 double getLambdaRaw() {
454 return lambda;
455 }
456
457 /**
458 * @return algorithm constant {@code logLambda}
459 */
460 double getLogLambda() {
461 return logLambda;
462 }
463
464 /**
465 * @return algorithm constant {@code logLambdaFactorial}
466 */
467 double getLogLambdaFactorial() {
468 return logLambdaFactorial;
469 }
470
471 /**
472 * @return algorithm constant {@code delta}
473 */
474 double getDelta() {
475 return delta;
476 }
477
478 /**
479 * @return algorithm constant {@code halfDelta}
480 */
481 double getHalfDelta() {
482 return halfDelta;
483 }
484
485 /**
486 * @return algorithm constant {@code sqrtLambdaPlusHalfDelta}
487 */
488 double getSqrtLambdaPlusHalfDelta() {
489 return sqrtLambdaPlusHalfDelta;
490 }
491
492 /**
493 * @return algorithm constant {@code twolpd}
494 */
495 double getTwolpd() {
496 return twolpd;
497 }
498
499 /**
500 * @return algorithm constant {@code p1}
501 */
502 double getP1() {
503 return p1;
504 }
505
506 /**
507 * @return algorithm constant {@code p2}
508 */
509 double getP2() {
510 return p2;
511 }
512
513 /**
514 * @return algorithm constant {@code c1}
515 */
516 double getC1() {
517 return c1;
518 }
519 }
520 }