001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.commons.rng.sampling.distribution; 018 019import org.apache.commons.rng.UniformRandomProvider; 020import org.apache.commons.rng.sampling.distribution.InternalUtils.FactorialLog; 021 022/** 023 * Sampler for the <a href="http://mathworld.wolfram.com/PoissonDistribution.html">Poisson distribution</a>. 024 * 025 * <ul> 026 * <li> 027 * For large means, we use the rejection algorithm described in 028 * <blockquote> 029 * Devroye, Luc. (1981).<i>The Computer Generation of Poisson Random Variables</i><br> 030 * <strong>Computing</strong> vol. 26 pp. 197-207. 031 * </blockquote> 032 * </li> 033 * </ul> 034 * 035 * @since 1.1 036 * 037 * This sampler is suitable for {@code mean >= 40}. 038 */ 039public class LargeMeanPoissonSampler 040 implements DiscreteSampler { 041 /** Upper bound to avoid truncation. */ 042 private static final double MAX_MEAN = 0.5 * Integer.MAX_VALUE; 043 /** Class to compute {@code log(n!)}. This has no cached values. */ 044 private static final InternalUtils.FactorialLog NO_CACHE_FACTORIAL_LOG; 045 /** Used when there is no requirement for a small mean Poisson sampler. */ 046 private static final DiscreteSampler NO_SMALL_MEAN_POISSON_SAMPLER = null; 047 048 static { 049 // Create without a cache. 050 NO_CACHE_FACTORIAL_LOG = FactorialLog.create(); 051 } 052 053 /** Underlying source of randomness. */ 054 private final UniformRandomProvider rng; 055 /** Exponential. */ 056 private final ContinuousSampler exponential; 057 /** Gaussian. */ 058 private final ContinuousSampler gaussian; 059 /** Local class to compute {@code log(n!)}. This may have cached values. */ 060 private final InternalUtils.FactorialLog factorialLog; 061 062 // Working values 063 064 /** Algorithm constant: {@code Math.floor(mean)}. */ 065 private final double lambda; 066 /** Algorithm constant: {@code Math.log(lambda)}. */ 067 private final double logLambda; 068 /** Algorithm constant: {@code factorialLog((int) lambda)}. */ 069 private final double logLambdaFactorial; 070 /** Algorithm constant: {@code Math.sqrt(lambda * Math.log(32 * lambda / Math.PI + 1))}. */ 071 private final double delta; 072 /** Algorithm constant: {@code delta / 2}. */ 073 private final double halfDelta; 074 /** Algorithm constant: {@code 2 * lambda + delta}. */ 075 private final double twolpd; 076 /** 077 * Algorithm constant: {@code a1 / aSum} with 078 * <ul> 079 * <li>{@code a1 = Math.sqrt(Math.PI * twolpd) * Math.exp(c1)}</li> 080 * <li>{@code aSum = a1 + a2 + 1}</li> 081 * </ul> 082 */ 083 private final double p1; 084 /** 085 * Algorithm constant: {@code a2 / aSum} with 086 * <ul> 087 * <li>{@code a2 = (twolpd / delta) * Math.exp(-delta * (1 + delta) / twolpd)}</li> 088 * <li>{@code aSum = a1 + a2 + 1}</li> 089 * </ul> 090 */ 091 private final double p2; 092 /** Algorithm constant: {@code 1 / (8 * lambda)}. */ 093 private final double c1; 094 095 /** The internal Poisson sampler for the lambda fraction. */ 096 private final DiscreteSampler smallMeanPoissonSampler; 097 098 /** 099 * @param rng Generator of uniformly distributed random numbers. 100 * @param mean Mean. 101 * @throws IllegalArgumentException if {@code mean <= 0} or 102 * {@code mean > 0.5 *} {@link Integer#MAX_VALUE}. 103 */ 104 public LargeMeanPoissonSampler(UniformRandomProvider rng, 105 double mean) { 106 if (mean <= 0) { 107 throw new IllegalArgumentException(mean + " <= " + 0); 108 } 109 // The algorithm is not valid if Math.floor(mean) is not an integer. 110 if (mean > MAX_MEAN) { 111 throw new IllegalArgumentException(mean + " > " + MAX_MEAN); 112 } 113 this.rng = rng; 114 115 gaussian = new ZigguratNormalizedGaussianSampler(rng); 116 exponential = new AhrensDieterExponentialSampler(rng, 1); 117 // Plain constructor uses the uncached function. 118 factorialLog = NO_CACHE_FACTORIAL_LOG; 119 120 // Cache values used in the algorithm 121 lambda = Math.floor(mean); 122 logLambda = Math.log(lambda); 123 logLambdaFactorial = factorialLog((int) lambda); 124 delta = Math.sqrt(lambda * Math.log(32 * lambda / Math.PI + 1)); 125 halfDelta = delta / 2; 126 twolpd = 2 * lambda + delta; 127 c1 = 1 / (8 * lambda); 128 final double a1 = Math.sqrt(Math.PI * twolpd) * Math.exp(c1); 129 final double a2 = (twolpd / delta) * Math.exp(-delta * (1 + delta) / twolpd); 130 final double aSum = a1 + a2 + 1; 131 p1 = a1 / aSum; 132 p2 = a2 / aSum; 133 134 // The algorithm requires a Poisson sample from the remaining lambda fraction. 135 final double lambdaFractional = mean - lambda; 136 smallMeanPoissonSampler = (lambdaFractional < Double.MIN_VALUE) ? 137 NO_SMALL_MEAN_POISSON_SAMPLER : // Not used. 138 new SmallMeanPoissonSampler(rng, lambdaFractional); 139 } 140 141 /** 142 * Instantiates a sampler using a precomputed state. 143 * 144 * @param rng Generator of uniformly distributed random numbers. 145 * @param state The state for {@code lambda = (int)Math.floor(mean)}. 146 * @param lambdaFractional The lambda fractional value 147 * ({@code mean - (int)Math.floor(mean))}. 148 * @throws IllegalArgumentException 149 * if {@code lambdaFractional < 0 || lambdaFractional >= 1}. 150 */ 151 LargeMeanPoissonSampler(UniformRandomProvider rng, 152 LargeMeanPoissonSamplerState state, 153 double lambdaFractional) { 154 if (lambdaFractional < 0 || lambdaFractional >= 1) { 155 throw new IllegalArgumentException( 156 "lambdaFractional must be in the range 0 (inclusive) to 1 (exclusive): " + lambdaFractional); 157 } 158 this.rng = rng; 159 160 gaussian = new ZigguratNormalizedGaussianSampler(rng); 161 exponential = new AhrensDieterExponentialSampler(rng, 1); 162 // Plain constructor uses the uncached function. 163 factorialLog = NO_CACHE_FACTORIAL_LOG; 164 165 // Use the state to initialise the algorithm 166 lambda = state.getLambdaRaw(); 167 logLambda = state.getLogLambda(); 168 logLambdaFactorial = state.getLogLambdaFactorial(); 169 delta = state.getDelta(); 170 halfDelta = state.getHalfDelta(); 171 twolpd = state.getTwolpd(); 172 p1 = state.getP1(); 173 p2 = state.getP2(); 174 c1 = state.getC1(); 175 176 // The algorithm requires a Poisson sample from the remaining lambda fraction. 177 smallMeanPoissonSampler = (lambdaFractional < Double.MIN_VALUE) ? 178 NO_SMALL_MEAN_POISSON_SAMPLER : // Not used. 179 new SmallMeanPoissonSampler(rng, lambdaFractional); 180 } 181 182 /** {@inheritDoc} */ 183 @Override 184 public int sample() { 185 186 final int y2 = (smallMeanPoissonSampler == null) ? 187 0 : // No lambda fraction 188 smallMeanPoissonSampler.sample(); 189 190 double x = 0; 191 double y = 0; 192 double v = 0; 193 int a = 0; 194 double t = 0; 195 double qr = 0; 196 double qa = 0; 197 while (true) { 198 final double u = rng.nextDouble(); 199 if (u <= p1) { 200 final double n = gaussian.sample(); 201 x = n * Math.sqrt(lambda + halfDelta) - 0.5d; 202 if (x > delta || x < -lambda) { 203 continue; 204 } 205 y = x < 0 ? Math.floor(x) : Math.ceil(x); 206 final double e = exponential.sample(); 207 v = -e - 0.5 * n * n + c1; 208 } else { 209 if (u > p1 + p2) { 210 y = lambda; 211 break; 212 } 213 x = delta + (twolpd / delta) * exponential.sample(); 214 y = Math.ceil(x); 215 v = -exponential.sample() - delta * (x + 1) / twolpd; 216 } 217 a = x < 0 ? 1 : 0; 218 t = y * (y + 1) / (2 * lambda); 219 if (v < -t && a == 0) { 220 y = lambda + y; 221 break; 222 } 223 qr = t * ((2 * y + 1) / (6 * lambda) - 1); 224 qa = qr - (t * t) / (3 * (lambda + a * (y + 1))); 225 if (v < qa) { 226 y = lambda + y; 227 break; 228 } 229 if (v > qr) { 230 continue; 231 } 232 if (v < y * logLambda - factorialLog((int) (y + lambda)) + logLambdaFactorial) { 233 y = lambda + y; 234 break; 235 } 236 } 237 238 return (int) Math.min(y2 + (long) y, Integer.MAX_VALUE); 239 } 240 241 /** 242 * Compute the natural logarithm of the factorial of {@code n}. 243 * 244 * @param n Argument. 245 * @return {@code log(n!)} 246 * @throws IllegalArgumentException if {@code n < 0}. 247 */ 248 private double factorialLog(int n) { 249 return factorialLog.value(n); 250 } 251 252 /** {@inheritDoc} */ 253 @Override 254 public String toString() { 255 return "Large Mean Poisson deviate [" + rng.toString() + "]"; 256 } 257 258 /** 259 * Gets the initialisation state of the sampler. 260 * 261 * <p>The state is computed using an integer {@code lambda} value of 262 * {@code lambda = (int)Math.floor(mean)}. 263 * 264 * <p>The state will be suitable for reconstructing a new sampler with a mean 265 * in the range {@code lambda <= mean < lambda+1} using 266 * {@link #LargeMeanPoissonSampler(UniformRandomProvider, LargeMeanPoissonSamplerState, double)}. 267 * 268 * @return the state 269 */ 270 LargeMeanPoissonSamplerState getState() { 271 return new LargeMeanPoissonSamplerState(lambda, logLambda, logLambdaFactorial, 272 delta, halfDelta, twolpd, p1, p2, c1); 273 } 274 275 /** 276 * Encapsulate the state of the sampler. The state is valid for construction of 277 * a sampler in the range {@code lambda <= mean < lambda+1}. 278 * 279 * <p>This class is immutable. 280 * 281 * @see #getLambda() 282 */ 283 static final class LargeMeanPoissonSamplerState { 284 /** Algorithm constant {@code lambda}. */ 285 private final double lambda; 286 /** Algorithm constant {@code logLambda}. */ 287 private final double logLambda; 288 /** Algorithm constant {@code logLambdaFactorial}. */ 289 private final double logLambdaFactorial; 290 /** Algorithm constant {@code delta}. */ 291 private final double delta; 292 /** Algorithm constant {@code halfDelta}. */ 293 private final double halfDelta; 294 /** Algorithm constant {@code twolpd}. */ 295 private final double twolpd; 296 /** Algorithm constant {@code p1}. */ 297 private final double p1; 298 /** Algorithm constant {@code p2}. */ 299 private final double p2; 300 /** Algorithm constant {@code c1}. */ 301 private final double c1; 302 303 /** 304 * Creates the state. 305 * 306 * <p>The state is valid for construction of a sampler in the range 307 * {@code lambda <= mean < lambda+1} where {@code lambda} is an integer. 308 * 309 * @param lambda the lambda 310 * @param logLambda the log lambda 311 * @param logLambdaFactorial the log lambda factorial 312 * @param delta the delta 313 * @param halfDelta the half delta 314 * @param twolpd the two lambda plus delta 315 * @param p1 the p1 constant 316 * @param p2 the p2 constant 317 * @param c1 the c1 constant 318 */ 319 private LargeMeanPoissonSamplerState(double lambda, double logLambda, 320 double logLambdaFactorial, double delta, double halfDelta, double twolpd, 321 double p1, double p2, double c1) { 322 this.lambda = lambda; 323 this.logLambda = logLambda; 324 this.logLambdaFactorial = logLambdaFactorial; 325 this.delta = delta; 326 this.halfDelta = halfDelta; 327 this.twolpd = twolpd; 328 this.p1 = p1; 329 this.p2 = p2; 330 this.c1 = c1; 331 } 332 333 /** 334 * Get the lambda value for the state. 335 * 336 * <p>Equal to {@code floor(mean)} for a Poisson sampler. 337 * @return the lambda value 338 */ 339 int getLambda() { 340 return (int) getLambdaRaw(); 341 } 342 343 /** 344 * @return algorithm constant {@code lambda} 345 */ 346 double getLambdaRaw() { 347 return lambda; 348 } 349 350 /** 351 * @return algorithm constant {@code logLambda} 352 */ 353 double getLogLambda() { 354 return logLambda; 355 } 356 357 /** 358 * @return algorithm constant {@code logLambdaFactorial} 359 */ 360 double getLogLambdaFactorial() { 361 return logLambdaFactorial; 362 } 363 364 /** 365 * @return algorithm constant {@code delta} 366 */ 367 double getDelta() { 368 return delta; 369 } 370 371 /** 372 * @return algorithm constant {@code halfDelta} 373 */ 374 double getHalfDelta() { 375 return halfDelta; 376 } 377 378 /** 379 * @return algorithm constant {@code twolpd} 380 */ 381 double getTwolpd() { 382 return twolpd; 383 } 384 385 /** 386 * @return algorithm constant {@code p1} 387 */ 388 double getP1() { 389 return p1; 390 } 391 392 /** 393 * @return algorithm constant {@code p2} 394 */ 395 double getP2() { 396 return p2; 397 } 398 399 /** 400 * @return algorithm constant {@code c1} 401 */ 402 double getC1() { 403 return c1; 404 } 405 } 406}