LargeMeanPoissonSampler.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.commons.rng.sampling.distribution;

import org.apache.commons.rng.UniformRandomProvider;
import org.apache.commons.rng.sampling.distribution.InternalUtils.FactorialLog;

/**
 * Sampler for the <a href="http://mathworld.wolfram.com/PoissonDistribution.html">Poisson distribution</a>.
 *
 * <ul>
 *  <li>
 *   For large means, we use the rejection algorithm described in
 *   <blockquote>
 *    Devroye, Luc. (1981).<i>The Computer Generation of Poisson Random Variables</i><br>
 *    <strong>Computing</strong> vol. 26 pp. 197-207.
 *   </blockquote>
 *  </li>
 * </ul>
 *
 * <p>This sampler is suitable for {@code mean >= 40}.</p>
 *
 * <p>Sampling uses:</p>
 *
 * <ul>
 *   <li>{@link UniformRandomProvider#nextLong()}
 *   <li>{@link UniformRandomProvider#nextDouble()}
 * </ul>
 *
 * @since 1.1
 */
public class LargeMeanPoissonSampler
    implements SharedStateDiscreteSampler {
    /** Upper bound to avoid truncation. */
    private static final double MAX_MEAN = 0.5 * Integer.MAX_VALUE;
    /** Class to compute {@code log(n!)}. This has no cached values. */
    private static final InternalUtils.FactorialLog NO_CACHE_FACTORIAL_LOG;
    /** Used when there is no requirement for a small mean Poisson sampler. */
    private static final SharedStateDiscreteSampler NO_SMALL_MEAN_POISSON_SAMPLER =
        new SharedStateDiscreteSampler() {
            @Override
            public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
                // No requirement for RNG
                return this;
            }

            @Override
            public int sample() {
                // No Poisson sample
                return 0;
            }
        };

    static {
        // Create without a cache.
        NO_CACHE_FACTORIAL_LOG = FactorialLog.create();
    }

    /** Underlying source of randomness. */
    private final UniformRandomProvider rng;
    /** Exponential. */
    private final SharedStateContinuousSampler exponential;
    /** Gaussian. */
    private final SharedStateContinuousSampler gaussian;
    /** Local class to compute {@code log(n!)}. This may have cached values. */
    private final InternalUtils.FactorialLog factorialLog;

    // Working values

    /** Algorithm constant: {@code Math.floor(mean)}. */
    private final double lambda;
    /** Algorithm constant: {@code Math.log(lambda)}. */
    private final double logLambda;
    /** Algorithm constant: {@code factorialLog((int) lambda)}. */
    private final double logLambdaFactorial;
    /** Algorithm constant: {@code Math.sqrt(lambda * Math.log(32 * lambda / Math.PI + 1))}. */
    private final double delta;
    /** Algorithm constant: {@code delta / 2}. */
    private final double halfDelta;
    /** Algorithm constant: {@code Math.sqrt(lambda + halfDelta)}. */
    private final double sqrtLambdaPlusHalfDelta;
    /** Algorithm constant: {@code 2 * lambda + delta}. */
    private final double twolpd;
    /**
     * Algorithm constant: {@code a1 / aSum}.
     * <ul>
     *  <li>{@code a1 = Math.sqrt(Math.PI * twolpd) * Math.exp(c1)}</li>
     *  <li>{@code aSum = a1 + a2 + 1}</li>
     * </ul>
     */
    private final double p1;
    /**
     * Algorithm constant: {@code a2 / aSum}.
     * <ul>
     *  <li>{@code a2 = (twolpd / delta) * Math.exp(-delta * (1 + delta) / twolpd)}</li>
     *  <li>{@code aSum = a1 + a2 + 1}</li>
     * </ul>
     */
    private final double p2;
    /** Algorithm constant: {@code 1 / (8 * lambda)}. */
    private final double c1;

    /** The internal Poisson sampler for the lambda fraction. */
    private final SharedStateDiscreteSampler smallMeanPoissonSampler;

    /**
     * @param rng Generator of uniformly distributed random numbers.
     * @param mean Mean.
     * @throws IllegalArgumentException if {@code mean < 1} or
     * {@code mean > 0.5 *} {@link Integer#MAX_VALUE}.
     */
    public LargeMeanPoissonSampler(UniformRandomProvider rng,
                                   double mean) {
        if (mean < 1) {
            throw new IllegalArgumentException("mean is not >= 1: " + mean);
        }
        // The algorithm is not valid if Math.floor(mean) is not an integer.
        if (mean > MAX_MEAN) {
            throw new IllegalArgumentException("mean " + mean + " > " + MAX_MEAN);
        }
        this.rng = rng;

        gaussian = ZigguratSampler.NormalizedGaussian.of(rng);
        exponential = ZigguratSampler.Exponential.of(rng);
        // Plain constructor uses the uncached function.
        factorialLog = NO_CACHE_FACTORIAL_LOG;

        // Cache values used in the algorithm
        lambda = Math.floor(mean);
        logLambda = Math.log(lambda);
        logLambdaFactorial = getFactorialLog((int) lambda);
        delta = Math.sqrt(lambda * Math.log(32 * lambda / Math.PI + 1));
        halfDelta = delta / 2;
        sqrtLambdaPlusHalfDelta = Math.sqrt(lambda + halfDelta);
        twolpd = 2 * lambda + delta;
        c1 = 1 / (8 * lambda);
        final double a1 = Math.sqrt(Math.PI * twolpd) * Math.exp(c1);
        final double a2 = (twolpd / delta) * Math.exp(-delta * (1 + delta) / twolpd);
        final double aSum = a1 + a2 + 1;
        p1 = a1 / aSum;
        p2 = a2 / aSum;

        // The algorithm requires a Poisson sample from the remaining lambda fraction.
        final double lambdaFractional = mean - lambda;
        smallMeanPoissonSampler = (lambdaFractional < Double.MIN_VALUE) ?
            NO_SMALL_MEAN_POISSON_SAMPLER : // Not used.
            KempSmallMeanPoissonSampler.of(rng, lambdaFractional);
    }

    /**
     * Instantiates a sampler using a precomputed state.
     *
     * @param rng              Generator of uniformly distributed random numbers.
     * @param state            The state for {@code lambda = (int)Math.floor(mean)}.
     * @param lambdaFractional The lambda fractional value
     *                         ({@code mean - (int)Math.floor(mean))}.
     * @throws IllegalArgumentException
     *                         if {@code lambdaFractional < 0 || lambdaFractional >= 1}.
     */
    LargeMeanPoissonSampler(UniformRandomProvider rng,
                            LargeMeanPoissonSamplerState state,
                            double lambdaFractional) {
        if (lambdaFractional < 0 || lambdaFractional >= 1) {
            throw new IllegalArgumentException(
                    "lambdaFractional must be in the range 0 (inclusive) to 1 (exclusive): " + lambdaFractional);
        }
        this.rng = rng;

        gaussian = ZigguratSampler.NormalizedGaussian.of(rng);
        exponential = ZigguratSampler.Exponential.of(rng);
        // Plain constructor uses the uncached function.
        factorialLog = NO_CACHE_FACTORIAL_LOG;

        // Use the state to initialize the algorithm
        lambda = state.getLambdaRaw();
        logLambda = state.getLogLambda();
        logLambdaFactorial = state.getLogLambdaFactorial();
        delta = state.getDelta();
        halfDelta = state.getHalfDelta();
        sqrtLambdaPlusHalfDelta = state.getSqrtLambdaPlusHalfDelta();
        twolpd = state.getTwolpd();
        p1 = state.getP1();
        p2 = state.getP2();
        c1 = state.getC1();

        // The algorithm requires a Poisson sample from the remaining lambda fraction.
        smallMeanPoissonSampler = (lambdaFractional < Double.MIN_VALUE) ?
            NO_SMALL_MEAN_POISSON_SAMPLER : // Not used.
            KempSmallMeanPoissonSampler.of(rng, lambdaFractional);
    }

    /**
     * @param rng Generator of uniformly distributed random numbers.
     * @param source Source to copy.
     */
    private LargeMeanPoissonSampler(UniformRandomProvider rng,
                                    LargeMeanPoissonSampler source) {
        this.rng = rng;

        gaussian = source.gaussian.withUniformRandomProvider(rng);
        exponential = source.exponential.withUniformRandomProvider(rng);
        // Reuse the cache
        factorialLog = source.factorialLog;

        lambda = source.lambda;
        logLambda = source.logLambda;
        logLambdaFactorial = source.logLambdaFactorial;
        delta = source.delta;
        halfDelta = source.halfDelta;
        sqrtLambdaPlusHalfDelta = source.sqrtLambdaPlusHalfDelta;
        twolpd = source.twolpd;
        p1 = source.p1;
        p2 = source.p2;
        c1 = source.c1;

        // Share the state of the small sampler
        smallMeanPoissonSampler = source.smallMeanPoissonSampler.withUniformRandomProvider(rng);
    }

    /** {@inheritDoc} */
    @Override
    public int sample() {
        // This will never be null. It may be a no-op delegate that returns zero.
        final int y2 = smallMeanPoissonSampler.sample();

        double x;
        double y;
        double v;
        int a;
        double t;
        double qr;
        double qa;
        while (true) {
            // Step 1:
            final double u = rng.nextDouble();
            if (u <= p1) {
                // Step 2:
                final double n = gaussian.sample();
                x = n * sqrtLambdaPlusHalfDelta - 0.5d;
                if (x > delta || x < -lambda) {
                    continue;
                }
                y = x < 0 ? Math.floor(x) : Math.ceil(x);
                final double e = exponential.sample();
                v = -e - 0.5 * n * n + c1;
            } else {
                // Step 3:
                if (u > p1 + p2) {
                    y = lambda;
                    break;
                }
                x = delta + (twolpd / delta) * exponential.sample();
                y = Math.ceil(x);
                v = -exponential.sample() - delta * (x + 1) / twolpd;
            }
            // The Squeeze Principle
            // Step 4.1:
            a = x < 0 ? 1 : 0;
            t = y * (y + 1) / (2 * lambda);
            // Step 4.2
            if (v < -t && a == 0) {
                y = lambda + y;
                break;
            }
            // Step 4.3:
            qr = t * ((2 * y + 1) / (6 * lambda) - 1);
            qa = qr - (t * t) / (3 * (lambda + a * (y + 1)));
            // Step 4.4:
            if (v < qa) {
                y = lambda + y;
                break;
            }
            // Step 4.5:
            if (v > qr) {
                continue;
            }
            // Step 4.6:
            if (v < y * logLambda - getFactorialLog((int) (y + lambda)) + logLambdaFactorial) {
                y = lambda + y;
                break;
            }
        }

        return (int) Math.min(y2 + (long) y, Integer.MAX_VALUE);
    }

    /**
     * Compute the natural logarithm of the factorial of {@code n}.
     *
     * @param n Argument.
     * @return {@code log(n!)}
     * @throws IllegalArgumentException if {@code n < 0}.
     */
    private double getFactorialLog(int n) {
        return factorialLog.value(n);
    }

    /** {@inheritDoc} */
    @Override
    public String toString() {
        return "Large Mean Poisson deviate [" + rng.toString() + "]";
    }

    /**
     * {@inheritDoc}
     *
     * @since 1.3
     */
    @Override
    public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
        return new LargeMeanPoissonSampler(rng, this);
    }

    /**
     * Creates a new Poisson distribution sampler.
     *
     * @param rng Generator of uniformly distributed random numbers.
     * @param mean Mean.
     * @return the sampler
     * @throws IllegalArgumentException if {@code mean < 1} or {@code mean > 0.5 *}
     * {@link Integer#MAX_VALUE}.
     * @since 1.3
     */
    public static SharedStateDiscreteSampler of(UniformRandomProvider rng,
                                                double mean) {
        return new LargeMeanPoissonSampler(rng, mean);
    }

    /**
     * Gets the initialisation state of the sampler.
     *
     * <p>The state is computed using an integer {@code lambda} value of
     * {@code lambda = (int)Math.floor(mean)}.
     *
     * <p>The state will be suitable for reconstructing a new sampler with a mean
     * in the range {@code lambda <= mean < lambda+1} using
     * {@link #LargeMeanPoissonSampler(UniformRandomProvider, LargeMeanPoissonSamplerState, double)}.
     *
     * @return the state
     */
    LargeMeanPoissonSamplerState getState() {
        return new LargeMeanPoissonSamplerState(lambda, logLambda, logLambdaFactorial,
                delta, halfDelta, sqrtLambdaPlusHalfDelta, twolpd, p1, p2, c1);
    }

    /**
     * Encapsulate the state of the sampler. The state is valid for construction of
     * a sampler in the range {@code lambda <= mean < lambda+1}.
     *
     * <p>This class is immutable.
     *
     * @see #getLambda()
     */
    static final class LargeMeanPoissonSamplerState {
        /** Algorithm constant {@code lambda}. */
        private final double lambda;
        /** Algorithm constant {@code logLambda}. */
        private final double logLambda;
        /** Algorithm constant {@code logLambdaFactorial}. */
        private final double logLambdaFactorial;
        /** Algorithm constant {@code delta}. */
        private final double delta;
        /** Algorithm constant {@code halfDelta}. */
        private final double halfDelta;
        /** Algorithm constant {@code sqrtLambdaPlusHalfDelta}. */
        private final double sqrtLambdaPlusHalfDelta;
        /** Algorithm constant {@code twolpd}. */
        private final double twolpd;
        /** Algorithm constant {@code p1}. */
        private final double p1;
        /** Algorithm constant {@code p2}. */
        private final double p2;
        /** Algorithm constant {@code c1}. */
        private final double c1;

        /**
         * Creates the state.
         *
         * <p>The state is valid for construction of a sampler in the range
         * {@code lambda <= mean < lambda+1} where {@code lambda} is an integer.
         *
         * @param lambda the lambda
         * @param logLambda the log lambda
         * @param logLambdaFactorial the log lambda factorial
         * @param delta the delta
         * @param halfDelta the half delta
         * @param sqrtLambdaPlusHalfDelta the sqrt(lambda+half delta)
         * @param twolpd the two lambda plus delta
         * @param p1 the p1 constant
         * @param p2 the p2 constant
         * @param c1 the c1 constant
         */
        LargeMeanPoissonSamplerState(double lambda, double logLambda,
                double logLambdaFactorial, double delta, double halfDelta,
                double sqrtLambdaPlusHalfDelta, double twolpd,
                double p1, double p2, double c1) {
            this.lambda = lambda;
            this.logLambda = logLambda;
            this.logLambdaFactorial = logLambdaFactorial;
            this.delta = delta;
            this.halfDelta = halfDelta;
            this.sqrtLambdaPlusHalfDelta = sqrtLambdaPlusHalfDelta;
            this.twolpd = twolpd;
            this.p1 = p1;
            this.p2 = p2;
            this.c1 = c1;
        }

        /**
         * Get the lambda value for the state.
         *
         * <p>Equal to {@code floor(mean)} for a Poisson sampler.
         * @return the lambda value
         */
        int getLambda() {
            return (int) getLambdaRaw();
        }

        /**
         * @return algorithm constant {@code lambda}
         */
        double getLambdaRaw() {
            return lambda;
        }

        /**
         * @return algorithm constant {@code logLambda}
         */
        double getLogLambda() {
            return logLambda;
        }

        /**
         * @return algorithm constant {@code logLambdaFactorial}
         */
        double getLogLambdaFactorial() {
            return logLambdaFactorial;
        }

        /**
         * @return algorithm constant {@code delta}
         */
        double getDelta() {
            return delta;
        }

        /**
         * @return algorithm constant {@code halfDelta}
         */
        double getHalfDelta() {
            return halfDelta;
        }

        /**
         * @return algorithm constant {@code sqrtLambdaPlusHalfDelta}
         */
        double getSqrtLambdaPlusHalfDelta() {
            return sqrtLambdaPlusHalfDelta;
        }

        /**
         * @return algorithm constant {@code twolpd}
         */
        double getTwolpd() {
            return twolpd;
        }

        /**
         * @return algorithm constant {@code p1}
         */
        double getP1() {
            return p1;
        }

        /**
         * @return algorithm constant {@code p2}
         */
        double getP2() {
            return p2;
        }

        /**
         * @return algorithm constant {@code c1}
         */
        double getC1() {
            return c1;
        }
    }
}