001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.commons.rng.sampling.distribution; 018 019import org.apache.commons.rng.UniformRandomProvider; 020import org.apache.commons.rng.sampling.distribution.InternalUtils.FactorialLog; 021 022/** 023 * Sampler for the <a href="http://mathworld.wolfram.com/PoissonDistribution.html">Poisson distribution</a>. 024 * 025 * <ul> 026 * <li> 027 * For large means, we use the rejection algorithm described in 028 * <blockquote> 029 * Devroye, Luc. (1981).<i>The Computer Generation of Poisson Random Variables</i><br> 030 * <strong>Computing</strong> vol. 26 pp. 197-207. 031 * </blockquote> 032 * </li> 033 * </ul> 034 * 035 * <p>This sampler is suitable for {@code mean >= 40}.</p> 036 * 037 * <p>Sampling uses:</p> 038 * 039 * <ul> 040 * <li>{@link UniformRandomProvider#nextLong()} 041 * <li>{@link UniformRandomProvider#nextDouble()} 042 * </ul> 043 * 044 * @since 1.1 045 */ 046public class LargeMeanPoissonSampler 047 implements SharedStateDiscreteSampler { 048 /** Upper bound to avoid truncation. */ 049 private static final double MAX_MEAN = 0.5 * Integer.MAX_VALUE; 050 /** Class to compute {@code log(n!)}. This has no cached values. */ 051 private static final InternalUtils.FactorialLog NO_CACHE_FACTORIAL_LOG; 052 /** Used when there is no requirement for a small mean Poisson sampler. */ 053 private static final SharedStateDiscreteSampler NO_SMALL_MEAN_POISSON_SAMPLER = 054 new SharedStateDiscreteSampler() { 055 @Override 056 public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) { 057 // No requirement for RNG 058 return this; 059 } 060 061 @Override 062 public int sample() { 063 // No Poisson sample 064 return 0; 065 } 066 }; 067 068 static { 069 // Create without a cache. 070 NO_CACHE_FACTORIAL_LOG = FactorialLog.create(); 071 } 072 073 /** Underlying source of randomness. */ 074 private final UniformRandomProvider rng; 075 /** Exponential. */ 076 private final SharedStateContinuousSampler exponential; 077 /** Gaussian. */ 078 private final SharedStateContinuousSampler gaussian; 079 /** Local class to compute {@code log(n!)}. This may have cached values. */ 080 private final InternalUtils.FactorialLog factorialLog; 081 082 // Working values 083 084 /** Algorithm constant: {@code Math.floor(mean)}. */ 085 private final double lambda; 086 /** Algorithm constant: {@code Math.log(lambda)}. */ 087 private final double logLambda; 088 /** Algorithm constant: {@code factorialLog((int) lambda)}. */ 089 private final double logLambdaFactorial; 090 /** Algorithm constant: {@code Math.sqrt(lambda * Math.log(32 * lambda / Math.PI + 1))}. */ 091 private final double delta; 092 /** Algorithm constant: {@code delta / 2}. */ 093 private final double halfDelta; 094 /** Algorithm constant: {@code Math.sqrt(lambda + halfDelta)}. */ 095 private final double sqrtLambdaPlusHalfDelta; 096 /** Algorithm constant: {@code 2 * lambda + delta}. */ 097 private final double twolpd; 098 /** 099 * Algorithm constant: {@code a1 / aSum}. 100 * <ul> 101 * <li>{@code a1 = Math.sqrt(Math.PI * twolpd) * Math.exp(c1)}</li> 102 * <li>{@code aSum = a1 + a2 + 1}</li> 103 * </ul> 104 */ 105 private final double p1; 106 /** 107 * Algorithm constant: {@code a2 / aSum}. 108 * <ul> 109 * <li>{@code a2 = (twolpd / delta) * Math.exp(-delta * (1 + delta) / twolpd)}</li> 110 * <li>{@code aSum = a1 + a2 + 1}</li> 111 * </ul> 112 */ 113 private final double p2; 114 /** Algorithm constant: {@code 1 / (8 * lambda)}. */ 115 private final double c1; 116 117 /** The internal Poisson sampler for the lambda fraction. */ 118 private final SharedStateDiscreteSampler smallMeanPoissonSampler; 119 120 /** 121 * @param rng Generator of uniformly distributed random numbers. 122 * @param mean Mean. 123 * @throws IllegalArgumentException if {@code mean < 1} or 124 * {@code mean > 0.5 *} {@link Integer#MAX_VALUE}. 125 */ 126 public LargeMeanPoissonSampler(UniformRandomProvider rng, 127 double mean) { 128 if (mean < 1) { 129 throw new IllegalArgumentException("mean is not >= 1: " + mean); 130 } 131 // The algorithm is not valid if Math.floor(mean) is not an integer. 132 if (mean > MAX_MEAN) { 133 throw new IllegalArgumentException("mean " + mean + " > " + MAX_MEAN); 134 } 135 this.rng = rng; 136 137 gaussian = ZigguratSampler.NormalizedGaussian.of(rng); 138 exponential = ZigguratSampler.Exponential.of(rng); 139 // Plain constructor uses the uncached function. 140 factorialLog = NO_CACHE_FACTORIAL_LOG; 141 142 // Cache values used in the algorithm 143 lambda = Math.floor(mean); 144 logLambda = Math.log(lambda); 145 logLambdaFactorial = getFactorialLog((int) lambda); 146 delta = Math.sqrt(lambda * Math.log(32 * lambda / Math.PI + 1)); 147 halfDelta = delta / 2; 148 sqrtLambdaPlusHalfDelta = Math.sqrt(lambda + halfDelta); 149 twolpd = 2 * lambda + delta; 150 c1 = 1 / (8 * lambda); 151 final double a1 = Math.sqrt(Math.PI * twolpd) * Math.exp(c1); 152 final double a2 = (twolpd / delta) * Math.exp(-delta * (1 + delta) / twolpd); 153 final double aSum = a1 + a2 + 1; 154 p1 = a1 / aSum; 155 p2 = a2 / aSum; 156 157 // The algorithm requires a Poisson sample from the remaining lambda fraction. 158 final double lambdaFractional = mean - lambda; 159 smallMeanPoissonSampler = (lambdaFractional < Double.MIN_VALUE) ? 160 NO_SMALL_MEAN_POISSON_SAMPLER : // Not used. 161 KempSmallMeanPoissonSampler.of(rng, lambdaFractional); 162 } 163 164 /** 165 * Instantiates a sampler using a precomputed state. 166 * 167 * @param rng Generator of uniformly distributed random numbers. 168 * @param state The state for {@code lambda = (int)Math.floor(mean)}. 169 * @param lambdaFractional The lambda fractional value 170 * ({@code mean - (int)Math.floor(mean))}. 171 * @throws IllegalArgumentException 172 * if {@code lambdaFractional < 0 || lambdaFractional >= 1}. 173 */ 174 LargeMeanPoissonSampler(UniformRandomProvider rng, 175 LargeMeanPoissonSamplerState state, 176 double lambdaFractional) { 177 if (lambdaFractional < 0 || lambdaFractional >= 1) { 178 throw new IllegalArgumentException( 179 "lambdaFractional must be in the range 0 (inclusive) to 1 (exclusive): " + lambdaFractional); 180 } 181 this.rng = rng; 182 183 gaussian = ZigguratSampler.NormalizedGaussian.of(rng); 184 exponential = ZigguratSampler.Exponential.of(rng); 185 // Plain constructor uses the uncached function. 186 factorialLog = NO_CACHE_FACTORIAL_LOG; 187 188 // Use the state to initialize the algorithm 189 lambda = state.getLambdaRaw(); 190 logLambda = state.getLogLambda(); 191 logLambdaFactorial = state.getLogLambdaFactorial(); 192 delta = state.getDelta(); 193 halfDelta = state.getHalfDelta(); 194 sqrtLambdaPlusHalfDelta = state.getSqrtLambdaPlusHalfDelta(); 195 twolpd = state.getTwolpd(); 196 p1 = state.getP1(); 197 p2 = state.getP2(); 198 c1 = state.getC1(); 199 200 // The algorithm requires a Poisson sample from the remaining lambda fraction. 201 smallMeanPoissonSampler = (lambdaFractional < Double.MIN_VALUE) ? 202 NO_SMALL_MEAN_POISSON_SAMPLER : // Not used. 203 KempSmallMeanPoissonSampler.of(rng, lambdaFractional); 204 } 205 206 /** 207 * @param rng Generator of uniformly distributed random numbers. 208 * @param source Source to copy. 209 */ 210 private LargeMeanPoissonSampler(UniformRandomProvider rng, 211 LargeMeanPoissonSampler source) { 212 this.rng = rng; 213 214 gaussian = source.gaussian.withUniformRandomProvider(rng); 215 exponential = source.exponential.withUniformRandomProvider(rng); 216 // Reuse the cache 217 factorialLog = source.factorialLog; 218 219 lambda = source.lambda; 220 logLambda = source.logLambda; 221 logLambdaFactorial = source.logLambdaFactorial; 222 delta = source.delta; 223 halfDelta = source.halfDelta; 224 sqrtLambdaPlusHalfDelta = source.sqrtLambdaPlusHalfDelta; 225 twolpd = source.twolpd; 226 p1 = source.p1; 227 p2 = source.p2; 228 c1 = source.c1; 229 230 // Share the state of the small sampler 231 smallMeanPoissonSampler = source.smallMeanPoissonSampler.withUniformRandomProvider(rng); 232 } 233 234 /** {@inheritDoc} */ 235 @Override 236 public int sample() { 237 // This will never be null. It may be a no-op delegate that returns zero. 238 final int y2 = smallMeanPoissonSampler.sample(); 239 240 double x; 241 double y; 242 double v; 243 int a; 244 double t; 245 double qr; 246 double qa; 247 while (true) { 248 // Step 1: 249 final double u = rng.nextDouble(); 250 if (u <= p1) { 251 // Step 2: 252 final double n = gaussian.sample(); 253 x = n * sqrtLambdaPlusHalfDelta - 0.5d; 254 if (x > delta || x < -lambda) { 255 continue; 256 } 257 y = x < 0 ? Math.floor(x) : Math.ceil(x); 258 final double e = exponential.sample(); 259 v = -e - 0.5 * n * n + c1; 260 } else { 261 // Step 3: 262 if (u > p1 + p2) { 263 y = lambda; 264 break; 265 } 266 x = delta + (twolpd / delta) * exponential.sample(); 267 y = Math.ceil(x); 268 v = -exponential.sample() - delta * (x + 1) / twolpd; 269 } 270 // The Squeeze Principle 271 // Step 4.1: 272 a = x < 0 ? 1 : 0; 273 t = y * (y + 1) / (2 * lambda); 274 // Step 4.2 275 if (v < -t && a == 0) { 276 y = lambda + y; 277 break; 278 } 279 // Step 4.3: 280 qr = t * ((2 * y + 1) / (6 * lambda) - 1); 281 qa = qr - (t * t) / (3 * (lambda + a * (y + 1))); 282 // Step 4.4: 283 if (v < qa) { 284 y = lambda + y; 285 break; 286 } 287 // Step 4.5: 288 if (v > qr) { 289 continue; 290 } 291 // Step 4.6: 292 if (v < y * logLambda - getFactorialLog((int) (y + lambda)) + logLambdaFactorial) { 293 y = lambda + y; 294 break; 295 } 296 } 297 298 return (int) Math.min(y2 + (long) y, Integer.MAX_VALUE); 299 } 300 301 /** 302 * Compute the natural logarithm of the factorial of {@code n}. 303 * 304 * @param n Argument. 305 * @return {@code log(n!)} 306 * @throws IllegalArgumentException if {@code n < 0}. 307 */ 308 private double getFactorialLog(int n) { 309 return factorialLog.value(n); 310 } 311 312 /** {@inheritDoc} */ 313 @Override 314 public String toString() { 315 return "Large Mean Poisson deviate [" + rng.toString() + "]"; 316 } 317 318 /** 319 * {@inheritDoc} 320 * 321 * @since 1.3 322 */ 323 @Override 324 public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) { 325 return new LargeMeanPoissonSampler(rng, this); 326 } 327 328 /** 329 * Creates a new Poisson distribution sampler. 330 * 331 * @param rng Generator of uniformly distributed random numbers. 332 * @param mean Mean. 333 * @return the sampler 334 * @throws IllegalArgumentException if {@code mean < 1} or {@code mean > 0.5 *} 335 * {@link Integer#MAX_VALUE}. 336 * @since 1.3 337 */ 338 public static SharedStateDiscreteSampler of(UniformRandomProvider rng, 339 double mean) { 340 return new LargeMeanPoissonSampler(rng, mean); 341 } 342 343 /** 344 * Gets the initialisation state of the sampler. 345 * 346 * <p>The state is computed using an integer {@code lambda} value of 347 * {@code lambda = (int)Math.floor(mean)}. 348 * 349 * <p>The state will be suitable for reconstructing a new sampler with a mean 350 * in the range {@code lambda <= mean < lambda+1} using 351 * {@link #LargeMeanPoissonSampler(UniformRandomProvider, LargeMeanPoissonSamplerState, double)}. 352 * 353 * @return the state 354 */ 355 LargeMeanPoissonSamplerState getState() { 356 return new LargeMeanPoissonSamplerState(lambda, logLambda, logLambdaFactorial, 357 delta, halfDelta, sqrtLambdaPlusHalfDelta, twolpd, p1, p2, c1); 358 } 359 360 /** 361 * Encapsulate the state of the sampler. The state is valid for construction of 362 * a sampler in the range {@code lambda <= mean < lambda+1}. 363 * 364 * <p>This class is immutable. 365 * 366 * @see #getLambda() 367 */ 368 static final class LargeMeanPoissonSamplerState { 369 /** Algorithm constant {@code lambda}. */ 370 private final double lambda; 371 /** Algorithm constant {@code logLambda}. */ 372 private final double logLambda; 373 /** Algorithm constant {@code logLambdaFactorial}. */ 374 private final double logLambdaFactorial; 375 /** Algorithm constant {@code delta}. */ 376 private final double delta; 377 /** Algorithm constant {@code halfDelta}. */ 378 private final double halfDelta; 379 /** Algorithm constant {@code sqrtLambdaPlusHalfDelta}. */ 380 private final double sqrtLambdaPlusHalfDelta; 381 /** Algorithm constant {@code twolpd}. */ 382 private final double twolpd; 383 /** Algorithm constant {@code p1}. */ 384 private final double p1; 385 /** Algorithm constant {@code p2}. */ 386 private final double p2; 387 /** Algorithm constant {@code c1}. */ 388 private final double c1; 389 390 /** 391 * Creates the state. 392 * 393 * <p>The state is valid for construction of a sampler in the range 394 * {@code lambda <= mean < lambda+1} where {@code lambda} is an integer. 395 * 396 * @param lambda the lambda 397 * @param logLambda the log lambda 398 * @param logLambdaFactorial the log lambda factorial 399 * @param delta the delta 400 * @param halfDelta the half delta 401 * @param sqrtLambdaPlusHalfDelta the sqrt(lambda+half delta) 402 * @param twolpd the two lambda plus delta 403 * @param p1 the p1 constant 404 * @param p2 the p2 constant 405 * @param c1 the c1 constant 406 */ 407 LargeMeanPoissonSamplerState(double lambda, double logLambda, 408 double logLambdaFactorial, double delta, double halfDelta, 409 double sqrtLambdaPlusHalfDelta, double twolpd, 410 double p1, double p2, double c1) { 411 this.lambda = lambda; 412 this.logLambda = logLambda; 413 this.logLambdaFactorial = logLambdaFactorial; 414 this.delta = delta; 415 this.halfDelta = halfDelta; 416 this.sqrtLambdaPlusHalfDelta = sqrtLambdaPlusHalfDelta; 417 this.twolpd = twolpd; 418 this.p1 = p1; 419 this.p2 = p2; 420 this.c1 = c1; 421 } 422 423 /** 424 * Get the lambda value for the state. 425 * 426 * <p>Equal to {@code floor(mean)} for a Poisson sampler. 427 * @return the lambda value 428 */ 429 int getLambda() { 430 return (int) getLambdaRaw(); 431 } 432 433 /** 434 * @return algorithm constant {@code lambda} 435 */ 436 double getLambdaRaw() { 437 return lambda; 438 } 439 440 /** 441 * @return algorithm constant {@code logLambda} 442 */ 443 double getLogLambda() { 444 return logLambda; 445 } 446 447 /** 448 * @return algorithm constant {@code logLambdaFactorial} 449 */ 450 double getLogLambdaFactorial() { 451 return logLambdaFactorial; 452 } 453 454 /** 455 * @return algorithm constant {@code delta} 456 */ 457 double getDelta() { 458 return delta; 459 } 460 461 /** 462 * @return algorithm constant {@code halfDelta} 463 */ 464 double getHalfDelta() { 465 return halfDelta; 466 } 467 468 /** 469 * @return algorithm constant {@code sqrtLambdaPlusHalfDelta} 470 */ 471 double getSqrtLambdaPlusHalfDelta() { 472 return sqrtLambdaPlusHalfDelta; 473 } 474 475 /** 476 * @return algorithm constant {@code twolpd} 477 */ 478 double getTwolpd() { 479 return twolpd; 480 } 481 482 /** 483 * @return algorithm constant {@code p1} 484 */ 485 double getP1() { 486 return p1; 487 } 488 489 /** 490 * @return algorithm constant {@code p2} 491 */ 492 double getP2() { 493 return p2; 494 } 495 496 /** 497 * @return algorithm constant {@code c1} 498 */ 499 double getC1() { 500 return c1; 501 } 502 } 503}