1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.rng.sampling.distribution;
18
19 import org.apache.commons.rng.UniformRandomProvider;
20 import org.apache.commons.rng.sampling.distribution.InternalUtils.FactorialLog;
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46 public class LargeMeanPoissonSampler
47 implements SharedStateDiscreteSampler {
48
49 private static final double MAX_MEAN = 0.5 * Integer.MAX_VALUE;
50
51 private static final InternalUtils.FactorialLog NO_CACHE_FACTORIAL_LOG;
52
53 private static final SharedStateDiscreteSampler NO_SMALL_MEAN_POISSON_SAMPLER =
54 new SharedStateDiscreteSampler() {
55 @Override
56 public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
57
58 return this;
59 }
60
61 @Override
62 public int sample() {
63
64 return 0;
65 }
66 };
67
68 static {
69
70 NO_CACHE_FACTORIAL_LOG = FactorialLog.create();
71 }
72
73
74 private final UniformRandomProvider rng;
75
76 private final SharedStateContinuousSampler exponential;
77
78 private final SharedStateContinuousSampler gaussian;
79
80 private final InternalUtils.FactorialLog factorialLog;
81
82
83
84
85 private final double lambda;
86
87 private final double logLambda;
88
89 private final double logLambdaFactorial;
90
91 private final double delta;
92
93 private final double halfDelta;
94
95 private final double sqrtLambdaPlusHalfDelta;
96
97 private final double twolpd;
98
99
100
101
102
103
104
105 private final double p1;
106
107
108
109
110
111
112
113 private final double p2;
114
115 private final double c1;
116
117
118 private final SharedStateDiscreteSampler smallMeanPoissonSampler;
119
120
121
122
123
124
125
126 public LargeMeanPoissonSampler(UniformRandomProvider rng,
127 double mean) {
128 if (mean < 1) {
129 throw new IllegalArgumentException("mean is not >= 1: " + mean);
130 }
131
132 if (mean > MAX_MEAN) {
133 throw new IllegalArgumentException("mean " + mean + " > " + MAX_MEAN);
134 }
135 this.rng = rng;
136
137 gaussian = ZigguratSampler.NormalizedGaussian.of(rng);
138 exponential = ZigguratSampler.Exponential.of(rng);
139
140 factorialLog = NO_CACHE_FACTORIAL_LOG;
141
142
143 lambda = Math.floor(mean);
144 logLambda = Math.log(lambda);
145 logLambdaFactorial = getFactorialLog((int) lambda);
146 delta = Math.sqrt(lambda * Math.log(32 * lambda / Math.PI + 1));
147 halfDelta = delta / 2;
148 sqrtLambdaPlusHalfDelta = Math.sqrt(lambda + halfDelta);
149 twolpd = 2 * lambda + delta;
150 c1 = 1 / (8 * lambda);
151 final double a1 = Math.sqrt(Math.PI * twolpd) * Math.exp(c1);
152 final double a2 = (twolpd / delta) * Math.exp(-delta * (1 + delta) / twolpd);
153 final double aSum = a1 + a2 + 1;
154 p1 = a1 / aSum;
155 p2 = a2 / aSum;
156
157
158 final double lambdaFractional = mean - lambda;
159 smallMeanPoissonSampler = (lambdaFractional < Double.MIN_VALUE) ?
160 NO_SMALL_MEAN_POISSON_SAMPLER :
161 KempSmallMeanPoissonSampler.of(rng, lambdaFractional);
162 }
163
164
165
166
167
168
169
170
171
172
173
174 LargeMeanPoissonSampler(UniformRandomProvider rng,
175 LargeMeanPoissonSamplerState state,
176 double lambdaFractional) {
177 if (lambdaFractional < 0 || lambdaFractional >= 1) {
178 throw new IllegalArgumentException(
179 "lambdaFractional must be in the range 0 (inclusive) to 1 (exclusive): " + lambdaFractional);
180 }
181 this.rng = rng;
182
183 gaussian = ZigguratSampler.NormalizedGaussian.of(rng);
184 exponential = ZigguratSampler.Exponential.of(rng);
185
186 factorialLog = NO_CACHE_FACTORIAL_LOG;
187
188
189 lambda = state.getLambdaRaw();
190 logLambda = state.getLogLambda();
191 logLambdaFactorial = state.getLogLambdaFactorial();
192 delta = state.getDelta();
193 halfDelta = state.getHalfDelta();
194 sqrtLambdaPlusHalfDelta = state.getSqrtLambdaPlusHalfDelta();
195 twolpd = state.getTwolpd();
196 p1 = state.getP1();
197 p2 = state.getP2();
198 c1 = state.getC1();
199
200
201 smallMeanPoissonSampler = (lambdaFractional < Double.MIN_VALUE) ?
202 NO_SMALL_MEAN_POISSON_SAMPLER :
203 KempSmallMeanPoissonSampler.of(rng, lambdaFractional);
204 }
205
206
207
208
209
210 private LargeMeanPoissonSampler(UniformRandomProvider rng,
211 LargeMeanPoissonSampler source) {
212 this.rng = rng;
213
214 gaussian = source.gaussian.withUniformRandomProvider(rng);
215 exponential = source.exponential.withUniformRandomProvider(rng);
216
217 factorialLog = source.factorialLog;
218
219 lambda = source.lambda;
220 logLambda = source.logLambda;
221 logLambdaFactorial = source.logLambdaFactorial;
222 delta = source.delta;
223 halfDelta = source.halfDelta;
224 sqrtLambdaPlusHalfDelta = source.sqrtLambdaPlusHalfDelta;
225 twolpd = source.twolpd;
226 p1 = source.p1;
227 p2 = source.p2;
228 c1 = source.c1;
229
230
231 smallMeanPoissonSampler = source.smallMeanPoissonSampler.withUniformRandomProvider(rng);
232 }
233
234
235 @Override
236 public int sample() {
237
238 final int y2 = smallMeanPoissonSampler.sample();
239
240 double x;
241 double y;
242 double v;
243 int a;
244 double t;
245 double qr;
246 double qa;
247 while (true) {
248
249 final double u = rng.nextDouble();
250 if (u <= p1) {
251
252 final double n = gaussian.sample();
253 x = n * sqrtLambdaPlusHalfDelta - 0.5d;
254 if (x > delta || x < -lambda) {
255 continue;
256 }
257 y = x < 0 ? Math.floor(x) : Math.ceil(x);
258 final double e = exponential.sample();
259 v = -e - 0.5 * n * n + c1;
260 } else {
261
262 if (u > p1 + p2) {
263 y = lambda;
264 break;
265 }
266 x = delta + (twolpd / delta) * exponential.sample();
267 y = Math.ceil(x);
268 v = -exponential.sample() - delta * (x + 1) / twolpd;
269 }
270
271
272 a = x < 0 ? 1 : 0;
273 t = y * (y + 1) / (2 * lambda);
274
275 if (v < -t && a == 0) {
276 y = lambda + y;
277 break;
278 }
279
280 qr = t * ((2 * y + 1) / (6 * lambda) - 1);
281 qa = qr - (t * t) / (3 * (lambda + a * (y + 1)));
282
283 if (v < qa) {
284 y = lambda + y;
285 break;
286 }
287
288 if (v > qr) {
289 continue;
290 }
291
292 if (v < y * logLambda - getFactorialLog((int) (y + lambda)) + logLambdaFactorial) {
293 y = lambda + y;
294 break;
295 }
296 }
297
298 return (int) Math.min(y2 + (long) y, Integer.MAX_VALUE);
299 }
300
301
302
303
304
305
306
307
308 private double getFactorialLog(int n) {
309 return factorialLog.value(n);
310 }
311
312
313 @Override
314 public String toString() {
315 return "Large Mean Poisson deviate [" + rng.toString() + "]";
316 }
317
318
319
320
321
322
323 @Override
324 public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
325 return new LargeMeanPoissonSampler(rng, this);
326 }
327
328
329
330
331
332
333
334
335
336
337
338 public static SharedStateDiscreteSampler of(UniformRandomProvider rng,
339 double mean) {
340 return new LargeMeanPoissonSampler(rng, mean);
341 }
342
343
344
345
346
347
348
349
350
351
352
353
354
355 LargeMeanPoissonSamplerState getState() {
356 return new LargeMeanPoissonSamplerState(lambda, logLambda, logLambdaFactorial,
357 delta, halfDelta, sqrtLambdaPlusHalfDelta, twolpd, p1, p2, c1);
358 }
359
360
361
362
363
364
365
366
367
368 static final class LargeMeanPoissonSamplerState {
369
370 private final double lambda;
371
372 private final double logLambda;
373
374 private final double logLambdaFactorial;
375
376 private final double delta;
377
378 private final double halfDelta;
379
380 private final double sqrtLambdaPlusHalfDelta;
381
382 private final double twolpd;
383
384 private final double p1;
385
386 private final double p2;
387
388 private final double c1;
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407 LargeMeanPoissonSamplerState(double lambda, double logLambda,
408 double logLambdaFactorial, double delta, double halfDelta,
409 double sqrtLambdaPlusHalfDelta, double twolpd,
410 double p1, double p2, double c1) {
411 this.lambda = lambda;
412 this.logLambda = logLambda;
413 this.logLambdaFactorial = logLambdaFactorial;
414 this.delta = delta;
415 this.halfDelta = halfDelta;
416 this.sqrtLambdaPlusHalfDelta = sqrtLambdaPlusHalfDelta;
417 this.twolpd = twolpd;
418 this.p1 = p1;
419 this.p2 = p2;
420 this.c1 = c1;
421 }
422
423
424
425
426
427
428
429 int getLambda() {
430 return (int) getLambdaRaw();
431 }
432
433
434
435
436 double getLambdaRaw() {
437 return lambda;
438 }
439
440
441
442
443 double getLogLambda() {
444 return logLambda;
445 }
446
447
448
449
450 double getLogLambdaFactorial() {
451 return logLambdaFactorial;
452 }
453
454
455
456
457 double getDelta() {
458 return delta;
459 }
460
461
462
463
464 double getHalfDelta() {
465 return halfDelta;
466 }
467
468
469
470
471 double getSqrtLambdaPlusHalfDelta() {
472 return sqrtLambdaPlusHalfDelta;
473 }
474
475
476
477
478 double getTwolpd() {
479 return twolpd;
480 }
481
482
483
484
485 double getP1() {
486 return p1;
487 }
488
489
490
491
492 double getP2() {
493 return p2;
494 }
495
496
497
498
499 double getC1() {
500 return c1;
501 }
502 }
503 }