LargeMeanPoissonSampler.java
- /*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- package org.apache.commons.rng.sampling.distribution;
- import org.apache.commons.rng.UniformRandomProvider;
- import org.apache.commons.rng.sampling.distribution.InternalUtils.FactorialLog;
- /**
- * Sampler for the <a href="http://mathworld.wolfram.com/PoissonDistribution.html">Poisson distribution</a>.
- *
- * <ul>
- * <li>
- * For large means, we use the rejection algorithm described in
- * <blockquote>
- * Devroye, Luc. (1981).<i>The Computer Generation of Poisson Random Variables</i><br>
- * <strong>Computing</strong> vol. 26 pp. 197-207.
- * </blockquote>
- * </li>
- * </ul>
- *
- * <p>This sampler is suitable for {@code mean >= 40}.</p>
- *
- * <p>Sampling uses:</p>
- *
- * <ul>
- * <li>{@link UniformRandomProvider#nextLong()}
- * <li>{@link UniformRandomProvider#nextDouble()}
- * </ul>
- *
- * @since 1.1
- */
- public class LargeMeanPoissonSampler
- implements SharedStateDiscreteSampler {
- /** Upper bound to avoid truncation. */
- private static final double MAX_MEAN = 0.5 * Integer.MAX_VALUE;
- /** Class to compute {@code log(n!)}. This has no cached values. */
- private static final InternalUtils.FactorialLog NO_CACHE_FACTORIAL_LOG;
- /** Used when there is no requirement for a small mean Poisson sampler. */
- private static final SharedStateDiscreteSampler NO_SMALL_MEAN_POISSON_SAMPLER =
- new SharedStateDiscreteSampler() {
- @Override
- public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
- // No requirement for RNG
- return this;
- }
- @Override
- public int sample() {
- // No Poisson sample
- return 0;
- }
- };
- static {
- // Create without a cache.
- NO_CACHE_FACTORIAL_LOG = FactorialLog.create();
- }
- /** Underlying source of randomness. */
- private final UniformRandomProvider rng;
- /** Exponential. */
- private final SharedStateContinuousSampler exponential;
- /** Gaussian. */
- private final SharedStateContinuousSampler gaussian;
- /** Local class to compute {@code log(n!)}. This may have cached values. */
- private final InternalUtils.FactorialLog factorialLog;
- // Working values
- /** Algorithm constant: {@code Math.floor(mean)}. */
- private final double lambda;
- /** Algorithm constant: {@code Math.log(lambda)}. */
- private final double logLambda;
- /** Algorithm constant: {@code factorialLog((int) lambda)}. */
- private final double logLambdaFactorial;
- /** Algorithm constant: {@code Math.sqrt(lambda * Math.log(32 * lambda / Math.PI + 1))}. */
- private final double delta;
- /** Algorithm constant: {@code delta / 2}. */
- private final double halfDelta;
- /** Algorithm constant: {@code Math.sqrt(lambda + halfDelta)}. */
- private final double sqrtLambdaPlusHalfDelta;
- /** Algorithm constant: {@code 2 * lambda + delta}. */
- private final double twolpd;
- /**
- * Algorithm constant: {@code a1 / aSum}.
- * <ul>
- * <li>{@code a1 = Math.sqrt(Math.PI * twolpd) * Math.exp(c1)}</li>
- * <li>{@code aSum = a1 + a2 + 1}</li>
- * </ul>
- */
- private final double p1;
- /**
- * Algorithm constant: {@code a2 / aSum}.
- * <ul>
- * <li>{@code a2 = (twolpd / delta) * Math.exp(-delta * (1 + delta) / twolpd)}</li>
- * <li>{@code aSum = a1 + a2 + 1}</li>
- * </ul>
- */
- private final double p2;
- /** Algorithm constant: {@code 1 / (8 * lambda)}. */
- private final double c1;
- /** The internal Poisson sampler for the lambda fraction. */
- private final SharedStateDiscreteSampler smallMeanPoissonSampler;
- /**
- * Create an instance.
- *
- * @param rng Generator of uniformly distributed random numbers.
- * @param mean Mean.
- * @throws IllegalArgumentException if {@code mean < 1} or
- * {@code mean > 0.5 *} {@link Integer#MAX_VALUE}.
- */
- public LargeMeanPoissonSampler(UniformRandomProvider rng,
- double mean) {
- // Validation before java.lang.Object constructor exits prevents partially initialized object
- this(InternalUtils.requireRangeClosed(1, MAX_MEAN, mean, "mean"), rng);
- }
- /**
- * Instantiates a sampler using a precomputed state.
- *
- * @param rng Generator of uniformly distributed random numbers.
- * @param state The state for {@code lambda = (int)Math.floor(mean)}.
- * @param lambdaFractional The lambda fractional value
- * ({@code mean - (int)Math.floor(mean))}.
- * @throws IllegalArgumentException
- * if {@code lambdaFractional < 0 || lambdaFractional >= 1}.
- */
- LargeMeanPoissonSampler(UniformRandomProvider rng,
- LargeMeanPoissonSamplerState state,
- double lambdaFractional) {
- // Validation before java.lang.Object constructor exits prevents partially initialized object
- this(state, InternalUtils.requireRange(0, 1, lambdaFractional, "lambdaFractional"), rng);
- }
- /**
- * @param mean Mean.
- * @param rng Generator of uniformly distributed random numbers.
- */
- private LargeMeanPoissonSampler(double mean,
- UniformRandomProvider rng) {
- this.rng = rng;
- gaussian = ZigguratSampler.NormalizedGaussian.of(rng);
- exponential = ZigguratSampler.Exponential.of(rng);
- // Plain constructor uses the uncached function.
- factorialLog = NO_CACHE_FACTORIAL_LOG;
- // Cache values used in the algorithm
- lambda = Math.floor(mean);
- logLambda = Math.log(lambda);
- logLambdaFactorial = getFactorialLog((int) lambda);
- delta = Math.sqrt(lambda * Math.log(32 * lambda / Math.PI + 1));
- halfDelta = delta / 2;
- sqrtLambdaPlusHalfDelta = Math.sqrt(lambda + halfDelta);
- twolpd = 2 * lambda + delta;
- c1 = 1 / (8 * lambda);
- final double a1 = Math.sqrt(Math.PI * twolpd) * Math.exp(c1);
- final double a2 = (twolpd / delta) * Math.exp(-delta * (1 + delta) / twolpd);
- final double aSum = a1 + a2 + 1;
- p1 = a1 / aSum;
- p2 = a2 / aSum;
- // The algorithm requires a Poisson sample from the remaining lambda fraction.
- final double lambdaFractional = mean - lambda;
- smallMeanPoissonSampler = (lambdaFractional < Double.MIN_VALUE) ?
- NO_SMALL_MEAN_POISSON_SAMPLER : // Not used.
- KempSmallMeanPoissonSampler.of(rng, lambdaFractional);
- }
- /**
- * Instantiates a sampler using a precomputed state.
- *
- * @param state The state for {@code lambda = (int)Math.floor(mean)}.
- * @param lambdaFractional The lambda fractional value
- * ({@code mean - (int)Math.floor(mean))}.
- * @param rng Generator of uniformly distributed random numbers.
- */
- private LargeMeanPoissonSampler(LargeMeanPoissonSamplerState state,
- double lambdaFractional,
- UniformRandomProvider rng) {
- this.rng = rng;
- gaussian = ZigguratSampler.NormalizedGaussian.of(rng);
- exponential = ZigguratSampler.Exponential.of(rng);
- // Plain constructor uses the uncached function.
- factorialLog = NO_CACHE_FACTORIAL_LOG;
- // Use the state to initialize the algorithm
- lambda = state.getLambdaRaw();
- logLambda = state.getLogLambda();
- logLambdaFactorial = state.getLogLambdaFactorial();
- delta = state.getDelta();
- halfDelta = state.getHalfDelta();
- sqrtLambdaPlusHalfDelta = state.getSqrtLambdaPlusHalfDelta();
- twolpd = state.getTwolpd();
- p1 = state.getP1();
- p2 = state.getP2();
- c1 = state.getC1();
- // The algorithm requires a Poisson sample from the remaining lambda fraction.
- smallMeanPoissonSampler = (lambdaFractional < Double.MIN_VALUE) ?
- NO_SMALL_MEAN_POISSON_SAMPLER : // Not used.
- KempSmallMeanPoissonSampler.of(rng, lambdaFractional);
- }
- /**
- * @param rng Generator of uniformly distributed random numbers.
- * @param source Source to copy.
- */
- private LargeMeanPoissonSampler(UniformRandomProvider rng,
- LargeMeanPoissonSampler source) {
- this.rng = rng;
- gaussian = source.gaussian.withUniformRandomProvider(rng);
- exponential = source.exponential.withUniformRandomProvider(rng);
- // Reuse the cache
- factorialLog = source.factorialLog;
- lambda = source.lambda;
- logLambda = source.logLambda;
- logLambdaFactorial = source.logLambdaFactorial;
- delta = source.delta;
- halfDelta = source.halfDelta;
- sqrtLambdaPlusHalfDelta = source.sqrtLambdaPlusHalfDelta;
- twolpd = source.twolpd;
- p1 = source.p1;
- p2 = source.p2;
- c1 = source.c1;
- // Share the state of the small sampler
- smallMeanPoissonSampler = source.smallMeanPoissonSampler.withUniformRandomProvider(rng);
- }
- /** {@inheritDoc} */
- @Override
- public int sample() {
- // This will never be null. It may be a no-op delegate that returns zero.
- final int y2 = smallMeanPoissonSampler.sample();
- double x;
- double y;
- double v;
- int a;
- double t;
- double qr;
- double qa;
- while (true) {
- // Step 1:
- final double u = rng.nextDouble();
- if (u <= p1) {
- // Step 2:
- final double n = gaussian.sample();
- x = n * sqrtLambdaPlusHalfDelta - 0.5d;
- if (x > delta || x < -lambda) {
- continue;
- }
- y = x < 0 ? Math.floor(x) : Math.ceil(x);
- final double e = exponential.sample();
- v = -e - 0.5 * n * n + c1;
- } else {
- // Step 3:
- if (u > p1 + p2) {
- y = lambda;
- break;
- }
- x = delta + (twolpd / delta) * exponential.sample();
- y = Math.ceil(x);
- v = -exponential.sample() - delta * (x + 1) / twolpd;
- }
- // The Squeeze Principle
- // Step 4.1:
- a = x < 0 ? 1 : 0;
- t = y * (y + 1) / (2 * lambda);
- // Step 4.2
- if (v < -t && a == 0) {
- y = lambda + y;
- break;
- }
- // Step 4.3:
- qr = t * ((2 * y + 1) / (6 * lambda) - 1);
- qa = qr - (t * t) / (3 * (lambda + a * (y + 1)));
- // Step 4.4:
- if (v < qa) {
- y = lambda + y;
- break;
- }
- // Step 4.5:
- if (v > qr) {
- continue;
- }
- // Step 4.6:
- if (v < y * logLambda - getFactorialLog((int) (y + lambda)) + logLambdaFactorial) {
- y = lambda + y;
- break;
- }
- }
- return (int) Math.min(y2 + (long) y, Integer.MAX_VALUE);
- }
- /**
- * Compute the natural logarithm of the factorial of {@code n}.
- *
- * @param n Argument.
- * @return {@code log(n!)}
- * @throws IllegalArgumentException if {@code n < 0}.
- */
- private double getFactorialLog(int n) {
- return factorialLog.value(n);
- }
- /** {@inheritDoc} */
- @Override
- public String toString() {
- return "Large Mean Poisson deviate [" + rng.toString() + "]";
- }
- /**
- * {@inheritDoc}
- *
- * @since 1.3
- */
- @Override
- public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
- return new LargeMeanPoissonSampler(rng, this);
- }
- /**
- * Creates a new Poisson distribution sampler.
- *
- * @param rng Generator of uniformly distributed random numbers.
- * @param mean Mean.
- * @return the sampler
- * @throws IllegalArgumentException if {@code mean < 1} or {@code mean > 0.5 *}
- * {@link Integer#MAX_VALUE}.
- * @since 1.3
- */
- public static SharedStateDiscreteSampler of(UniformRandomProvider rng,
- double mean) {
- return new LargeMeanPoissonSampler(rng, mean);
- }
- /**
- * Gets the initialisation state of the sampler.
- *
- * <p>The state is computed using an integer {@code lambda} value of
- * {@code lambda = (int)Math.floor(mean)}.
- *
- * <p>The state will be suitable for reconstructing a new sampler with a mean
- * in the range {@code lambda <= mean < lambda+1} using
- * {@link #LargeMeanPoissonSampler(UniformRandomProvider, LargeMeanPoissonSamplerState, double)}.
- *
- * @return the state
- */
- LargeMeanPoissonSamplerState getState() {
- return new LargeMeanPoissonSamplerState(lambda, logLambda, logLambdaFactorial,
- delta, halfDelta, sqrtLambdaPlusHalfDelta, twolpd, p1, p2, c1);
- }
- /**
- * Encapsulate the state of the sampler. The state is valid for construction of
- * a sampler in the range {@code lambda <= mean < lambda+1}.
- *
- * <p>This class is immutable.
- *
- * @see #getLambda()
- */
- static final class LargeMeanPoissonSamplerState {
- /** Algorithm constant {@code lambda}. */
- private final double lambda;
- /** Algorithm constant {@code logLambda}. */
- private final double logLambda;
- /** Algorithm constant {@code logLambdaFactorial}. */
- private final double logLambdaFactorial;
- /** Algorithm constant {@code delta}. */
- private final double delta;
- /** Algorithm constant {@code halfDelta}. */
- private final double halfDelta;
- /** Algorithm constant {@code sqrtLambdaPlusHalfDelta}. */
- private final double sqrtLambdaPlusHalfDelta;
- /** Algorithm constant {@code twolpd}. */
- private final double twolpd;
- /** Algorithm constant {@code p1}. */
- private final double p1;
- /** Algorithm constant {@code p2}. */
- private final double p2;
- /** Algorithm constant {@code c1}. */
- private final double c1;
- /**
- * Creates the state.
- *
- * <p>The state is valid for construction of a sampler in the range
- * {@code lambda <= mean < lambda+1} where {@code lambda} is an integer.
- *
- * @param lambda the lambda
- * @param logLambda the log lambda
- * @param logLambdaFactorial the log lambda factorial
- * @param delta the delta
- * @param halfDelta the half delta
- * @param sqrtLambdaPlusHalfDelta the sqrt(lambda+half delta)
- * @param twolpd the two lambda plus delta
- * @param p1 the p1 constant
- * @param p2 the p2 constant
- * @param c1 the c1 constant
- */
- LargeMeanPoissonSamplerState(double lambda, double logLambda,
- double logLambdaFactorial, double delta, double halfDelta,
- double sqrtLambdaPlusHalfDelta, double twolpd,
- double p1, double p2, double c1) {
- this.lambda = lambda;
- this.logLambda = logLambda;
- this.logLambdaFactorial = logLambdaFactorial;
- this.delta = delta;
- this.halfDelta = halfDelta;
- this.sqrtLambdaPlusHalfDelta = sqrtLambdaPlusHalfDelta;
- this.twolpd = twolpd;
- this.p1 = p1;
- this.p2 = p2;
- this.c1 = c1;
- }
- /**
- * Get the lambda value for the state.
- *
- * <p>Equal to {@code floor(mean)} for a Poisson sampler.
- * @return the lambda value
- */
- int getLambda() {
- return (int) getLambdaRaw();
- }
- /**
- * @return algorithm constant {@code lambda}
- */
- double getLambdaRaw() {
- return lambda;
- }
- /**
- * @return algorithm constant {@code logLambda}
- */
- double getLogLambda() {
- return logLambda;
- }
- /**
- * @return algorithm constant {@code logLambdaFactorial}
- */
- double getLogLambdaFactorial() {
- return logLambdaFactorial;
- }
- /**
- * @return algorithm constant {@code delta}
- */
- double getDelta() {
- return delta;
- }
- /**
- * @return algorithm constant {@code halfDelta}
- */
- double getHalfDelta() {
- return halfDelta;
- }
- /**
- * @return algorithm constant {@code sqrtLambdaPlusHalfDelta}
- */
- double getSqrtLambdaPlusHalfDelta() {
- return sqrtLambdaPlusHalfDelta;
- }
- /**
- * @return algorithm constant {@code twolpd}
- */
- double getTwolpd() {
- return twolpd;
- }
- /**
- * @return algorithm constant {@code p1}
- */
- double getP1() {
- return p1;
- }
- /**
- * @return algorithm constant {@code p2}
- */
- double getP2() {
- return p2;
- }
- /**
- * @return algorithm constant {@code c1}
- */
- double getC1() {
- return c1;
- }
- }
- }