001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.commons.rng.sampling.distribution; 018 019import org.apache.commons.rng.UniformRandomProvider; 020import org.apache.commons.rng.sampling.distribution.InternalUtils.FactorialLog; 021 022/** 023 * Sampler for the <a href="http://mathworld.wolfram.com/PoissonDistribution.html">Poisson distribution</a>. 024 * 025 * <ul> 026 * <li> 027 * For large means, we use the rejection algorithm described in 028 * <blockquote> 029 * Devroye, Luc. (1981).<i>The Computer Generation of Poisson Random Variables</i><br> 030 * <strong>Computing</strong> vol. 26 pp. 197-207. 031 * </blockquote> 032 * </li> 033 * </ul> 034 * 035 * <p>This sampler is suitable for {@code mean >= 40}.</p> 036 * 037 * <p>Sampling uses:</p> 038 * 039 * <ul> 040 * <li>{@link UniformRandomProvider#nextLong()} 041 * <li>{@link UniformRandomProvider#nextDouble()} 042 * </ul> 043 * 044 * @since 1.1 045 */ 046public class LargeMeanPoissonSampler 047 implements SharedStateDiscreteSampler { 048 /** Upper bound to avoid truncation. */ 049 private static final double MAX_MEAN = 0.5 * Integer.MAX_VALUE; 050 /** Class to compute {@code log(n!)}. This has no cached values. */ 051 private static final InternalUtils.FactorialLog NO_CACHE_FACTORIAL_LOG; 052 /** Used when there is no requirement for a small mean Poisson sampler. */ 053 private static final SharedStateDiscreteSampler NO_SMALL_MEAN_POISSON_SAMPLER = 054 new SharedStateDiscreteSampler() { 055 @Override 056 public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) { 057 // No requirement for RNG 058 return this; 059 } 060 061 @Override 062 public int sample() { 063 // No Poisson sample 064 return 0; 065 } 066 }; 067 068 static { 069 // Create without a cache. 070 NO_CACHE_FACTORIAL_LOG = FactorialLog.create(); 071 } 072 073 /** Underlying source of randomness. */ 074 private final UniformRandomProvider rng; 075 /** Exponential. */ 076 private final SharedStateContinuousSampler exponential; 077 /** Gaussian. */ 078 private final SharedStateContinuousSampler gaussian; 079 /** Local class to compute {@code log(n!)}. This may have cached values. */ 080 private final InternalUtils.FactorialLog factorialLog; 081 082 // Working values 083 084 /** Algorithm constant: {@code Math.floor(mean)}. */ 085 private final double lambda; 086 /** Algorithm constant: {@code Math.log(lambda)}. */ 087 private final double logLambda; 088 /** Algorithm constant: {@code factorialLog((int) lambda)}. */ 089 private final double logLambdaFactorial; 090 /** Algorithm constant: {@code Math.sqrt(lambda * Math.log(32 * lambda / Math.PI + 1))}. */ 091 private final double delta; 092 /** Algorithm constant: {@code delta / 2}. */ 093 private final double halfDelta; 094 /** Algorithm constant: {@code Math.sqrt(lambda + halfDelta)}. */ 095 private final double sqrtLambdaPlusHalfDelta; 096 /** Algorithm constant: {@code 2 * lambda + delta}. */ 097 private final double twolpd; 098 /** 099 * Algorithm constant: {@code a1 / aSum}. 100 * <ul> 101 * <li>{@code a1 = Math.sqrt(Math.PI * twolpd) * Math.exp(c1)}</li> 102 * <li>{@code aSum = a1 + a2 + 1}</li> 103 * </ul> 104 */ 105 private final double p1; 106 /** 107 * Algorithm constant: {@code a2 / aSum}. 108 * <ul> 109 * <li>{@code a2 = (twolpd / delta) * Math.exp(-delta * (1 + delta) / twolpd)}</li> 110 * <li>{@code aSum = a1 + a2 + 1}</li> 111 * </ul> 112 */ 113 private final double p2; 114 /** Algorithm constant: {@code 1 / (8 * lambda)}. */ 115 private final double c1; 116 117 /** The internal Poisson sampler for the lambda fraction. */ 118 private final SharedStateDiscreteSampler smallMeanPoissonSampler; 119 120 121 /** 122 * Create an instance. 123 * 124 * @param rng Generator of uniformly distributed random numbers. 125 * @param mean Mean. 126 * @throws IllegalArgumentException if {@code mean < 1} or 127 * {@code mean > 0.5 *} {@link Integer#MAX_VALUE}. 128 */ 129 public LargeMeanPoissonSampler(UniformRandomProvider rng, 130 double mean) { 131 // Validation before java.lang.Object constructor exits prevents partially initialized object 132 this(InternalUtils.requireRangeClosed(1, MAX_MEAN, mean, "mean"), rng); 133 } 134 135 /** 136 * Instantiates a sampler using a precomputed state. 137 * 138 * @param rng Generator of uniformly distributed random numbers. 139 * @param state The state for {@code lambda = (int)Math.floor(mean)}. 140 * @param lambdaFractional The lambda fractional value 141 * ({@code mean - (int)Math.floor(mean))}. 142 * @throws IllegalArgumentException 143 * if {@code lambdaFractional < 0 || lambdaFractional >= 1}. 144 */ 145 LargeMeanPoissonSampler(UniformRandomProvider rng, 146 LargeMeanPoissonSamplerState state, 147 double lambdaFractional) { 148 // Validation before java.lang.Object constructor exits prevents partially initialized object 149 this(state, InternalUtils.requireRange(0, 1, lambdaFractional, "lambdaFractional"), rng); 150 } 151 152 /** 153 * @param mean Mean. 154 * @param rng Generator of uniformly distributed random numbers. 155 */ 156 private LargeMeanPoissonSampler(double mean, 157 UniformRandomProvider rng) { 158 this.rng = rng; 159 160 gaussian = ZigguratSampler.NormalizedGaussian.of(rng); 161 exponential = ZigguratSampler.Exponential.of(rng); 162 // Plain constructor uses the uncached function. 163 factorialLog = NO_CACHE_FACTORIAL_LOG; 164 165 // Cache values used in the algorithm 166 lambda = Math.floor(mean); 167 logLambda = Math.log(lambda); 168 logLambdaFactorial = getFactorialLog((int) lambda); 169 delta = Math.sqrt(lambda * Math.log(32 * lambda / Math.PI + 1)); 170 halfDelta = delta / 2; 171 sqrtLambdaPlusHalfDelta = Math.sqrt(lambda + halfDelta); 172 twolpd = 2 * lambda + delta; 173 c1 = 1 / (8 * lambda); 174 final double a1 = Math.sqrt(Math.PI * twolpd) * Math.exp(c1); 175 final double a2 = (twolpd / delta) * Math.exp(-delta * (1 + delta) / twolpd); 176 final double aSum = a1 + a2 + 1; 177 p1 = a1 / aSum; 178 p2 = a2 / aSum; 179 180 // The algorithm requires a Poisson sample from the remaining lambda fraction. 181 final double lambdaFractional = mean - lambda; 182 smallMeanPoissonSampler = (lambdaFractional < Double.MIN_VALUE) ? 183 NO_SMALL_MEAN_POISSON_SAMPLER : // Not used. 184 KempSmallMeanPoissonSampler.of(rng, lambdaFractional); 185 } 186 187 /** 188 * Instantiates a sampler using a precomputed state. 189 * 190 * @param state The state for {@code lambda = (int)Math.floor(mean)}. 191 * @param lambdaFractional The lambda fractional value 192 * ({@code mean - (int)Math.floor(mean))}. 193 * @param rng Generator of uniformly distributed random numbers. 194 */ 195 private LargeMeanPoissonSampler(LargeMeanPoissonSamplerState state, 196 double lambdaFractional, 197 UniformRandomProvider rng) { 198 this.rng = rng; 199 200 gaussian = ZigguratSampler.NormalizedGaussian.of(rng); 201 exponential = ZigguratSampler.Exponential.of(rng); 202 // Plain constructor uses the uncached function. 203 factorialLog = NO_CACHE_FACTORIAL_LOG; 204 205 // Use the state to initialize the algorithm 206 lambda = state.getLambdaRaw(); 207 logLambda = state.getLogLambda(); 208 logLambdaFactorial = state.getLogLambdaFactorial(); 209 delta = state.getDelta(); 210 halfDelta = state.getHalfDelta(); 211 sqrtLambdaPlusHalfDelta = state.getSqrtLambdaPlusHalfDelta(); 212 twolpd = state.getTwolpd(); 213 p1 = state.getP1(); 214 p2 = state.getP2(); 215 c1 = state.getC1(); 216 217 // The algorithm requires a Poisson sample from the remaining lambda fraction. 218 smallMeanPoissonSampler = (lambdaFractional < Double.MIN_VALUE) ? 219 NO_SMALL_MEAN_POISSON_SAMPLER : // Not used. 220 KempSmallMeanPoissonSampler.of(rng, lambdaFractional); 221 } 222 223 /** 224 * @param rng Generator of uniformly distributed random numbers. 225 * @param source Source to copy. 226 */ 227 private LargeMeanPoissonSampler(UniformRandomProvider rng, 228 LargeMeanPoissonSampler source) { 229 this.rng = rng; 230 231 gaussian = source.gaussian.withUniformRandomProvider(rng); 232 exponential = source.exponential.withUniformRandomProvider(rng); 233 // Reuse the cache 234 factorialLog = source.factorialLog; 235 236 lambda = source.lambda; 237 logLambda = source.logLambda; 238 logLambdaFactorial = source.logLambdaFactorial; 239 delta = source.delta; 240 halfDelta = source.halfDelta; 241 sqrtLambdaPlusHalfDelta = source.sqrtLambdaPlusHalfDelta; 242 twolpd = source.twolpd; 243 p1 = source.p1; 244 p2 = source.p2; 245 c1 = source.c1; 246 247 // Share the state of the small sampler 248 smallMeanPoissonSampler = source.smallMeanPoissonSampler.withUniformRandomProvider(rng); 249 } 250 251 /** {@inheritDoc} */ 252 @Override 253 public int sample() { 254 // This will never be null. It may be a no-op delegate that returns zero. 255 final int y2 = smallMeanPoissonSampler.sample(); 256 257 double x; 258 double y; 259 double v; 260 int a; 261 double t; 262 double qr; 263 double qa; 264 while (true) { 265 // Step 1: 266 final double u = rng.nextDouble(); 267 if (u <= p1) { 268 // Step 2: 269 final double n = gaussian.sample(); 270 x = n * sqrtLambdaPlusHalfDelta - 0.5d; 271 if (x > delta || x < -lambda) { 272 continue; 273 } 274 y = x < 0 ? Math.floor(x) : Math.ceil(x); 275 final double e = exponential.sample(); 276 v = -e - 0.5 * n * n + c1; 277 } else { 278 // Step 3: 279 if (u > p1 + p2) { 280 y = lambda; 281 break; 282 } 283 x = delta + (twolpd / delta) * exponential.sample(); 284 y = Math.ceil(x); 285 v = -exponential.sample() - delta * (x + 1) / twolpd; 286 } 287 // The Squeeze Principle 288 // Step 4.1: 289 a = x < 0 ? 1 : 0; 290 t = y * (y + 1) / (2 * lambda); 291 // Step 4.2 292 if (v < -t && a == 0) { 293 y = lambda + y; 294 break; 295 } 296 // Step 4.3: 297 qr = t * ((2 * y + 1) / (6 * lambda) - 1); 298 qa = qr - (t * t) / (3 * (lambda + a * (y + 1))); 299 // Step 4.4: 300 if (v < qa) { 301 y = lambda + y; 302 break; 303 } 304 // Step 4.5: 305 if (v > qr) { 306 continue; 307 } 308 // Step 4.6: 309 if (v < y * logLambda - getFactorialLog((int) (y + lambda)) + logLambdaFactorial) { 310 y = lambda + y; 311 break; 312 } 313 } 314 315 return (int) Math.min(y2 + (long) y, Integer.MAX_VALUE); 316 } 317 318 /** 319 * Compute the natural logarithm of the factorial of {@code n}. 320 * 321 * @param n Argument. 322 * @return {@code log(n!)} 323 * @throws IllegalArgumentException if {@code n < 0}. 324 */ 325 private double getFactorialLog(int n) { 326 return factorialLog.value(n); 327 } 328 329 /** {@inheritDoc} */ 330 @Override 331 public String toString() { 332 return "Large Mean Poisson deviate [" + rng.toString() + "]"; 333 } 334 335 /** 336 * {@inheritDoc} 337 * 338 * @since 1.3 339 */ 340 @Override 341 public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) { 342 return new LargeMeanPoissonSampler(rng, this); 343 } 344 345 /** 346 * Creates a new Poisson distribution sampler. 347 * 348 * @param rng Generator of uniformly distributed random numbers. 349 * @param mean Mean. 350 * @return the sampler 351 * @throws IllegalArgumentException if {@code mean < 1} or {@code mean > 0.5 *} 352 * {@link Integer#MAX_VALUE}. 353 * @since 1.3 354 */ 355 public static SharedStateDiscreteSampler of(UniformRandomProvider rng, 356 double mean) { 357 return new LargeMeanPoissonSampler(rng, mean); 358 } 359 360 /** 361 * Gets the initialisation state of the sampler. 362 * 363 * <p>The state is computed using an integer {@code lambda} value of 364 * {@code lambda = (int)Math.floor(mean)}. 365 * 366 * <p>The state will be suitable for reconstructing a new sampler with a mean 367 * in the range {@code lambda <= mean < lambda+1} using 368 * {@link #LargeMeanPoissonSampler(UniformRandomProvider, LargeMeanPoissonSamplerState, double)}. 369 * 370 * @return the state 371 */ 372 LargeMeanPoissonSamplerState getState() { 373 return new LargeMeanPoissonSamplerState(lambda, logLambda, logLambdaFactorial, 374 delta, halfDelta, sqrtLambdaPlusHalfDelta, twolpd, p1, p2, c1); 375 } 376 377 /** 378 * Encapsulate the state of the sampler. The state is valid for construction of 379 * a sampler in the range {@code lambda <= mean < lambda+1}. 380 * 381 * <p>This class is immutable. 382 * 383 * @see #getLambda() 384 */ 385 static final class LargeMeanPoissonSamplerState { 386 /** Algorithm constant {@code lambda}. */ 387 private final double lambda; 388 /** Algorithm constant {@code logLambda}. */ 389 private final double logLambda; 390 /** Algorithm constant {@code logLambdaFactorial}. */ 391 private final double logLambdaFactorial; 392 /** Algorithm constant {@code delta}. */ 393 private final double delta; 394 /** Algorithm constant {@code halfDelta}. */ 395 private final double halfDelta; 396 /** Algorithm constant {@code sqrtLambdaPlusHalfDelta}. */ 397 private final double sqrtLambdaPlusHalfDelta; 398 /** Algorithm constant {@code twolpd}. */ 399 private final double twolpd; 400 /** Algorithm constant {@code p1}. */ 401 private final double p1; 402 /** Algorithm constant {@code p2}. */ 403 private final double p2; 404 /** Algorithm constant {@code c1}. */ 405 private final double c1; 406 407 /** 408 * Creates the state. 409 * 410 * <p>The state is valid for construction of a sampler in the range 411 * {@code lambda <= mean < lambda+1} where {@code lambda} is an integer. 412 * 413 * @param lambda the lambda 414 * @param logLambda the log lambda 415 * @param logLambdaFactorial the log lambda factorial 416 * @param delta the delta 417 * @param halfDelta the half delta 418 * @param sqrtLambdaPlusHalfDelta the sqrt(lambda+half delta) 419 * @param twolpd the two lambda plus delta 420 * @param p1 the p1 constant 421 * @param p2 the p2 constant 422 * @param c1 the c1 constant 423 */ 424 LargeMeanPoissonSamplerState(double lambda, double logLambda, 425 double logLambdaFactorial, double delta, double halfDelta, 426 double sqrtLambdaPlusHalfDelta, double twolpd, 427 double p1, double p2, double c1) { 428 this.lambda = lambda; 429 this.logLambda = logLambda; 430 this.logLambdaFactorial = logLambdaFactorial; 431 this.delta = delta; 432 this.halfDelta = halfDelta; 433 this.sqrtLambdaPlusHalfDelta = sqrtLambdaPlusHalfDelta; 434 this.twolpd = twolpd; 435 this.p1 = p1; 436 this.p2 = p2; 437 this.c1 = c1; 438 } 439 440 /** 441 * Get the lambda value for the state. 442 * 443 * <p>Equal to {@code floor(mean)} for a Poisson sampler. 444 * @return the lambda value 445 */ 446 int getLambda() { 447 return (int) getLambdaRaw(); 448 } 449 450 /** 451 * @return algorithm constant {@code lambda} 452 */ 453 double getLambdaRaw() { 454 return lambda; 455 } 456 457 /** 458 * @return algorithm constant {@code logLambda} 459 */ 460 double getLogLambda() { 461 return logLambda; 462 } 463 464 /** 465 * @return algorithm constant {@code logLambdaFactorial} 466 */ 467 double getLogLambdaFactorial() { 468 return logLambdaFactorial; 469 } 470 471 /** 472 * @return algorithm constant {@code delta} 473 */ 474 double getDelta() { 475 return delta; 476 } 477 478 /** 479 * @return algorithm constant {@code halfDelta} 480 */ 481 double getHalfDelta() { 482 return halfDelta; 483 } 484 485 /** 486 * @return algorithm constant {@code sqrtLambdaPlusHalfDelta} 487 */ 488 double getSqrtLambdaPlusHalfDelta() { 489 return sqrtLambdaPlusHalfDelta; 490 } 491 492 /** 493 * @return algorithm constant {@code twolpd} 494 */ 495 double getTwolpd() { 496 return twolpd; 497 } 498 499 /** 500 * @return algorithm constant {@code p1} 501 */ 502 double getP1() { 503 return p1; 504 } 505 506 /** 507 * @return algorithm constant {@code p2} 508 */ 509 double getP2() { 510 return p2; 511 } 512 513 /** 514 * @return algorithm constant {@code c1} 515 */ 516 double getC1() { 517 return c1; 518 } 519 } 520}