1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.rng.sampling.distribution;
18
19 import org.apache.commons.rng.UniformRandomProvider;
20 import org.apache.commons.rng.sampling.distribution.InternalUtils.FactorialLog;
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46 public class LargeMeanPoissonSampler
47 implements SharedStateDiscreteSampler {
48
49 private static final double MAX_MEAN = 0.5 * Integer.MAX_VALUE;
50
51 private static final InternalUtils.FactorialLog NO_CACHE_FACTORIAL_LOG;
52
53 private static final SharedStateDiscreteSampler NO_SMALL_MEAN_POISSON_SAMPLER =
54 new SharedStateDiscreteSampler() {
55 @Override
56 public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
57
58 return this;
59 }
60
61 @Override
62 public int sample() {
63
64 return 0;
65 }
66 };
67
68 static {
69
70 NO_CACHE_FACTORIAL_LOG = FactorialLog.create();
71 }
72
73
74 private final UniformRandomProvider rng;
75
76 private final SharedStateContinuousSampler exponential;
77
78 private final SharedStateContinuousSampler gaussian;
79
80 private final InternalUtils.FactorialLog factorialLog;
81
82
83
84
85 private final double lambda;
86
87 private final double logLambda;
88
89 private final double logLambdaFactorial;
90
91 private final double delta;
92
93 private final double halfDelta;
94
95 private final double sqrtLambdaPlusHalfDelta;
96
97 private final double twolpd;
98
99
100
101
102
103
104
105 private final double p1;
106
107
108
109
110
111
112
113 private final double p2;
114
115 private final double c1;
116
117
118 private final SharedStateDiscreteSampler smallMeanPoissonSampler;
119
120
121
122
123
124
125
126
127
128
129 public LargeMeanPoissonSampler(UniformRandomProvider rng,
130 double mean) {
131
132 this(InternalUtils.requireRangeClosed(1, MAX_MEAN, mean, "mean"), rng);
133 }
134
135
136
137
138
139
140
141
142
143
144
145 LargeMeanPoissonSampler(UniformRandomProvider rng,
146 LargeMeanPoissonSamplerState state,
147 double lambdaFractional) {
148
149 this(state, InternalUtils.requireRange(0, 1, lambdaFractional, "lambdaFractional"), rng);
150 }
151
152
153
154
155
156 private LargeMeanPoissonSampler(double mean,
157 UniformRandomProvider rng) {
158 this.rng = rng;
159
160 gaussian = ZigguratSampler.NormalizedGaussian.of(rng);
161 exponential = ZigguratSampler.Exponential.of(rng);
162
163 factorialLog = NO_CACHE_FACTORIAL_LOG;
164
165
166 lambda = Math.floor(mean);
167 logLambda = Math.log(lambda);
168 logLambdaFactorial = getFactorialLog((int) lambda);
169 delta = Math.sqrt(lambda * Math.log(32 * lambda / Math.PI + 1));
170 halfDelta = delta / 2;
171 sqrtLambdaPlusHalfDelta = Math.sqrt(lambda + halfDelta);
172 twolpd = 2 * lambda + delta;
173 c1 = 1 / (8 * lambda);
174 final double a1 = Math.sqrt(Math.PI * twolpd) * Math.exp(c1);
175 final double a2 = (twolpd / delta) * Math.exp(-delta * (1 + delta) / twolpd);
176 final double aSum = a1 + a2 + 1;
177 p1 = a1 / aSum;
178 p2 = a2 / aSum;
179
180
181 final double lambdaFractional = mean - lambda;
182 smallMeanPoissonSampler = (lambdaFractional < Double.MIN_VALUE) ?
183 NO_SMALL_MEAN_POISSON_SAMPLER :
184 KempSmallMeanPoissonSampler.of(rng, lambdaFractional);
185 }
186
187
188
189
190
191
192
193
194
195 private LargeMeanPoissonSampler(LargeMeanPoissonSamplerState state,
196 double lambdaFractional,
197 UniformRandomProvider rng) {
198 this.rng = rng;
199
200 gaussian = ZigguratSampler.NormalizedGaussian.of(rng);
201 exponential = ZigguratSampler.Exponential.of(rng);
202
203 factorialLog = NO_CACHE_FACTORIAL_LOG;
204
205
206 lambda = state.getLambdaRaw();
207 logLambda = state.getLogLambda();
208 logLambdaFactorial = state.getLogLambdaFactorial();
209 delta = state.getDelta();
210 halfDelta = state.getHalfDelta();
211 sqrtLambdaPlusHalfDelta = state.getSqrtLambdaPlusHalfDelta();
212 twolpd = state.getTwolpd();
213 p1 = state.getP1();
214 p2 = state.getP2();
215 c1 = state.getC1();
216
217
218 smallMeanPoissonSampler = (lambdaFractional < Double.MIN_VALUE) ?
219 NO_SMALL_MEAN_POISSON_SAMPLER :
220 KempSmallMeanPoissonSampler.of(rng, lambdaFractional);
221 }
222
223
224
225
226
227 private LargeMeanPoissonSampler(UniformRandomProvider rng,
228 LargeMeanPoissonSampler source) {
229 this.rng = rng;
230
231 gaussian = source.gaussian.withUniformRandomProvider(rng);
232 exponential = source.exponential.withUniformRandomProvider(rng);
233
234 factorialLog = source.factorialLog;
235
236 lambda = source.lambda;
237 logLambda = source.logLambda;
238 logLambdaFactorial = source.logLambdaFactorial;
239 delta = source.delta;
240 halfDelta = source.halfDelta;
241 sqrtLambdaPlusHalfDelta = source.sqrtLambdaPlusHalfDelta;
242 twolpd = source.twolpd;
243 p1 = source.p1;
244 p2 = source.p2;
245 c1 = source.c1;
246
247
248 smallMeanPoissonSampler = source.smallMeanPoissonSampler.withUniformRandomProvider(rng);
249 }
250
251
252 @Override
253 public int sample() {
254
255 final int y2 = smallMeanPoissonSampler.sample();
256
257 double x;
258 double y;
259 double v;
260 int a;
261 double t;
262 double qr;
263 double qa;
264 while (true) {
265
266 final double u = rng.nextDouble();
267 if (u <= p1) {
268
269 final double n = gaussian.sample();
270 x = n * sqrtLambdaPlusHalfDelta - 0.5d;
271 if (x > delta || x < -lambda) {
272 continue;
273 }
274 y = x < 0 ? Math.floor(x) : Math.ceil(x);
275 final double e = exponential.sample();
276 v = -e - 0.5 * n * n + c1;
277 } else {
278
279 if (u > p1 + p2) {
280 y = lambda;
281 break;
282 }
283 x = delta + (twolpd / delta) * exponential.sample();
284 y = Math.ceil(x);
285 v = -exponential.sample() - delta * (x + 1) / twolpd;
286 }
287
288
289 a = x < 0 ? 1 : 0;
290 t = y * (y + 1) / (2 * lambda);
291
292 if (v < -t && a == 0) {
293 y = lambda + y;
294 break;
295 }
296
297 qr = t * ((2 * y + 1) / (6 * lambda) - 1);
298 qa = qr - (t * t) / (3 * (lambda + a * (y + 1)));
299
300 if (v < qa) {
301 y = lambda + y;
302 break;
303 }
304
305 if (v > qr) {
306 continue;
307 }
308
309 if (v < y * logLambda - getFactorialLog((int) (y + lambda)) + logLambdaFactorial) {
310 y = lambda + y;
311 break;
312 }
313 }
314
315 return (int) Math.min(y2 + (long) y, Integer.MAX_VALUE);
316 }
317
318
319
320
321
322
323
324
325 private double getFactorialLog(int n) {
326 return factorialLog.value(n);
327 }
328
329
330 @Override
331 public String toString() {
332 return "Large Mean Poisson deviate [" + rng.toString() + "]";
333 }
334
335
336
337
338
339
340 @Override
341 public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
342 return new LargeMeanPoissonSampler(rng, this);
343 }
344
345
346
347
348
349
350
351
352
353
354
355 public static SharedStateDiscreteSampler of(UniformRandomProvider rng,
356 double mean) {
357 return new LargeMeanPoissonSampler(rng, mean);
358 }
359
360
361
362
363
364
365
366
367
368
369
370
371
372 LargeMeanPoissonSamplerState getState() {
373 return new LargeMeanPoissonSamplerState(lambda, logLambda, logLambdaFactorial,
374 delta, halfDelta, sqrtLambdaPlusHalfDelta, twolpd, p1, p2, c1);
375 }
376
377
378
379
380
381
382
383
384
385 static final class LargeMeanPoissonSamplerState {
386
387 private final double lambda;
388
389 private final double logLambda;
390
391 private final double logLambdaFactorial;
392
393 private final double delta;
394
395 private final double halfDelta;
396
397 private final double sqrtLambdaPlusHalfDelta;
398
399 private final double twolpd;
400
401 private final double p1;
402
403 private final double p2;
404
405 private final double c1;
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424 LargeMeanPoissonSamplerState(double lambda, double logLambda,
425 double logLambdaFactorial, double delta, double halfDelta,
426 double sqrtLambdaPlusHalfDelta, double twolpd,
427 double p1, double p2, double c1) {
428 this.lambda = lambda;
429 this.logLambda = logLambda;
430 this.logLambdaFactorial = logLambdaFactorial;
431 this.delta = delta;
432 this.halfDelta = halfDelta;
433 this.sqrtLambdaPlusHalfDelta = sqrtLambdaPlusHalfDelta;
434 this.twolpd = twolpd;
435 this.p1 = p1;
436 this.p2 = p2;
437 this.c1 = c1;
438 }
439
440
441
442
443
444
445
446 int getLambda() {
447 return (int) getLambdaRaw();
448 }
449
450
451
452
453 double getLambdaRaw() {
454 return lambda;
455 }
456
457
458
459
460 double getLogLambda() {
461 return logLambda;
462 }
463
464
465
466
467 double getLogLambdaFactorial() {
468 return logLambdaFactorial;
469 }
470
471
472
473
474 double getDelta() {
475 return delta;
476 }
477
478
479
480
481 double getHalfDelta() {
482 return halfDelta;
483 }
484
485
486
487
488 double getSqrtLambdaPlusHalfDelta() {
489 return sqrtLambdaPlusHalfDelta;
490 }
491
492
493
494
495 double getTwolpd() {
496 return twolpd;
497 }
498
499
500
501
502 double getP1() {
503 return p1;
504 }
505
506
507
508
509 double getP2() {
510 return p2;
511 }
512
513
514
515
516 double getC1() {
517 return c1;
518 }
519 }
520 }