001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.rng.sampling.distribution;
018
019import org.apache.commons.rng.UniformRandomProvider;
020import org.apache.commons.rng.sampling.distribution.InternalUtils.FactorialLog;
021
022/**
023 * Sampler for the <a href="http://mathworld.wolfram.com/PoissonDistribution.html">Poisson distribution</a>.
024 *
025 * <ul>
026 *  <li>
027 *   For large means, we use the rejection algorithm described in
028 *   <blockquote>
029 *    Devroye, Luc. (1981).<i>The Computer Generation of Poisson Random Variables</i><br>
030 *    <strong>Computing</strong> vol. 26 pp. 197-207.
031 *   </blockquote>
032 *  </li>
033 * </ul>
034 *
035 * <p>This sampler is suitable for {@code mean >= 40}.</p>
036 *
037 * <p>Sampling uses:</p>
038 *
039 * <ul>
040 *   <li>{@link UniformRandomProvider#nextLong()}
041 *   <li>{@link UniformRandomProvider#nextDouble()}
042 * </ul>
043 *
044 * @since 1.1
045 */
046public class LargeMeanPoissonSampler
047    implements SharedStateDiscreteSampler {
048    /** Upper bound to avoid truncation. */
049    private static final double MAX_MEAN = 0.5 * Integer.MAX_VALUE;
050    /** Class to compute {@code log(n!)}. This has no cached values. */
051    private static final InternalUtils.FactorialLog NO_CACHE_FACTORIAL_LOG;
052    /** Used when there is no requirement for a small mean Poisson sampler. */
053    private static final SharedStateDiscreteSampler NO_SMALL_MEAN_POISSON_SAMPLER =
054        new SharedStateDiscreteSampler() {
055            @Override
056            public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
057                // No requirement for RNG
058                return this;
059            }
060
061            @Override
062            public int sample() {
063                // No Poisson sample
064                return 0;
065            }
066        };
067
068    static {
069        // Create without a cache.
070        NO_CACHE_FACTORIAL_LOG = FactorialLog.create();
071    }
072
073    /** Underlying source of randomness. */
074    private final UniformRandomProvider rng;
075    /** Exponential. */
076    private final SharedStateContinuousSampler exponential;
077    /** Gaussian. */
078    private final SharedStateContinuousSampler gaussian;
079    /** Local class to compute {@code log(n!)}. This may have cached values. */
080    private final InternalUtils.FactorialLog factorialLog;
081
082    // Working values
083
084    /** Algorithm constant: {@code Math.floor(mean)}. */
085    private final double lambda;
086    /** Algorithm constant: {@code Math.log(lambda)}. */
087    private final double logLambda;
088    /** Algorithm constant: {@code factorialLog((int) lambda)}. */
089    private final double logLambdaFactorial;
090    /** Algorithm constant: {@code Math.sqrt(lambda * Math.log(32 * lambda / Math.PI + 1))}. */
091    private final double delta;
092    /** Algorithm constant: {@code delta / 2}. */
093    private final double halfDelta;
094    /** Algorithm constant: {@code 2 * lambda + delta}. */
095    private final double twolpd;
096    /**
097     * Algorithm constant: {@code a1 / aSum}.
098     * <ul>
099     *  <li>{@code a1 = Math.sqrt(Math.PI * twolpd) * Math.exp(c1)}</li>
100     *  <li>{@code aSum = a1 + a2 + 1}</li>
101     * </ul>
102     */
103    private final double p1;
104    /**
105     * Algorithm constant: {@code a2 / aSum}.
106     * <ul>
107     *  <li>{@code a2 = (twolpd / delta) * Math.exp(-delta * (1 + delta) / twolpd)}</li>
108     *  <li>{@code aSum = a1 + a2 + 1}</li>
109     * </ul>
110     */
111    private final double p2;
112    /** Algorithm constant: {@code 1 / (8 * lambda)}. */
113    private final double c1;
114
115    /** The internal Poisson sampler for the lambda fraction. */
116    private final SharedStateDiscreteSampler smallMeanPoissonSampler;
117
118    /**
119     * @param rng Generator of uniformly distributed random numbers.
120     * @param mean Mean.
121     * @throws IllegalArgumentException if {@code mean < 1} or
122     * {@code mean > 0.5 *} {@link Integer#MAX_VALUE}.
123     */
124    public LargeMeanPoissonSampler(UniformRandomProvider rng,
125                                   double mean) {
126        if (mean < 1) {
127            throw new IllegalArgumentException("mean is not >= 1: " + mean);
128        }
129        // The algorithm is not valid if Math.floor(mean) is not an integer.
130        if (mean > MAX_MEAN) {
131            throw new IllegalArgumentException("mean " + mean + " > " + MAX_MEAN);
132        }
133        this.rng = rng;
134
135        gaussian = new ZigguratNormalizedGaussianSampler(rng);
136        exponential = AhrensDieterExponentialSampler.of(rng, 1);
137        // Plain constructor uses the uncached function.
138        factorialLog = NO_CACHE_FACTORIAL_LOG;
139
140        // Cache values used in the algorithm
141        lambda = Math.floor(mean);
142        logLambda = Math.log(lambda);
143        logLambdaFactorial = getFactorialLog((int) lambda);
144        delta = Math.sqrt(lambda * Math.log(32 * lambda / Math.PI + 1));
145        halfDelta = delta / 2;
146        twolpd = 2 * lambda + delta;
147        c1 = 1 / (8 * lambda);
148        final double a1 = Math.sqrt(Math.PI * twolpd) * Math.exp(c1);
149        final double a2 = (twolpd / delta) * Math.exp(-delta * (1 + delta) / twolpd);
150        final double aSum = a1 + a2 + 1;
151        p1 = a1 / aSum;
152        p2 = a2 / aSum;
153
154        // The algorithm requires a Poisson sample from the remaining lambda fraction.
155        final double lambdaFractional = mean - lambda;
156        smallMeanPoissonSampler = (lambdaFractional < Double.MIN_VALUE) ?
157            NO_SMALL_MEAN_POISSON_SAMPLER : // Not used.
158            KempSmallMeanPoissonSampler.of(rng, lambdaFractional);
159    }
160
161    /**
162     * Instantiates a sampler using a precomputed state.
163     *
164     * @param rng              Generator of uniformly distributed random numbers.
165     * @param state            The state for {@code lambda = (int)Math.floor(mean)}.
166     * @param lambdaFractional The lambda fractional value
167     *                         ({@code mean - (int)Math.floor(mean))}.
168     * @throws IllegalArgumentException
169     *                         if {@code lambdaFractional < 0 || lambdaFractional >= 1}.
170     */
171    LargeMeanPoissonSampler(UniformRandomProvider rng,
172                            LargeMeanPoissonSamplerState state,
173                            double lambdaFractional) {
174        if (lambdaFractional < 0 || lambdaFractional >= 1) {
175            throw new IllegalArgumentException(
176                    "lambdaFractional must be in the range 0 (inclusive) to 1 (exclusive): " + lambdaFractional);
177        }
178        this.rng = rng;
179
180        gaussian = new ZigguratNormalizedGaussianSampler(rng);
181        exponential = AhrensDieterExponentialSampler.of(rng, 1);
182        // Plain constructor uses the uncached function.
183        factorialLog = NO_CACHE_FACTORIAL_LOG;
184
185        // Use the state to initialise the algorithm
186        lambda = state.getLambdaRaw();
187        logLambda = state.getLogLambda();
188        logLambdaFactorial = state.getLogLambdaFactorial();
189        delta = state.getDelta();
190        halfDelta = state.getHalfDelta();
191        twolpd = state.getTwolpd();
192        p1 = state.getP1();
193        p2 = state.getP2();
194        c1 = state.getC1();
195
196        // The algorithm requires a Poisson sample from the remaining lambda fraction.
197        smallMeanPoissonSampler = (lambdaFractional < Double.MIN_VALUE) ?
198            NO_SMALL_MEAN_POISSON_SAMPLER : // Not used.
199            KempSmallMeanPoissonSampler.of(rng, lambdaFractional);
200    }
201
202    /**
203     * @param rng Generator of uniformly distributed random numbers.
204     * @param source Source to copy.
205     */
206    private LargeMeanPoissonSampler(UniformRandomProvider rng,
207                                    LargeMeanPoissonSampler source) {
208        this.rng = rng;
209
210        gaussian = source.gaussian.withUniformRandomProvider(rng);
211        exponential = source.exponential.withUniformRandomProvider(rng);
212        // Reuse the cache
213        factorialLog = source.factorialLog;
214
215        lambda = source.lambda;
216        logLambda = source.logLambda;
217        logLambdaFactorial = source.logLambdaFactorial;
218        delta = source.delta;
219        halfDelta = source.halfDelta;
220        twolpd = source.twolpd;
221        p1 = source.p1;
222        p2 = source.p2;
223        c1 = source.c1;
224
225        // Share the state of the small sampler
226        smallMeanPoissonSampler = source.smallMeanPoissonSampler.withUniformRandomProvider(rng);
227    }
228
229    /** {@inheritDoc} */
230    @Override
231    public int sample() {
232        // This will never be null. It may be a no-op delegate that returns zero.
233        final int y2 = smallMeanPoissonSampler.sample();
234
235        double x;
236        double y;
237        double v;
238        int a;
239        double t;
240        double qr;
241        double qa;
242        while (true) {
243            // Step 1:
244            final double u = rng.nextDouble();
245            if (u <= p1) {
246                // Step 2:
247                final double n = gaussian.sample();
248                x = n * Math.sqrt(lambda + halfDelta) - 0.5d;
249                if (x > delta || x < -lambda) {
250                    continue;
251                }
252                y = x < 0 ? Math.floor(x) : Math.ceil(x);
253                final double e = exponential.sample();
254                v = -e - 0.5 * n * n + c1;
255            } else {
256                // Step 3:
257                if (u > p1 + p2) {
258                    y = lambda;
259                    break;
260                }
261                x = delta + (twolpd / delta) * exponential.sample();
262                y = Math.ceil(x);
263                v = -exponential.sample() - delta * (x + 1) / twolpd;
264            }
265            // The Squeeze Principle
266            // Step 4.1:
267            a = x < 0 ? 1 : 0;
268            t = y * (y + 1) / (2 * lambda);
269            // Step 4.2
270            if (v < -t && a == 0) {
271                y = lambda + y;
272                break;
273            }
274            // Step 4.3:
275            qr = t * ((2 * y + 1) / (6 * lambda) - 1);
276            qa = qr - (t * t) / (3 * (lambda + a * (y + 1)));
277            // Step 4.4:
278            if (v < qa) {
279                y = lambda + y;
280                break;
281            }
282            // Step 4.5:
283            if (v > qr) {
284                continue;
285            }
286            // Step 4.6:
287            if (v < y * logLambda - getFactorialLog((int) (y + lambda)) + logLambdaFactorial) {
288                y = lambda + y;
289                break;
290            }
291        }
292
293        return (int) Math.min(y2 + (long) y, Integer.MAX_VALUE);
294    }
295
296    /**
297     * Compute the natural logarithm of the factorial of {@code n}.
298     *
299     * @param n Argument.
300     * @return {@code log(n!)}
301     * @throws IllegalArgumentException if {@code n < 0}.
302     */
303    private double getFactorialLog(int n) {
304        return factorialLog.value(n);
305    }
306
307    /** {@inheritDoc} */
308    @Override
309    public String toString() {
310        return "Large Mean Poisson deviate [" + rng.toString() + "]";
311    }
312
313    /**
314     * {@inheritDoc}
315     *
316     * @since 1.3
317     */
318    @Override
319    public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
320        return new LargeMeanPoissonSampler(rng, this);
321    }
322
323    /**
324     * Creates a new Poisson distribution sampler.
325     *
326     * @param rng Generator of uniformly distributed random numbers.
327     * @param mean Mean.
328     * @return the sampler
329     * @throws IllegalArgumentException if {@code mean < 1} or {@code mean > 0.5 *}
330     * {@link Integer#MAX_VALUE}.
331     * @since 1.3
332     */
333    public static SharedStateDiscreteSampler of(UniformRandomProvider rng,
334                                                double mean) {
335        return new LargeMeanPoissonSampler(rng, mean);
336    }
337    /**
338     * Gets the initialisation state of the sampler.
339     *
340     * <p>The state is computed using an integer {@code lambda} value of
341     * {@code lambda = (int)Math.floor(mean)}.
342     *
343     * <p>The state will be suitable for reconstructing a new sampler with a mean
344     * in the range {@code lambda <= mean < lambda+1} using
345     * {@link #LargeMeanPoissonSampler(UniformRandomProvider, LargeMeanPoissonSamplerState, double)}.
346     *
347     * @return the state
348     */
349    LargeMeanPoissonSamplerState getState() {
350        return new LargeMeanPoissonSamplerState(lambda, logLambda, logLambdaFactorial,
351                delta, halfDelta, twolpd, p1, p2, c1);
352    }
353
354    /**
355     * Encapsulate the state of the sampler. The state is valid for construction of
356     * a sampler in the range {@code lambda <= mean < lambda+1}.
357     *
358     * <p>This class is immutable.
359     *
360     * @see #getLambda()
361     */
362    static final class LargeMeanPoissonSamplerState {
363        /** Algorithm constant {@code lambda}. */
364        private final double lambda;
365        /** Algorithm constant {@code logLambda}. */
366        private final double logLambda;
367        /** Algorithm constant {@code logLambdaFactorial}. */
368        private final double logLambdaFactorial;
369        /** Algorithm constant {@code delta}. */
370        private final double delta;
371        /** Algorithm constant {@code halfDelta}. */
372        private final double halfDelta;
373        /** Algorithm constant {@code twolpd}. */
374        private final double twolpd;
375        /** Algorithm constant {@code p1}. */
376        private final double p1;
377        /** Algorithm constant {@code p2}. */
378        private final double p2;
379        /** Algorithm constant {@code c1}. */
380        private final double c1;
381
382        /**
383         * Creates the state.
384         *
385         * <p>The state is valid for construction of a sampler in the range
386         * {@code lambda <= mean < lambda+1} where {@code lambda} is an integer.
387         *
388         * @param lambda the lambda
389         * @param logLambda the log lambda
390         * @param logLambdaFactorial the log lambda factorial
391         * @param delta the delta
392         * @param halfDelta the half delta
393         * @param twolpd the two lambda plus delta
394         * @param p1 the p1 constant
395         * @param p2 the p2 constant
396         * @param c1 the c1 constant
397         */
398        LargeMeanPoissonSamplerState(double lambda, double logLambda,
399                double logLambdaFactorial, double delta, double halfDelta, double twolpd,
400                double p1, double p2, double c1) {
401            this.lambda = lambda;
402            this.logLambda = logLambda;
403            this.logLambdaFactorial = logLambdaFactorial;
404            this.delta = delta;
405            this.halfDelta = halfDelta;
406            this.twolpd = twolpd;
407            this.p1 = p1;
408            this.p2 = p2;
409            this.c1 = c1;
410        }
411
412        /**
413         * Get the lambda value for the state.
414         *
415         * <p>Equal to {@code floor(mean)} for a Poisson sampler.
416         * @return the lambda value
417         */
418        int getLambda() {
419            return (int) getLambdaRaw();
420        }
421
422        /**
423         * @return algorithm constant {@code lambda}
424         */
425        double getLambdaRaw() {
426            return lambda;
427        }
428
429        /**
430         * @return algorithm constant {@code logLambda}
431         */
432        double getLogLambda() {
433            return logLambda;
434        }
435
436        /**
437         * @return algorithm constant {@code logLambdaFactorial}
438         */
439        double getLogLambdaFactorial() {
440            return logLambdaFactorial;
441        }
442
443        /**
444         * @return algorithm constant {@code delta}
445         */
446        double getDelta() {
447            return delta;
448        }
449
450        /**
451         * @return algorithm constant {@code halfDelta}
452         */
453        double getHalfDelta() {
454            return halfDelta;
455        }
456
457        /**
458         * @return algorithm constant {@code twolpd}
459         */
460        double getTwolpd() {
461            return twolpd;
462        }
463
464        /**
465         * @return algorithm constant {@code p1}
466         */
467        double getP1() {
468            return p1;
469        }
470
471        /**
472         * @return algorithm constant {@code p2}
473         */
474        double getP2() {
475            return p2;
476        }
477
478        /**
479         * @return algorithm constant {@code c1}
480         */
481        double getC1() {
482            return c1;
483        }
484    }
485}