001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.commons.rng.sampling.distribution; 018 019import org.apache.commons.rng.UniformRandomProvider; 020import org.apache.commons.rng.sampling.distribution.InternalUtils.FactorialLog; 021 022/** 023 * Sampler for the <a href="http://mathworld.wolfram.com/PoissonDistribution.html">Poisson distribution</a>. 024 * 025 * <ul> 026 * <li> 027 * For large means, we use the rejection algorithm described in 028 * <blockquote> 029 * Devroye, Luc. (1981).<i>The Computer Generation of Poisson Random Variables</i><br> 030 * <strong>Computing</strong> vol. 26 pp. 197-207. 031 * </blockquote> 032 * </li> 033 * </ul> 034 * 035 * <p>This sampler is suitable for {@code mean >= 40}.</p> 036 * 037 * <p>Sampling uses:</p> 038 * 039 * <ul> 040 * <li>{@link UniformRandomProvider#nextLong()} 041 * <li>{@link UniformRandomProvider#nextDouble()} 042 * </ul> 043 * 044 * @since 1.1 045 */ 046public class LargeMeanPoissonSampler 047 implements SharedStateDiscreteSampler { 048 /** Upper bound to avoid truncation. */ 049 private static final double MAX_MEAN = 0.5 * Integer.MAX_VALUE; 050 /** Class to compute {@code log(n!)}. This has no cached values. */ 051 private static final InternalUtils.FactorialLog NO_CACHE_FACTORIAL_LOG; 052 /** Used when there is no requirement for a small mean Poisson sampler. */ 053 private static final SharedStateDiscreteSampler NO_SMALL_MEAN_POISSON_SAMPLER = 054 new SharedStateDiscreteSampler() { 055 @Override 056 public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) { 057 // No requirement for RNG 058 return this; 059 } 060 061 @Override 062 public int sample() { 063 // No Poisson sample 064 return 0; 065 } 066 }; 067 068 static { 069 // Create without a cache. 070 NO_CACHE_FACTORIAL_LOG = FactorialLog.create(); 071 } 072 073 /** Underlying source of randomness. */ 074 private final UniformRandomProvider rng; 075 /** Exponential. */ 076 private final SharedStateContinuousSampler exponential; 077 /** Gaussian. */ 078 private final SharedStateContinuousSampler gaussian; 079 /** Local class to compute {@code log(n!)}. This may have cached values. */ 080 private final InternalUtils.FactorialLog factorialLog; 081 082 // Working values 083 084 /** Algorithm constant: {@code Math.floor(mean)}. */ 085 private final double lambda; 086 /** Algorithm constant: {@code Math.log(lambda)}. */ 087 private final double logLambda; 088 /** Algorithm constant: {@code factorialLog((int) lambda)}. */ 089 private final double logLambdaFactorial; 090 /** Algorithm constant: {@code Math.sqrt(lambda * Math.log(32 * lambda / Math.PI + 1))}. */ 091 private final double delta; 092 /** Algorithm constant: {@code delta / 2}. */ 093 private final double halfDelta; 094 /** Algorithm constant: {@code 2 * lambda + delta}. */ 095 private final double twolpd; 096 /** 097 * Algorithm constant: {@code a1 / aSum}. 098 * <ul> 099 * <li>{@code a1 = Math.sqrt(Math.PI * twolpd) * Math.exp(c1)}</li> 100 * <li>{@code aSum = a1 + a2 + 1}</li> 101 * </ul> 102 */ 103 private final double p1; 104 /** 105 * Algorithm constant: {@code a2 / aSum}. 106 * <ul> 107 * <li>{@code a2 = (twolpd / delta) * Math.exp(-delta * (1 + delta) / twolpd)}</li> 108 * <li>{@code aSum = a1 + a2 + 1}</li> 109 * </ul> 110 */ 111 private final double p2; 112 /** Algorithm constant: {@code 1 / (8 * lambda)}. */ 113 private final double c1; 114 115 /** The internal Poisson sampler for the lambda fraction. */ 116 private final SharedStateDiscreteSampler smallMeanPoissonSampler; 117 118 /** 119 * @param rng Generator of uniformly distributed random numbers. 120 * @param mean Mean. 121 * @throws IllegalArgumentException if {@code mean < 1} or 122 * {@code mean > 0.5 *} {@link Integer#MAX_VALUE}. 123 */ 124 public LargeMeanPoissonSampler(UniformRandomProvider rng, 125 double mean) { 126 if (mean < 1) { 127 throw new IllegalArgumentException("mean is not >= 1: " + mean); 128 } 129 // The algorithm is not valid if Math.floor(mean) is not an integer. 130 if (mean > MAX_MEAN) { 131 throw new IllegalArgumentException("mean " + mean + " > " + MAX_MEAN); 132 } 133 this.rng = rng; 134 135 gaussian = ZigguratSampler.NormalizedGaussian.of(rng); 136 exponential = ZigguratSampler.Exponential.of(rng); 137 // Plain constructor uses the uncached function. 138 factorialLog = NO_CACHE_FACTORIAL_LOG; 139 140 // Cache values used in the algorithm 141 lambda = Math.floor(mean); 142 logLambda = Math.log(lambda); 143 logLambdaFactorial = getFactorialLog((int) lambda); 144 delta = Math.sqrt(lambda * Math.log(32 * lambda / Math.PI + 1)); 145 halfDelta = delta / 2; 146 twolpd = 2 * lambda + delta; 147 c1 = 1 / (8 * lambda); 148 final double a1 = Math.sqrt(Math.PI * twolpd) * Math.exp(c1); 149 final double a2 = (twolpd / delta) * Math.exp(-delta * (1 + delta) / twolpd); 150 final double aSum = a1 + a2 + 1; 151 p1 = a1 / aSum; 152 p2 = a2 / aSum; 153 154 // The algorithm requires a Poisson sample from the remaining lambda fraction. 155 final double lambdaFractional = mean - lambda; 156 smallMeanPoissonSampler = (lambdaFractional < Double.MIN_VALUE) ? 157 NO_SMALL_MEAN_POISSON_SAMPLER : // Not used. 158 KempSmallMeanPoissonSampler.of(rng, lambdaFractional); 159 } 160 161 /** 162 * Instantiates a sampler using a precomputed state. 163 * 164 * @param rng Generator of uniformly distributed random numbers. 165 * @param state The state for {@code lambda = (int)Math.floor(mean)}. 166 * @param lambdaFractional The lambda fractional value 167 * ({@code mean - (int)Math.floor(mean))}. 168 * @throws IllegalArgumentException 169 * if {@code lambdaFractional < 0 || lambdaFractional >= 1}. 170 */ 171 LargeMeanPoissonSampler(UniformRandomProvider rng, 172 LargeMeanPoissonSamplerState state, 173 double lambdaFractional) { 174 if (lambdaFractional < 0 || lambdaFractional >= 1) { 175 throw new IllegalArgumentException( 176 "lambdaFractional must be in the range 0 (inclusive) to 1 (exclusive): " + lambdaFractional); 177 } 178 this.rng = rng; 179 180 gaussian = ZigguratSampler.NormalizedGaussian.of(rng); 181 exponential = ZigguratSampler.Exponential.of(rng); 182 // Plain constructor uses the uncached function. 183 factorialLog = NO_CACHE_FACTORIAL_LOG; 184 185 // Use the state to initialize the algorithm 186 lambda = state.getLambdaRaw(); 187 logLambda = state.getLogLambda(); 188 logLambdaFactorial = state.getLogLambdaFactorial(); 189 delta = state.getDelta(); 190 halfDelta = state.getHalfDelta(); 191 twolpd = state.getTwolpd(); 192 p1 = state.getP1(); 193 p2 = state.getP2(); 194 c1 = state.getC1(); 195 196 // The algorithm requires a Poisson sample from the remaining lambda fraction. 197 smallMeanPoissonSampler = (lambdaFractional < Double.MIN_VALUE) ? 198 NO_SMALL_MEAN_POISSON_SAMPLER : // Not used. 199 KempSmallMeanPoissonSampler.of(rng, lambdaFractional); 200 } 201 202 /** 203 * @param rng Generator of uniformly distributed random numbers. 204 * @param source Source to copy. 205 */ 206 private LargeMeanPoissonSampler(UniformRandomProvider rng, 207 LargeMeanPoissonSampler source) { 208 this.rng = rng; 209 210 gaussian = source.gaussian.withUniformRandomProvider(rng); 211 exponential = source.exponential.withUniformRandomProvider(rng); 212 // Reuse the cache 213 factorialLog = source.factorialLog; 214 215 lambda = source.lambda; 216 logLambda = source.logLambda; 217 logLambdaFactorial = source.logLambdaFactorial; 218 delta = source.delta; 219 halfDelta = source.halfDelta; 220 twolpd = source.twolpd; 221 p1 = source.p1; 222 p2 = source.p2; 223 c1 = source.c1; 224 225 // Share the state of the small sampler 226 smallMeanPoissonSampler = source.smallMeanPoissonSampler.withUniformRandomProvider(rng); 227 } 228 229 /** {@inheritDoc} */ 230 @Override 231 public int sample() { 232 // This will never be null. It may be a no-op delegate that returns zero. 233 final int y2 = smallMeanPoissonSampler.sample(); 234 235 double x; 236 double y; 237 double v; 238 int a; 239 double t; 240 double qr; 241 double qa; 242 while (true) { 243 // Step 1: 244 final double u = rng.nextDouble(); 245 if (u <= p1) { 246 // Step 2: 247 final double n = gaussian.sample(); 248 x = n * Math.sqrt(lambda + halfDelta) - 0.5d; 249 if (x > delta || x < -lambda) { 250 continue; 251 } 252 y = x < 0 ? Math.floor(x) : Math.ceil(x); 253 final double e = exponential.sample(); 254 v = -e - 0.5 * n * n + c1; 255 } else { 256 // Step 3: 257 if (u > p1 + p2) { 258 y = lambda; 259 break; 260 } 261 x = delta + (twolpd / delta) * exponential.sample(); 262 y = Math.ceil(x); 263 v = -exponential.sample() - delta * (x + 1) / twolpd; 264 } 265 // The Squeeze Principle 266 // Step 4.1: 267 a = x < 0 ? 1 : 0; 268 t = y * (y + 1) / (2 * lambda); 269 // Step 4.2 270 if (v < -t && a == 0) { 271 y = lambda + y; 272 break; 273 } 274 // Step 4.3: 275 qr = t * ((2 * y + 1) / (6 * lambda) - 1); 276 qa = qr - (t * t) / (3 * (lambda + a * (y + 1))); 277 // Step 4.4: 278 if (v < qa) { 279 y = lambda + y; 280 break; 281 } 282 // Step 4.5: 283 if (v > qr) { 284 continue; 285 } 286 // Step 4.6: 287 if (v < y * logLambda - getFactorialLog((int) (y + lambda)) + logLambdaFactorial) { 288 y = lambda + y; 289 break; 290 } 291 } 292 293 return (int) Math.min(y2 + (long) y, Integer.MAX_VALUE); 294 } 295 296 /** 297 * Compute the natural logarithm of the factorial of {@code n}. 298 * 299 * @param n Argument. 300 * @return {@code log(n!)} 301 * @throws IllegalArgumentException if {@code n < 0}. 302 */ 303 private double getFactorialLog(int n) { 304 return factorialLog.value(n); 305 } 306 307 /** {@inheritDoc} */ 308 @Override 309 public String toString() { 310 return "Large Mean Poisson deviate [" + rng.toString() + "]"; 311 } 312 313 /** 314 * {@inheritDoc} 315 * 316 * @since 1.3 317 */ 318 @Override 319 public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) { 320 return new LargeMeanPoissonSampler(rng, this); 321 } 322 323 /** 324 * Creates a new Poisson distribution sampler. 325 * 326 * @param rng Generator of uniformly distributed random numbers. 327 * @param mean Mean. 328 * @return the sampler 329 * @throws IllegalArgumentException if {@code mean < 1} or {@code mean > 0.5 *} 330 * {@link Integer#MAX_VALUE}. 331 * @since 1.3 332 */ 333 public static SharedStateDiscreteSampler of(UniformRandomProvider rng, 334 double mean) { 335 return new LargeMeanPoissonSampler(rng, mean); 336 } 337 /** 338 * Gets the initialisation state of the sampler. 339 * 340 * <p>The state is computed using an integer {@code lambda} value of 341 * {@code lambda = (int)Math.floor(mean)}. 342 * 343 * <p>The state will be suitable for reconstructing a new sampler with a mean 344 * in the range {@code lambda <= mean < lambda+1} using 345 * {@link #LargeMeanPoissonSampler(UniformRandomProvider, LargeMeanPoissonSamplerState, double)}. 346 * 347 * @return the state 348 */ 349 LargeMeanPoissonSamplerState getState() { 350 return new LargeMeanPoissonSamplerState(lambda, logLambda, logLambdaFactorial, 351 delta, halfDelta, twolpd, p1, p2, c1); 352 } 353 354 /** 355 * Encapsulate the state of the sampler. The state is valid for construction of 356 * a sampler in the range {@code lambda <= mean < lambda+1}. 357 * 358 * <p>This class is immutable. 359 * 360 * @see #getLambda() 361 */ 362 static final class LargeMeanPoissonSamplerState { 363 /** Algorithm constant {@code lambda}. */ 364 private final double lambda; 365 /** Algorithm constant {@code logLambda}. */ 366 private final double logLambda; 367 /** Algorithm constant {@code logLambdaFactorial}. */ 368 private final double logLambdaFactorial; 369 /** Algorithm constant {@code delta}. */ 370 private final double delta; 371 /** Algorithm constant {@code halfDelta}. */ 372 private final double halfDelta; 373 /** Algorithm constant {@code twolpd}. */ 374 private final double twolpd; 375 /** Algorithm constant {@code p1}. */ 376 private final double p1; 377 /** Algorithm constant {@code p2}. */ 378 private final double p2; 379 /** Algorithm constant {@code c1}. */ 380 private final double c1; 381 382 /** 383 * Creates the state. 384 * 385 * <p>The state is valid for construction of a sampler in the range 386 * {@code lambda <= mean < lambda+1} where {@code lambda} is an integer. 387 * 388 * @param lambda the lambda 389 * @param logLambda the log lambda 390 * @param logLambdaFactorial the log lambda factorial 391 * @param delta the delta 392 * @param halfDelta the half delta 393 * @param twolpd the two lambda plus delta 394 * @param p1 the p1 constant 395 * @param p2 the p2 constant 396 * @param c1 the c1 constant 397 */ 398 LargeMeanPoissonSamplerState(double lambda, double logLambda, 399 double logLambdaFactorial, double delta, double halfDelta, double twolpd, 400 double p1, double p2, double c1) { 401 this.lambda = lambda; 402 this.logLambda = logLambda; 403 this.logLambdaFactorial = logLambdaFactorial; 404 this.delta = delta; 405 this.halfDelta = halfDelta; 406 this.twolpd = twolpd; 407 this.p1 = p1; 408 this.p2 = p2; 409 this.c1 = c1; 410 } 411 412 /** 413 * Get the lambda value for the state. 414 * 415 * <p>Equal to {@code floor(mean)} for a Poisson sampler. 416 * @return the lambda value 417 */ 418 int getLambda() { 419 return (int) getLambdaRaw(); 420 } 421 422 /** 423 * @return algorithm constant {@code lambda} 424 */ 425 double getLambdaRaw() { 426 return lambda; 427 } 428 429 /** 430 * @return algorithm constant {@code logLambda} 431 */ 432 double getLogLambda() { 433 return logLambda; 434 } 435 436 /** 437 * @return algorithm constant {@code logLambdaFactorial} 438 */ 439 double getLogLambdaFactorial() { 440 return logLambdaFactorial; 441 } 442 443 /** 444 * @return algorithm constant {@code delta} 445 */ 446 double getDelta() { 447 return delta; 448 } 449 450 /** 451 * @return algorithm constant {@code halfDelta} 452 */ 453 double getHalfDelta() { 454 return halfDelta; 455 } 456 457 /** 458 * @return algorithm constant {@code twolpd} 459 */ 460 double getTwolpd() { 461 return twolpd; 462 } 463 464 /** 465 * @return algorithm constant {@code p1} 466 */ 467 double getP1() { 468 return p1; 469 } 470 471 /** 472 * @return algorithm constant {@code p2} 473 */ 474 double getP2() { 475 return p2; 476 } 477 478 /** 479 * @return algorithm constant {@code c1} 480 */ 481 double getC1() { 482 return c1; 483 } 484 } 485}