001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.rng.sampling.distribution;
018
019import org.apache.commons.rng.UniformRandomProvider;
020import org.apache.commons.rng.sampling.distribution.InternalUtils.FactorialLog;
021
022/**
023 * Sampler for the <a href="http://mathworld.wolfram.com/PoissonDistribution.html">Poisson distribution</a>.
024 *
025 * <ul>
026 *  <li>
027 *   For large means, we use the rejection algorithm described in
028 *   <blockquote>
029 *    Devroye, Luc. (1981).<i>The Computer Generation of Poisson Random Variables</i><br>
030 *    <strong>Computing</strong> vol. 26 pp. 197-207.
031 *   </blockquote>
032 *  </li>
033 * </ul>
034 *
035 * <p>This sampler is suitable for {@code mean >= 40}.</p>
036 *
037 * <p>Sampling uses:</p>
038 *
039 * <ul>
040 *   <li>{@link UniformRandomProvider#nextLong()}
041 *   <li>{@link UniformRandomProvider#nextDouble()}
042 * </ul>
043 *
044 * @since 1.1
045 */
046public class LargeMeanPoissonSampler
047    implements SharedStateDiscreteSampler {
048    /** Upper bound to avoid truncation. */
049    private static final double MAX_MEAN = 0.5 * Integer.MAX_VALUE;
050    /** Class to compute {@code log(n!)}. This has no cached values. */
051    private static final InternalUtils.FactorialLog NO_CACHE_FACTORIAL_LOG;
052    /** Used when there is no requirement for a small mean Poisson sampler. */
053    private static final SharedStateDiscreteSampler NO_SMALL_MEAN_POISSON_SAMPLER =
054        new SharedStateDiscreteSampler() {
055            @Override
056            public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
057                // No requirement for RNG
058                return this;
059            }
060
061            @Override
062            public int sample() {
063                // No Poisson sample
064                return 0;
065            }
066        };
067
068    static {
069        // Create without a cache.
070        NO_CACHE_FACTORIAL_LOG = FactorialLog.create();
071    }
072
073    /** Underlying source of randomness. */
074    private final UniformRandomProvider rng;
075    /** Exponential. */
076    private final SharedStateContinuousSampler exponential;
077    /** Gaussian. */
078    private final SharedStateContinuousSampler gaussian;
079    /** Local class to compute {@code log(n!)}. This may have cached values. */
080    private final InternalUtils.FactorialLog factorialLog;
081
082    // Working values
083
084    /** Algorithm constant: {@code Math.floor(mean)}. */
085    private final double lambda;
086    /** Algorithm constant: {@code Math.log(lambda)}. */
087    private final double logLambda;
088    /** Algorithm constant: {@code factorialLog((int) lambda)}. */
089    private final double logLambdaFactorial;
090    /** Algorithm constant: {@code Math.sqrt(lambda * Math.log(32 * lambda / Math.PI + 1))}. */
091    private final double delta;
092    /** Algorithm constant: {@code delta / 2}. */
093    private final double halfDelta;
094    /** Algorithm constant: {@code Math.sqrt(lambda + halfDelta)}. */
095    private final double sqrtLambdaPlusHalfDelta;
096    /** Algorithm constant: {@code 2 * lambda + delta}. */
097    private final double twolpd;
098    /**
099     * Algorithm constant: {@code a1 / aSum}.
100     * <ul>
101     *  <li>{@code a1 = Math.sqrt(Math.PI * twolpd) * Math.exp(c1)}</li>
102     *  <li>{@code aSum = a1 + a2 + 1}</li>
103     * </ul>
104     */
105    private final double p1;
106    /**
107     * Algorithm constant: {@code a2 / aSum}.
108     * <ul>
109     *  <li>{@code a2 = (twolpd / delta) * Math.exp(-delta * (1 + delta) / twolpd)}</li>
110     *  <li>{@code aSum = a1 + a2 + 1}</li>
111     * </ul>
112     */
113    private final double p2;
114    /** Algorithm constant: {@code 1 / (8 * lambda)}. */
115    private final double c1;
116
117    /** The internal Poisson sampler for the lambda fraction. */
118    private final SharedStateDiscreteSampler smallMeanPoissonSampler;
119
120    /**
121     * @param rng Generator of uniformly distributed random numbers.
122     * @param mean Mean.
123     * @throws IllegalArgumentException if {@code mean < 1} or
124     * {@code mean > 0.5 *} {@link Integer#MAX_VALUE}.
125     */
126    public LargeMeanPoissonSampler(UniformRandomProvider rng,
127                                   double mean) {
128        if (mean < 1) {
129            throw new IllegalArgumentException("mean is not >= 1: " + mean);
130        }
131        // The algorithm is not valid if Math.floor(mean) is not an integer.
132        if (mean > MAX_MEAN) {
133            throw new IllegalArgumentException("mean " + mean + " > " + MAX_MEAN);
134        }
135        this.rng = rng;
136
137        gaussian = ZigguratSampler.NormalizedGaussian.of(rng);
138        exponential = ZigguratSampler.Exponential.of(rng);
139        // Plain constructor uses the uncached function.
140        factorialLog = NO_CACHE_FACTORIAL_LOG;
141
142        // Cache values used in the algorithm
143        lambda = Math.floor(mean);
144        logLambda = Math.log(lambda);
145        logLambdaFactorial = getFactorialLog((int) lambda);
146        delta = Math.sqrt(lambda * Math.log(32 * lambda / Math.PI + 1));
147        halfDelta = delta / 2;
148        sqrtLambdaPlusHalfDelta = Math.sqrt(lambda + halfDelta);
149        twolpd = 2 * lambda + delta;
150        c1 = 1 / (8 * lambda);
151        final double a1 = Math.sqrt(Math.PI * twolpd) * Math.exp(c1);
152        final double a2 = (twolpd / delta) * Math.exp(-delta * (1 + delta) / twolpd);
153        final double aSum = a1 + a2 + 1;
154        p1 = a1 / aSum;
155        p2 = a2 / aSum;
156
157        // The algorithm requires a Poisson sample from the remaining lambda fraction.
158        final double lambdaFractional = mean - lambda;
159        smallMeanPoissonSampler = (lambdaFractional < Double.MIN_VALUE) ?
160            NO_SMALL_MEAN_POISSON_SAMPLER : // Not used.
161            KempSmallMeanPoissonSampler.of(rng, lambdaFractional);
162    }
163
164    /**
165     * Instantiates a sampler using a precomputed state.
166     *
167     * @param rng              Generator of uniformly distributed random numbers.
168     * @param state            The state for {@code lambda = (int)Math.floor(mean)}.
169     * @param lambdaFractional The lambda fractional value
170     *                         ({@code mean - (int)Math.floor(mean))}.
171     * @throws IllegalArgumentException
172     *                         if {@code lambdaFractional < 0 || lambdaFractional >= 1}.
173     */
174    LargeMeanPoissonSampler(UniformRandomProvider rng,
175                            LargeMeanPoissonSamplerState state,
176                            double lambdaFractional) {
177        if (lambdaFractional < 0 || lambdaFractional >= 1) {
178            throw new IllegalArgumentException(
179                    "lambdaFractional must be in the range 0 (inclusive) to 1 (exclusive): " + lambdaFractional);
180        }
181        this.rng = rng;
182
183        gaussian = ZigguratSampler.NormalizedGaussian.of(rng);
184        exponential = ZigguratSampler.Exponential.of(rng);
185        // Plain constructor uses the uncached function.
186        factorialLog = NO_CACHE_FACTORIAL_LOG;
187
188        // Use the state to initialize the algorithm
189        lambda = state.getLambdaRaw();
190        logLambda = state.getLogLambda();
191        logLambdaFactorial = state.getLogLambdaFactorial();
192        delta = state.getDelta();
193        halfDelta = state.getHalfDelta();
194        sqrtLambdaPlusHalfDelta = state.getSqrtLambdaPlusHalfDelta();
195        twolpd = state.getTwolpd();
196        p1 = state.getP1();
197        p2 = state.getP2();
198        c1 = state.getC1();
199
200        // The algorithm requires a Poisson sample from the remaining lambda fraction.
201        smallMeanPoissonSampler = (lambdaFractional < Double.MIN_VALUE) ?
202            NO_SMALL_MEAN_POISSON_SAMPLER : // Not used.
203            KempSmallMeanPoissonSampler.of(rng, lambdaFractional);
204    }
205
206    /**
207     * @param rng Generator of uniformly distributed random numbers.
208     * @param source Source to copy.
209     */
210    private LargeMeanPoissonSampler(UniformRandomProvider rng,
211                                    LargeMeanPoissonSampler source) {
212        this.rng = rng;
213
214        gaussian = source.gaussian.withUniformRandomProvider(rng);
215        exponential = source.exponential.withUniformRandomProvider(rng);
216        // Reuse the cache
217        factorialLog = source.factorialLog;
218
219        lambda = source.lambda;
220        logLambda = source.logLambda;
221        logLambdaFactorial = source.logLambdaFactorial;
222        delta = source.delta;
223        halfDelta = source.halfDelta;
224        sqrtLambdaPlusHalfDelta = source.sqrtLambdaPlusHalfDelta;
225        twolpd = source.twolpd;
226        p1 = source.p1;
227        p2 = source.p2;
228        c1 = source.c1;
229
230        // Share the state of the small sampler
231        smallMeanPoissonSampler = source.smallMeanPoissonSampler.withUniformRandomProvider(rng);
232    }
233
234    /** {@inheritDoc} */
235    @Override
236    public int sample() {
237        // This will never be null. It may be a no-op delegate that returns zero.
238        final int y2 = smallMeanPoissonSampler.sample();
239
240        double x;
241        double y;
242        double v;
243        int a;
244        double t;
245        double qr;
246        double qa;
247        while (true) {
248            // Step 1:
249            final double u = rng.nextDouble();
250            if (u <= p1) {
251                // Step 2:
252                final double n = gaussian.sample();
253                x = n * sqrtLambdaPlusHalfDelta - 0.5d;
254                if (x > delta || x < -lambda) {
255                    continue;
256                }
257                y = x < 0 ? Math.floor(x) : Math.ceil(x);
258                final double e = exponential.sample();
259                v = -e - 0.5 * n * n + c1;
260            } else {
261                // Step 3:
262                if (u > p1 + p2) {
263                    y = lambda;
264                    break;
265                }
266                x = delta + (twolpd / delta) * exponential.sample();
267                y = Math.ceil(x);
268                v = -exponential.sample() - delta * (x + 1) / twolpd;
269            }
270            // The Squeeze Principle
271            // Step 4.1:
272            a = x < 0 ? 1 : 0;
273            t = y * (y + 1) / (2 * lambda);
274            // Step 4.2
275            if (v < -t && a == 0) {
276                y = lambda + y;
277                break;
278            }
279            // Step 4.3:
280            qr = t * ((2 * y + 1) / (6 * lambda) - 1);
281            qa = qr - (t * t) / (3 * (lambda + a * (y + 1)));
282            // Step 4.4:
283            if (v < qa) {
284                y = lambda + y;
285                break;
286            }
287            // Step 4.5:
288            if (v > qr) {
289                continue;
290            }
291            // Step 4.6:
292            if (v < y * logLambda - getFactorialLog((int) (y + lambda)) + logLambdaFactorial) {
293                y = lambda + y;
294                break;
295            }
296        }
297
298        return (int) Math.min(y2 + (long) y, Integer.MAX_VALUE);
299    }
300
301    /**
302     * Compute the natural logarithm of the factorial of {@code n}.
303     *
304     * @param n Argument.
305     * @return {@code log(n!)}
306     * @throws IllegalArgumentException if {@code n < 0}.
307     */
308    private double getFactorialLog(int n) {
309        return factorialLog.value(n);
310    }
311
312    /** {@inheritDoc} */
313    @Override
314    public String toString() {
315        return "Large Mean Poisson deviate [" + rng.toString() + "]";
316    }
317
318    /**
319     * {@inheritDoc}
320     *
321     * @since 1.3
322     */
323    @Override
324    public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
325        return new LargeMeanPoissonSampler(rng, this);
326    }
327
328    /**
329     * Creates a new Poisson distribution sampler.
330     *
331     * @param rng Generator of uniformly distributed random numbers.
332     * @param mean Mean.
333     * @return the sampler
334     * @throws IllegalArgumentException if {@code mean < 1} or {@code mean > 0.5 *}
335     * {@link Integer#MAX_VALUE}.
336     * @since 1.3
337     */
338    public static SharedStateDiscreteSampler of(UniformRandomProvider rng,
339                                                double mean) {
340        return new LargeMeanPoissonSampler(rng, mean);
341    }
342
343    /**
344     * Gets the initialisation state of the sampler.
345     *
346     * <p>The state is computed using an integer {@code lambda} value of
347     * {@code lambda = (int)Math.floor(mean)}.
348     *
349     * <p>The state will be suitable for reconstructing a new sampler with a mean
350     * in the range {@code lambda <= mean < lambda+1} using
351     * {@link #LargeMeanPoissonSampler(UniformRandomProvider, LargeMeanPoissonSamplerState, double)}.
352     *
353     * @return the state
354     */
355    LargeMeanPoissonSamplerState getState() {
356        return new LargeMeanPoissonSamplerState(lambda, logLambda, logLambdaFactorial,
357                delta, halfDelta, sqrtLambdaPlusHalfDelta, twolpd, p1, p2, c1);
358    }
359
360    /**
361     * Encapsulate the state of the sampler. The state is valid for construction of
362     * a sampler in the range {@code lambda <= mean < lambda+1}.
363     *
364     * <p>This class is immutable.
365     *
366     * @see #getLambda()
367     */
368    static final class LargeMeanPoissonSamplerState {
369        /** Algorithm constant {@code lambda}. */
370        private final double lambda;
371        /** Algorithm constant {@code logLambda}. */
372        private final double logLambda;
373        /** Algorithm constant {@code logLambdaFactorial}. */
374        private final double logLambdaFactorial;
375        /** Algorithm constant {@code delta}. */
376        private final double delta;
377        /** Algorithm constant {@code halfDelta}. */
378        private final double halfDelta;
379        /** Algorithm constant {@code sqrtLambdaPlusHalfDelta}. */
380        private final double sqrtLambdaPlusHalfDelta;
381        /** Algorithm constant {@code twolpd}. */
382        private final double twolpd;
383        /** Algorithm constant {@code p1}. */
384        private final double p1;
385        /** Algorithm constant {@code p2}. */
386        private final double p2;
387        /** Algorithm constant {@code c1}. */
388        private final double c1;
389
390        /**
391         * Creates the state.
392         *
393         * <p>The state is valid for construction of a sampler in the range
394         * {@code lambda <= mean < lambda+1} where {@code lambda} is an integer.
395         *
396         * @param lambda the lambda
397         * @param logLambda the log lambda
398         * @param logLambdaFactorial the log lambda factorial
399         * @param delta the delta
400         * @param halfDelta the half delta
401         * @param sqrtLambdaPlusHalfDelta the sqrt(lambda+half delta)
402         * @param twolpd the two lambda plus delta
403         * @param p1 the p1 constant
404         * @param p2 the p2 constant
405         * @param c1 the c1 constant
406         */
407        LargeMeanPoissonSamplerState(double lambda, double logLambda,
408                double logLambdaFactorial, double delta, double halfDelta,
409                double sqrtLambdaPlusHalfDelta, double twolpd,
410                double p1, double p2, double c1) {
411            this.lambda = lambda;
412            this.logLambda = logLambda;
413            this.logLambdaFactorial = logLambdaFactorial;
414            this.delta = delta;
415            this.halfDelta = halfDelta;
416            this.sqrtLambdaPlusHalfDelta = sqrtLambdaPlusHalfDelta;
417            this.twolpd = twolpd;
418            this.p1 = p1;
419            this.p2 = p2;
420            this.c1 = c1;
421        }
422
423        /**
424         * Get the lambda value for the state.
425         *
426         * <p>Equal to {@code floor(mean)} for a Poisson sampler.
427         * @return the lambda value
428         */
429        int getLambda() {
430            return (int) getLambdaRaw();
431        }
432
433        /**
434         * @return algorithm constant {@code lambda}
435         */
436        double getLambdaRaw() {
437            return lambda;
438        }
439
440        /**
441         * @return algorithm constant {@code logLambda}
442         */
443        double getLogLambda() {
444            return logLambda;
445        }
446
447        /**
448         * @return algorithm constant {@code logLambdaFactorial}
449         */
450        double getLogLambdaFactorial() {
451            return logLambdaFactorial;
452        }
453
454        /**
455         * @return algorithm constant {@code delta}
456         */
457        double getDelta() {
458            return delta;
459        }
460
461        /**
462         * @return algorithm constant {@code halfDelta}
463         */
464        double getHalfDelta() {
465            return halfDelta;
466        }
467
468        /**
469         * @return algorithm constant {@code sqrtLambdaPlusHalfDelta}
470         */
471        double getSqrtLambdaPlusHalfDelta() {
472            return sqrtLambdaPlusHalfDelta;
473        }
474
475        /**
476         * @return algorithm constant {@code twolpd}
477         */
478        double getTwolpd() {
479            return twolpd;
480        }
481
482        /**
483         * @return algorithm constant {@code p1}
484         */
485        double getP1() {
486            return p1;
487        }
488
489        /**
490         * @return algorithm constant {@code p2}
491         */
492        double getP2() {
493            return p2;
494        }
495
496        /**
497         * @return algorithm constant {@code c1}
498         */
499        double getC1() {
500            return c1;
501        }
502    }
503}