001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.rng.sampling.distribution;
018
019import org.apache.commons.rng.UniformRandomProvider;
020
021/**
022 * Sampler for a discrete distribution using an optimised look-up table.
023 *
024 * <ul>
025 *  <li>
026 *   The method requires 30-bit integer probabilities that sum to 2<sup>30</sup> as described
027 *   in George Marsaglia, Wai Wan Tsang, Jingbo Wang (2004) Fast Generation of Discrete
028 *   Random Variables. Journal of Statistical Software. Vol. 11, Issue. 3, pp. 1-11.
029 *  </li>
030 * </ul>
031 *
032 * <p>Sampling uses 1 call to {@link UniformRandomProvider#nextInt()}.</p>
033 *
034 * <p>Memory requirements depend on the maximum number of possible sample values, {@code n},
035 * and the values for the probabilities. Storage is optimised for {@code n}. The worst case
036 * scenario is a uniform distribution of the maximum sample size. This is capped at 0.06MB for
037 * {@code n <= } 2<sup>8</sup>, 17.0MB for {@code n <= } 2<sup>16</sup>, and 4.3GB for
038 * {@code n <=} 2<sup>30</sup>. Realistic requirements will be in the kB range.</p>
039 *
040 * <p>The sampler supports the following distributions:</p>
041 *
042 * <ul>
043 *  <li>Enumerated distribution (probabilities must be provided for each sample)</li>
044 *  <li>Poisson distribution up to {@code mean = 1024}</li>
045 *  <li>Binomial distribution up to {@code trials = 65535}</li>
046 * </ul>
047 *
048 * @see <a href="https://dx.doi.org/10.18637/jss.v011.i03">Margsglia, et al (2004) JSS Vol.
049 * 11, Issue 3</a>
050 * @since 1.3
051 */
052public final class MarsagliaTsangWangDiscreteSampler {
053    /** The value 2<sup>8</sup> as an {@code int}. */
054    private static final int INT_8 = 1 << 8;
055    /** The value 2<sup>16</sup> as an {@code int}. */
056    private static final int INT_16 = 1 << 16;
057    /** The value 2<sup>30</sup> as an {@code int}. */
058    private static final int INT_30 = 1 << 30;
059    /** The value 2<sup>31</sup> as a {@code double}. */
060    private static final double DOUBLE_31 = 1L << 31;
061
062    // =========================================================================
063    // Implementation note:
064    //
065    // This sampler uses prepared look-up tables that are searched using a single
066    // random int variate. The look-up tables contain the sample value. The tables
067    // are constructed using probabilities that sum to 2^30. The original paper
068    // by Marsaglia, et al (2004) describes the use of 5, 3, or 2 look-up tables
069    // indexed using digits of base 2^6, 2^10 or 2^15. Currently only base 64 (2^6)
070    // is supported using 5 look-up tables.
071    //
072    // The implementations use 8, 16 or 32 bit storage tables to support different
073    // distribution sizes with optimal storage. Separate class implementations of
074    // the same algorithm allow array storage to be accessed directly from 1D tables.
075    // This provides a performance gain over using: abstracted storage accessed via
076    // an interface; or a single 2D table.
077    //
078    // To allow the optimal implementation to be chosen the sampler is created
079    // using factory methods. The sampler supports any probability distribution
080    // when provided via an array of probabilities and the Poisson and Binomial
081    // distributions for a restricted set of parameters. The restrictions are
082    // imposed by the requirement to compute the entire probability distribution
083    // from the controlling parameter(s) using a recursive method. Factory
084    // constructors return a SharedStateDiscreteSampler instance. Each distribution
085    // type is contained in an inner class.
086    // =========================================================================
087
088    /**
089     * The base class for Marsaglia-Tsang-Wang samplers.
090     */
091    private abstract static class AbstractMarsagliaTsangWangDiscreteSampler
092            implements SharedStateDiscreteSampler {
093        /** Underlying source of randomness. */
094        protected final UniformRandomProvider rng;
095
096        /** The name of the distribution. */
097        private final String distributionName;
098
099        /**
100         * @param rng Generator of uniformly distributed random numbers.
101         * @param distributionName Distribution name.
102         */
103        AbstractMarsagliaTsangWangDiscreteSampler(UniformRandomProvider rng,
104                                                  String distributionName) {
105            this.rng = rng;
106            this.distributionName = distributionName;
107        }
108
109        /**
110         * @param rng Generator of uniformly distributed random numbers.
111         * @param source Source to copy.
112         */
113        AbstractMarsagliaTsangWangDiscreteSampler(UniformRandomProvider rng,
114                                                  AbstractMarsagliaTsangWangDiscreteSampler source) {
115            this.rng = rng;
116            this.distributionName = source.distributionName;
117        }
118
119        /** {@inheritDoc} */
120        @Override
121        public String toString() {
122            return "Marsaglia Tsang Wang " + distributionName + " deviate [" + rng.toString() + "]";
123        }
124    }
125
126    /**
127     * An implementation for the sample algorithm based on the decomposition of the
128     * index in the range {@code [0,2^30)} into 5 base-64 digits with 8-bit backing storage.
129     */
130    private static final class MarsagliaTsangWangBase64Int8DiscreteSampler
131        extends AbstractMarsagliaTsangWangDiscreteSampler {
132        /** The mask to convert a {@code byte} to an unsigned 8-bit integer. */
133        private static final int MASK = 0xff;
134
135        /** Limit for look-up table 1. */
136        private final int t1;
137        /** Limit for look-up table 2. */
138        private final int t2;
139        /** Limit for look-up table 3. */
140        private final int t3;
141        /** Limit for look-up table 4. */
142        private final int t4;
143
144        /** Look-up table table1. */
145        private final byte[] table1;
146        /** Look-up table table2. */
147        private final byte[] table2;
148        /** Look-up table table3. */
149        private final byte[] table3;
150        /** Look-up table table4. */
151        private final byte[] table4;
152        /** Look-up table table5. */
153        private final byte[] table5;
154
155        /**
156         * @param rng Generator of uniformly distributed random numbers.
157         * @param distributionName Distribution name.
158         * @param prob The probabilities.
159         * @param offset The offset (must be positive).
160         */
161        MarsagliaTsangWangBase64Int8DiscreteSampler(UniformRandomProvider rng,
162                                                    String distributionName,
163                                                    int[] prob,
164                                                    int offset) {
165            super(rng, distributionName);
166
167            // Get table sizes for each base-64 digit
168            int n1 = 0;
169            int n2 = 0;
170            int n3 = 0;
171            int n4 = 0;
172            int n5 = 0;
173            for (final int m : prob) {
174                n1 += getBase64Digit(m, 1);
175                n2 += getBase64Digit(m, 2);
176                n3 += getBase64Digit(m, 3);
177                n4 += getBase64Digit(m, 4);
178                n5 += getBase64Digit(m, 5);
179            }
180
181            table1 = new byte[n1];
182            table2 = new byte[n2];
183            table3 = new byte[n3];
184            table4 = new byte[n4];
185            table5 = new byte[n5];
186
187            // Compute offsets
188            t1 = n1 << 24;
189            t2 = t1 + (n2 << 18);
190            t3 = t2 + (n3 << 12);
191            t4 = t3 + (n4 << 6);
192            n1 = 0;
193            n2 = 0;
194            n3 = 0;
195            n4 = 0;
196            n5 = 0;
197
198            // Fill tables
199            for (int i = 0; i < prob.length; i++) {
200                final int m = prob[i];
201                // Primitive type conversion will extract lower 8 bits
202                final byte k = (byte) (i + offset);
203                n1 = fill(table1, n1, n1 + getBase64Digit(m, 1), k);
204                n2 = fill(table2, n2, n2 + getBase64Digit(m, 2), k);
205                n3 = fill(table3, n3, n3 + getBase64Digit(m, 3), k);
206                n4 = fill(table4, n4, n4 + getBase64Digit(m, 4), k);
207                n5 = fill(table5, n5, n5 + getBase64Digit(m, 5), k);
208            }
209        }
210
211        /**
212         * @param rng Generator of uniformly distributed random numbers.
213         * @param source Source to copy.
214         */
215        private MarsagliaTsangWangBase64Int8DiscreteSampler(UniformRandomProvider rng,
216                MarsagliaTsangWangBase64Int8DiscreteSampler source) {
217            super(rng, source);
218            t1 = source.t1;
219            t2 = source.t2;
220            t3 = source.t3;
221            t4 = source.t4;
222            table1 = source.table1;
223            table2 = source.table2;
224            table3 = source.table3;
225            table4 = source.table4;
226            table5 = source.table5;
227        }
228
229        /**
230         * Fill the table with the value.
231         *
232         * @param table Table.
233         * @param from Lower bound index (inclusive)
234         * @param to Upper bound index (exclusive)
235         * @param value Value.
236         * @return the upper bound index
237         */
238        private static int fill(byte[] table, int from, int to, byte value) {
239            for (int i = from; i < to; i++) {
240                table[i] = value;
241            }
242            return to;
243        }
244
245        @Override
246        public int sample() {
247            final int j = rng.nextInt() >>> 2;
248            if (j < t1) {
249                return table1[j >>> 24] & MASK;
250            }
251            if (j < t2) {
252                return table2[(j - t1) >>> 18] & MASK;
253            }
254            if (j < t3) {
255                return table3[(j - t2) >>> 12] & MASK;
256            }
257            if (j < t4) {
258                return table4[(j - t3) >>> 6] & MASK;
259            }
260            // Note the tables are filled on the assumption that the sum of the probabilities.
261            // is >=2^30. If this is not true then the final table table5 will be smaller by the
262            // difference. So the tables *must* be constructed correctly.
263            return table5[j - t4] & MASK;
264        }
265
266        @Override
267        public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
268            return new MarsagliaTsangWangBase64Int8DiscreteSampler(rng, this);
269        }
270    }
271
272    /**
273     * An implementation for the sample algorithm based on the decomposition of the
274     * index in the range {@code [0,2^30)} into 5 base-64 digits with 16-bit backing storage.
275     */
276    private static final class MarsagliaTsangWangBase64Int16DiscreteSampler
277        extends AbstractMarsagliaTsangWangDiscreteSampler {
278        /** The mask to convert a {@code byte} to an unsigned 16-bit integer. */
279        private static final int MASK = 0xffff;
280
281        /** Limit for look-up table 1. */
282        private final int t1;
283        /** Limit for look-up table 2. */
284        private final int t2;
285        /** Limit for look-up table 3. */
286        private final int t3;
287        /** Limit for look-up table 4. */
288        private final int t4;
289
290        /** Look-up table table1. */
291        private final short[] table1;
292        /** Look-up table table2. */
293        private final short[] table2;
294        /** Look-up table table3. */
295        private final short[] table3;
296        /** Look-up table table4. */
297        private final short[] table4;
298        /** Look-up table table5. */
299        private final short[] table5;
300
301        /**
302         * @param rng Generator of uniformly distributed random numbers.
303         * @param distributionName Distribution name.
304         * @param prob The probabilities.
305         * @param offset The offset (must be positive).
306         */
307        MarsagliaTsangWangBase64Int16DiscreteSampler(UniformRandomProvider rng,
308                                                     String distributionName,
309                                                     int[] prob,
310                                                     int offset) {
311            super(rng, distributionName);
312
313            // Get table sizes for each base-64 digit
314            int n1 = 0;
315            int n2 = 0;
316            int n3 = 0;
317            int n4 = 0;
318            int n5 = 0;
319            for (final int m : prob) {
320                n1 += getBase64Digit(m, 1);
321                n2 += getBase64Digit(m, 2);
322                n3 += getBase64Digit(m, 3);
323                n4 += getBase64Digit(m, 4);
324                n5 += getBase64Digit(m, 5);
325            }
326
327            table1 = new short[n1];
328            table2 = new short[n2];
329            table3 = new short[n3];
330            table4 = new short[n4];
331            table5 = new short[n5];
332
333            // Compute offsets
334            t1 = n1 << 24;
335            t2 = t1 + (n2 << 18);
336            t3 = t2 + (n3 << 12);
337            t4 = t3 + (n4 << 6);
338            n1 = 0;
339            n2 = 0;
340            n3 = 0;
341            n4 = 0;
342            n5 = 0;
343
344            // Fill tables
345            for (int i = 0; i < prob.length; i++) {
346                final int m = prob[i];
347                // Primitive type conversion will extract lower 16 bits
348                final short k = (short) (i + offset);
349                n1 = fill(table1, n1, n1 + getBase64Digit(m, 1), k);
350                n2 = fill(table2, n2, n2 + getBase64Digit(m, 2), k);
351                n3 = fill(table3, n3, n3 + getBase64Digit(m, 3), k);
352                n4 = fill(table4, n4, n4 + getBase64Digit(m, 4), k);
353                n5 = fill(table5, n5, n5 + getBase64Digit(m, 5), k);
354            }
355        }
356
357        /**
358         * @param rng Generator of uniformly distributed random numbers.
359         * @param source Source to copy.
360         */
361        private MarsagliaTsangWangBase64Int16DiscreteSampler(UniformRandomProvider rng,
362                MarsagliaTsangWangBase64Int16DiscreteSampler source) {
363            super(rng, source);
364            t1 = source.t1;
365            t2 = source.t2;
366            t3 = source.t3;
367            t4 = source.t4;
368            table1 = source.table1;
369            table2 = source.table2;
370            table3 = source.table3;
371            table4 = source.table4;
372            table5 = source.table5;
373        }
374
375        /**
376         * Fill the table with the value.
377         *
378         * @param table Table.
379         * @param from Lower bound index (inclusive)
380         * @param to Upper bound index (exclusive)
381         * @param value Value.
382         * @return the upper bound index
383         */
384        private static int fill(short[] table, int from, int to, short value) {
385            for (int i = from; i < to; i++) {
386                table[i] = value;
387            }
388            return to;
389        }
390
391        @Override
392        public int sample() {
393            final int j = rng.nextInt() >>> 2;
394            if (j < t1) {
395                return table1[j >>> 24] & MASK;
396            }
397            if (j < t2) {
398                return table2[(j - t1) >>> 18] & MASK;
399            }
400            if (j < t3) {
401                return table3[(j - t2) >>> 12] & MASK;
402            }
403            if (j < t4) {
404                return table4[(j - t3) >>> 6] & MASK;
405            }
406            // Note the tables are filled on the assumption that the sum of the probabilities.
407            // is >=2^30. If this is not true then the final table table5 will be smaller by the
408            // difference. So the tables *must* be constructed correctly.
409            return table5[j - t4] & MASK;
410        }
411
412        @Override
413        public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
414            return new MarsagliaTsangWangBase64Int16DiscreteSampler(rng, this);
415        }
416    }
417
418    /**
419     * An implementation for the sample algorithm based on the decomposition of the
420     * index in the range {@code [0,2^30)} into 5 base-64 digits with 32-bit backing storage.
421     */
422    private static final class MarsagliaTsangWangBase64Int32DiscreteSampler
423        extends AbstractMarsagliaTsangWangDiscreteSampler {
424        /** Limit for look-up table 1. */
425        private final int t1;
426        /** Limit for look-up table 2. */
427        private final int t2;
428        /** Limit for look-up table 3. */
429        private final int t3;
430        /** Limit for look-up table 4. */
431        private final int t4;
432
433        /** Look-up table table1. */
434        private final int[] table1;
435        /** Look-up table table2. */
436        private final int[] table2;
437        /** Look-up table table3. */
438        private final int[] table3;
439        /** Look-up table table4. */
440        private final int[] table4;
441        /** Look-up table table5. */
442        private final int[] table5;
443
444        /**
445         * @param rng Generator of uniformly distributed random numbers.
446         * @param distributionName Distribution name.
447         * @param prob The probabilities.
448         * @param offset The offset (must be positive).
449         */
450        MarsagliaTsangWangBase64Int32DiscreteSampler(UniformRandomProvider rng,
451                                                     String distributionName,
452                                                     int[] prob,
453                                                     int offset) {
454            super(rng, distributionName);
455
456            // Get table sizes for each base-64 digit
457            int n1 = 0;
458            int n2 = 0;
459            int n3 = 0;
460            int n4 = 0;
461            int n5 = 0;
462            for (final int m : prob) {
463                n1 += getBase64Digit(m, 1);
464                n2 += getBase64Digit(m, 2);
465                n3 += getBase64Digit(m, 3);
466                n4 += getBase64Digit(m, 4);
467                n5 += getBase64Digit(m, 5);
468            }
469
470            table1 = new int[n1];
471            table2 = new int[n2];
472            table3 = new int[n3];
473            table4 = new int[n4];
474            table5 = new int[n5];
475
476            // Compute offsets
477            t1 = n1 << 24;
478            t2 = t1 + (n2 << 18);
479            t3 = t2 + (n3 << 12);
480            t4 = t3 + (n4 << 6);
481            n1 = 0;
482            n2 = 0;
483            n3 = 0;
484            n4 = 0;
485            n5 = 0;
486
487            // Fill tables
488            for (int i = 0; i < prob.length; i++) {
489                final int m = prob[i];
490                final int k = i + offset;
491                n1 = fill(table1, n1, n1 + getBase64Digit(m, 1), k);
492                n2 = fill(table2, n2, n2 + getBase64Digit(m, 2), k);
493                n3 = fill(table3, n3, n3 + getBase64Digit(m, 3), k);
494                n4 = fill(table4, n4, n4 + getBase64Digit(m, 4), k);
495                n5 = fill(table5, n5, n5 + getBase64Digit(m, 5), k);
496            }
497        }
498
499        /**
500         * @param rng Generator of uniformly distributed random numbers.
501         * @param source Source to copy.
502         */
503        private MarsagliaTsangWangBase64Int32DiscreteSampler(UniformRandomProvider rng,
504                MarsagliaTsangWangBase64Int32DiscreteSampler source) {
505            super(rng, source);
506            t1 = source.t1;
507            t2 = source.t2;
508            t3 = source.t3;
509            t4 = source.t4;
510            table1 = source.table1;
511            table2 = source.table2;
512            table3 = source.table3;
513            table4 = source.table4;
514            table5 = source.table5;
515        }
516
517        /**
518         * Fill the table with the value.
519         *
520         * @param table Table.
521         * @param from Lower bound index (inclusive)
522         * @param to Upper bound index (exclusive)
523         * @param value Value.
524         * @return the upper bound index
525         */
526        private static int fill(int[] table, int from, int to, int value) {
527            for (int i = from; i < to; i++) {
528                table[i] = value;
529            }
530            return to;
531        }
532
533        @Override
534        public int sample() {
535            final int j = rng.nextInt() >>> 2;
536            if (j < t1) {
537                return table1[j >>> 24];
538            }
539            if (j < t2) {
540                return table2[(j - t1) >>> 18];
541            }
542            if (j < t3) {
543                return table3[(j - t2) >>> 12];
544            }
545            if (j < t4) {
546                return table4[(j - t3) >>> 6];
547            }
548            // Note the tables are filled on the assumption that the sum of the probabilities.
549            // is >=2^30. If this is not true then the final table table5 will be smaller by the
550            // difference. So the tables *must* be constructed correctly.
551            return table5[j - t4];
552        }
553
554        @Override
555        public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
556            return new MarsagliaTsangWangBase64Int32DiscreteSampler(rng, this);
557        }
558    }
559
560
561
562    /** Class contains only static methods. */
563    private MarsagliaTsangWangDiscreteSampler() {}
564
565    /**
566     * Gets the k<sup>th</sup> base 64 digit of {@code m}.
567     *
568     * @param m the value m.
569     * @param k the digit.
570     * @return the base 64 digit
571     */
572    private static int getBase64Digit(int m, int k) {
573        return (m >>> (30 - 6 * k)) & 63;
574    }
575
576    /**
577     * Convert the probability to an integer in the range [0,2^30]. This is the numerator of
578     * a fraction with assumed denominator 2<sup>30</sup>.
579     *
580     * @param p Probability.
581     * @return the fraction numerator
582     */
583    private static int toUnsignedInt30(double p) {
584        return (int) (p * INT_30 + 0.5);
585    }
586
587    /**
588     * Create a new instance for probabilities {@code p(i)} where the sample value {@code x} is
589     * {@code i + offset}.
590     *
591     * <p>The sum of the probabilities must be {@code >=} 2<sup>30</sup>. Only the
592     * values for cumulative probability up to 2<sup>30</sup> will be sampled.</p>
593     *
594     * @param rng Generator of uniformly distributed random numbers.
595     * @param distributionName Distribution name.
596     * @param prob The probabilities.
597     * @param offset The offset (must be positive).
598     * @return Sampler.
599     */
600    private static SharedStateDiscreteSampler createSampler(UniformRandomProvider rng,
601                                                            String distributionName,
602                                                            int[] prob,
603                                                            int offset) {
604        // Note: No argument checks for private method.
605
606        // Choose implementation based on the maximum index
607        final int maxIndex = prob.length + offset - 1;
608        if (maxIndex < INT_8) {
609            return new MarsagliaTsangWangBase64Int8DiscreteSampler(rng, distributionName, prob, offset);
610        }
611        if (maxIndex < INT_16) {
612            return new MarsagliaTsangWangBase64Int16DiscreteSampler(rng, distributionName, prob, offset);
613        }
614        return new MarsagliaTsangWangBase64Int32DiscreteSampler(rng, distributionName, prob, offset);
615    }
616
617    // =========================================================================
618    // The following public classes provide factory methods to construct a sampler for:
619    // - Enumerated probability distribution (from provided double[] probabilities)
620    // - Poisson distribution for mean <= 1024
621    // - Binomial distribution for trials <= 65535
622    // =========================================================================
623
624    /**
625     * Create a sampler for an enumerated distribution of {@code n} values each with an
626     * associated probability.
627     * The samples corresponding to each probability are assumed to be a natural sequence
628     * starting at zero.
629     */
630    public static final class Enumerated {
631        /** The name of the enumerated probability distribution. */
632        private static final String ENUMERATED_NAME = "Enumerated";
633
634        /** Class contains only static methods. */
635        private Enumerated() {}
636
637        /**
638         * Creates a sampler for an enumerated distribution of {@code n} values each with an
639         * associated probability.
640         *
641         * <p>The probabilities will be normalised using their sum. The only requirement
642         * is the sum is positive.</p>
643         *
644         * <p>The sum of the probabilities is normalised to 2<sup>30</sup>. Note that
645         * probabilities are adjusted to the nearest 2<sup>-30</sup> due to round-off during
646         * the normalisation conversion. Consequently any probability less than 2<sup>-31</sup>
647         * will not be observed in samples.</p>
648         *
649         * @param rng Generator of uniformly distributed random numbers.
650         * @param probabilities The list of probabilities.
651         * @return Sampler.
652         * @throws IllegalArgumentException if {@code probabilities} is null or empty, a
653         * probability is negative, infinite or {@code NaN}, or the sum of all
654         * probabilities is not strictly positive.
655         */
656        public static SharedStateDiscreteSampler of(UniformRandomProvider rng,
657                                                    double[] probabilities) {
658            return createSampler(rng, ENUMERATED_NAME, normaliseProbabilities(probabilities), 0);
659        }
660
661        /**
662         * Normalise the probabilities to integers that sum to 2<sup>30</sup>.
663         *
664         * @param probabilities The list of probabilities.
665         * @return the normalised probabilities.
666         * @throws IllegalArgumentException if {@code probabilities} is null or empty, a
667         * probability is negative, infinite or {@code NaN}, or the sum of all
668         * probabilities is not strictly positive.
669         */
670        private static int[] normaliseProbabilities(double[] probabilities) {
671            final double sumProb = InternalUtils.validateProbabilities(probabilities);
672
673            // Compute the normalisation: 2^30 / sum
674            final double normalisation = INT_30 / sumProb;
675            final int[] prob = new int[probabilities.length];
676            int sum = 0;
677            int max = 0;
678            int mode = 0;
679            for (int i = 0; i < prob.length; i++) {
680                // Add 0.5 for rounding
681                final int p = (int) (probabilities[i] * normalisation + 0.5);
682                sum += p;
683                // Find the mode (maximum probability)
684                if (max < p) {
685                    max = p;
686                    mode = i;
687                }
688                prob[i] = p;
689            }
690
691            // The sum must be >= 2^30.
692            // Here just compensate the difference onto the highest probability.
693            prob[mode] += INT_30 - sum;
694
695            return prob;
696        }
697    }
698
699    /**
700     * Create a sampler for the Poisson distribution.
701     */
702    public static final class Poisson {
703        /** The name of the Poisson distribution. */
704        private static final String POISSON_NAME = "Poisson";
705
706        /**
707         * Upper bound on the mean for the Poisson distribution.
708         *
709         * <p>The original source code provided in Marsaglia, et al (2004) has no explicit
710         * limit but the code fails at mean {@code >= 1941} as the transform to compute p(x=mode)
711         * produces infinity. Use a conservative limit of 1024.</p>
712         */
713
714        private static final double MAX_MEAN = 1024;
715        /**
716         * The threshold for the mean of the Poisson distribution to switch the method used
717         * to compute the probabilities. This is taken from the example software provided by
718         * Marsaglia, et al (2004).
719         */
720        private static final double MEAN_THRESHOLD = 21.4;
721
722        /** Class contains only static methods. */
723        private Poisson() {}
724
725        /**
726         * Creates a sampler for the Poisson distribution.
727         *
728         * <p>Any probability less than 2<sup>-31</sup> will not be observed in samples.</p>
729         *
730         * <p>Storage requirements depend on the tabulated probability values. Example storage
731         * requirements are listed below.</p>
732         *
733         * <pre>
734         * mean      table size     kB
735         * 0.25      882            0.88
736         * 0.5       1135           1.14
737         * 1         1200           1.20
738         * 2         1451           1.45
739         * 4         1955           1.96
740         * 8         2961           2.96
741         * 16        4410           4.41
742         * 32        6115           6.11
743         * 64        8499           8.50
744         * 128       11528          11.53
745         * 256       15935          31.87
746         * 512       20912          41.82
747         * 1024      30614          61.23
748         * </pre>
749         *
750         * <p>Note: Storage changes to 2 bytes per index between {@code mean=128} and {@code mean=256}.</p>
751         *
752         * @param rng Generator of uniformly distributed random numbers.
753         * @param mean Mean.
754         * @return Sampler.
755         * @throws IllegalArgumentException if {@code mean <= 0} or {@code mean > 1024}.
756         */
757        public static SharedStateDiscreteSampler of(UniformRandomProvider rng,
758                                                    double mean) {
759            validatePoissonDistributionParameters(mean);
760
761            // Create the distribution either from X=0 or from X=mode when the mean is high.
762            return mean < MEAN_THRESHOLD ?
763                createPoissonDistributionFromX0(rng, mean) :
764                createPoissonDistributionFromXMode(rng, mean);
765        }
766
767        /**
768         * Validate the Poisson distribution parameters.
769         *
770         * @param mean Mean.
771         * @throws IllegalArgumentException if {@code mean <= 0} or {@code mean > 1024}.
772         */
773        private static void validatePoissonDistributionParameters(double mean) {
774            InternalUtils.requireStrictlyPositive(mean, "mean");
775            if (mean > MAX_MEAN) {
776                throw new IllegalArgumentException("mean " + mean + " > " + MAX_MEAN);
777            }
778        }
779
780        /**
781         * Creates the Poisson distribution by computing probabilities recursively from {@code X=0}.
782         *
783         * @param rng Generator of uniformly distributed random numbers.
784         * @param mean Mean.
785         * @return Sampler.
786         */
787        private static SharedStateDiscreteSampler createPoissonDistributionFromX0(
788                UniformRandomProvider rng, double mean) {
789            final double p0 = Math.exp(-mean);
790
791            // Recursive update of Poisson probability until the value is too small
792            // p(x + 1) = p(x) * mean / (x + 1)
793            double p = p0;
794            int i = 1;
795            while (p * DOUBLE_31 >= 1) {
796                p *= mean / i++;
797            }
798
799            // Probabilities are 30-bit integers, assumed denominator 2^30
800            final int size = i - 1;
801            final int[] prob = new int[size];
802
803            p = p0;
804            prob[0] = toUnsignedInt30(p);
805            // The sum must exceed 2^30. In edges cases this is false due to round-off.
806            int sum = prob[0];
807            for (i = 1; i < prob.length; i++) {
808                p *= mean / i;
809                prob[i] = toUnsignedInt30(p);
810                sum += prob[i];
811            }
812
813            // If the sum is < 2^30 add the remaining sum to the mode (floor(mean)).
814            prob[(int) mean] += Math.max(0, INT_30 - sum);
815
816            // Note: offset = 0
817            return createSampler(rng, POISSON_NAME, prob, 0);
818        }
819
820        /**
821         * Creates the Poisson distribution by computing probabilities recursively upward and downward
822         * from {@code X=mode}, the location of the largest p-value.
823         *
824         * @param rng Generator of uniformly distributed random numbers.
825         * @param mean Mean.
826         * @return Sampler.
827         */
828        private static SharedStateDiscreteSampler createPoissonDistributionFromXMode(
829                UniformRandomProvider rng, double mean) {
830            // If mean >= 21.4, generate from largest p-value up, then largest down.
831            // The largest p-value will be at the mode (floor(mean)).
832
833            // Find p(x=mode)
834            final int mode = (int) mean;
835            // This transform is stable until mean >= 1941 where p will result in Infinity
836            // before the divisor i is large enough to start reducing the product (i.e. i > c).
837            final double c = mean * Math.exp(-mean / mode);
838            double p = 1.0;
839            for (int i = 1; i <= mode; i++) {
840                p *= c / i;
841            }
842            final double pMode = p;
843
844            // Find the upper limit using recursive computation of the p-value.
845            // Note this will exit when i overflows to negative so no check on the range
846            int i = mode + 1;
847            while (p * DOUBLE_31 >= 1) {
848                p *= mean / i++;
849            }
850            final int last = i - 2;
851
852            // Find the lower limit using recursive computation of the p-value.
853            p = pMode;
854            int j = -1;
855            for (i = mode - 1; i >= 0; i--) {
856                p *= (i + 1) / mean;
857                if (p * DOUBLE_31 < 1) {
858                    j = i;
859                    break;
860                }
861            }
862
863            // Probabilities are 30-bit integers, assumed denominator 2^30.
864            // This is the minimum sample value: prob[x - offset] = p(x)
865            final int offset = j + 1;
866            final int size = last - offset + 1;
867            final int[] prob = new int[size];
868
869            p = pMode;
870            prob[mode - offset] = toUnsignedInt30(p);
871            // The sum must exceed 2^30. In edges cases this is false due to round-off.
872            int sum = prob[mode - offset];
873            // From mode to upper limit
874            for (i = mode + 1; i <= last; i++) {
875                p *= mean / i;
876                prob[i - offset] = toUnsignedInt30(p);
877                sum += prob[i - offset];
878            }
879            // From mode to lower limit
880            p = pMode;
881            for (i = mode - 1; i >= offset; i--) {
882                p *= (i + 1) / mean;
883                prob[i - offset] = toUnsignedInt30(p);
884                sum += prob[i - offset];
885            }
886
887            // If the sum is < 2^30 add the remaining sum to the mode.
888            // If above 2^30 then the effect is truncation of the long tail of the distribution.
889            prob[mode - offset] += Math.max(0, INT_30 - sum);
890
891            return createSampler(rng, POISSON_NAME, prob, offset);
892        }
893    }
894
895    /**
896     * Create a sampler for the Binomial distribution.
897     */
898    public static final class Binomial {
899        /** The name of the Binomial distribution. */
900        private static final String BINOMIAL_NAME = "Binomial";
901
902        /**
903         * Return a fixed result for the Binomial distribution. This is a special class to handle
904         * an edge case of probability of success equal to 0 or 1.
905         */
906        private static final class MarsagliaTsangWangFixedResultBinomialSampler
907            extends AbstractMarsagliaTsangWangDiscreteSampler {
908            /** The result. */
909            private final int result;
910
911            /**
912             * @param result Result.
913             */
914            MarsagliaTsangWangFixedResultBinomialSampler(int result) {
915                super(null, BINOMIAL_NAME);
916                this.result = result;
917            }
918
919            @Override
920            public int sample() {
921                return result;
922            }
923
924            @Override
925            public String toString() {
926                return BINOMIAL_NAME + " deviate";
927            }
928
929            @Override
930            public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
931                // No shared state
932                return this;
933            }
934        }
935
936        /**
937         * Return an inversion result for the Binomial distribution. This assumes the
938         * following:
939         *
940         * <pre>
941         * Binomial(n, p) = 1 - Binomial(n, 1 - p)
942         * </pre>
943         */
944        private static final class MarsagliaTsangWangInversionBinomialSampler
945            extends AbstractMarsagliaTsangWangDiscreteSampler {
946            /** The number of trials. */
947            private final int trials;
948            /** The Binomial distribution sampler. */
949            private final SharedStateDiscreteSampler sampler;
950
951            /**
952             * @param trials Number of trials.
953             * @param sampler Binomial distribution sampler.
954             */
955            MarsagliaTsangWangInversionBinomialSampler(int trials,
956                                                       SharedStateDiscreteSampler sampler) {
957                super(null, BINOMIAL_NAME);
958                this.trials = trials;
959                this.sampler = sampler;
960            }
961
962            @Override
963            public int sample() {
964                return trials - sampler.sample();
965            }
966
967            @Override
968            public String toString() {
969                return sampler.toString();
970            }
971
972            @Override
973            public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
974                return new MarsagliaTsangWangInversionBinomialSampler(this.trials,
975                    this.sampler.withUniformRandomProvider(rng));
976            }
977        }
978
979        /** Class contains only static methods. */
980        private Binomial() {}
981
982        /**
983         * Creates a sampler for the Binomial distribution.
984         *
985         * <p>Any probability less than 2<sup>-31</sup> will not be observed in samples.</p>
986         *
987         * <p>Storage requirements depend on the tabulated probability values. Example storage
988         * requirements are listed below (in kB).</p>
989         *
990         * <pre>
991         *          p
992         * trials   0.5    0.1   0.01  0.001
993         *    4    0.06   0.63   0.44   0.44
994         *   16    0.69   1.14   0.76   0.44
995         *   64    4.73   2.40   1.14   0.51
996         *  256    8.63   5.17   1.89   0.82
997         * 1024   31.12   9.45   3.34   0.89
998         * </pre>
999         *
1000         * <p>The method requires that the Binomial distribution probability at {@code x=0} can be computed.
1001         * This will fail when {@code (1 - p)^trials == 0} which requires {@code trials} to be large
1002         * and/or {@code p} to be small. In this case an exception is raised.</p>
1003         *
1004         * @param rng Generator of uniformly distributed random numbers.
1005         * @param trials Number of trials.
1006         * @param probabilityOfSuccess Probability of success (p).
1007         * @return Sampler.
1008         * @throws IllegalArgumentException if {@code trials < 0} or {@code trials >= 2^16},
1009         * {@code p} is not in the range {@code [0-1]}, or the probability distribution cannot
1010         * be computed.
1011         */
1012        public static SharedStateDiscreteSampler of(UniformRandomProvider rng,
1013                                                    int trials,
1014                                                    double probabilityOfSuccess) {
1015            validateBinomialDistributionParameters(trials, probabilityOfSuccess);
1016
1017            // Handle edge cases
1018            if (probabilityOfSuccess == 0) {
1019                return new MarsagliaTsangWangFixedResultBinomialSampler(0);
1020            }
1021            if (probabilityOfSuccess == 1) {
1022                return new MarsagliaTsangWangFixedResultBinomialSampler(trials);
1023            }
1024
1025            // Check the supported size.
1026            if (trials >= INT_16) {
1027                throw new IllegalArgumentException("Unsupported number of trials: " + trials);
1028            }
1029
1030            return createBinomialDistributionSampler(rng, trials, probabilityOfSuccess);
1031        }
1032
1033        /**
1034         * Validate the Binomial distribution parameters.
1035         *
1036         * @param trials Number of trials.
1037         * @param probabilityOfSuccess Probability of success (p).
1038         * @throws IllegalArgumentException if {@code trials < 0} or
1039         * {@code p} is not in the range {@code [0-1]}
1040         */
1041        private static void validateBinomialDistributionParameters(int trials, double probabilityOfSuccess) {
1042            if (trials < 0) {
1043                throw new IllegalArgumentException("Trials is not positive: " + trials);
1044            }
1045            InternalUtils.requireRangeClosed(0, 1, probabilityOfSuccess, "probability of success");
1046        }
1047
1048        /**
1049         * Creates the Binomial distribution sampler.
1050         *
1051         * <p>This assumes the parameters for the distribution are valid. The method
1052         * will only fail if the initial probability for {@code X=0} is zero.</p>
1053         *
1054         * @param rng Generator of uniformly distributed random numbers.
1055         * @param trials Number of trials.
1056         * @param probabilityOfSuccess Probability of success (p).
1057         * @return Sampler.
1058         * @throws IllegalArgumentException if the probability distribution cannot be
1059         * computed.
1060         */
1061        private static SharedStateDiscreteSampler createBinomialDistributionSampler(
1062                UniformRandomProvider rng, int trials, double probabilityOfSuccess) {
1063
1064            // The maximum supported value for Math.exp is approximately -744.
1065            // This occurs when trials is large and p is close to 1.
1066            // Handle this by using an inversion: generate j=Binomial(n,1-p), return n-j
1067            final boolean useInversion = probabilityOfSuccess > 0.5;
1068            final double p = useInversion ? 1 - probabilityOfSuccess : probabilityOfSuccess;
1069
1070            // Check if the distribution can be computed
1071            final double p0 = Math.exp(trials * Math.log(1 - p));
1072            if (p0 < Double.MIN_VALUE) {
1073                throw new IllegalArgumentException("Unable to compute distribution");
1074            }
1075
1076            // First find size of probability array
1077            double t = p0;
1078            final double h = p / (1 - p);
1079            // Find first probability above the threshold of 2^-31
1080            int begin = 0;
1081            if (t * DOUBLE_31 < 1) {
1082                // Somewhere after p(0)
1083                // Note:
1084                // If this loop is entered p(0) is < 2^-31.
1085                // This has been tested at the extreme for p(0)=Double.MIN_VALUE and either
1086                // p=0.5 or trials=2^16-1 and does not fail to find the beginning.
1087                for (int i = 1; i <= trials; i++) {
1088                    t *= (trials + 1 - i) * h / i;
1089                    if (t * DOUBLE_31 >= 1) {
1090                        begin = i;
1091                        break;
1092                    }
1093                }
1094            }
1095            // Find last probability
1096            int end = trials;
1097            for (int i = begin + 1; i <= trials; i++) {
1098                t *= (trials + 1 - i) * h / i;
1099                if (t * DOUBLE_31 < 1) {
1100                    end = i - 1;
1101                    break;
1102                }
1103            }
1104
1105            return createBinomialDistributionSamplerFromRange(rng, trials, p, useInversion,
1106                    p0, begin, end);
1107        }
1108
1109        /**
1110         * Creates the Binomial distribution sampler using only the probability values for {@code X}
1111         * between the begin and the end (inclusive).
1112         *
1113         * @param rng Generator of uniformly distributed random numbers.
1114         * @param trials Number of trials.
1115         * @param p Probability of success (p).
1116         * @param useInversion Set to {@code true} if the probability was inverted.
1117         * @param p0 Probability at {@code X=0}
1118         * @param begin Begin value {@code X} for the distribution.
1119         * @param end End value {@code X} for the distribution.
1120         * @return Sampler.
1121         */
1122        private static SharedStateDiscreteSampler createBinomialDistributionSamplerFromRange(
1123                UniformRandomProvider rng, int trials, double p,
1124                boolean useInversion, double p0, int begin, int end) {
1125
1126            // Assign probability values as 30-bit integers
1127            final int size = end - begin + 1;
1128            final int[] prob = new int[size];
1129            double t = p0;
1130            final double h = p / (1 - p);
1131            for (int i = 1; i <= begin; i++) {
1132                t *= (trials + 1 - i) * h / i;
1133            }
1134            int sum = toUnsignedInt30(t);
1135            prob[0] = sum;
1136            for (int i = begin + 1; i <= end; i++) {
1137                t *= (trials + 1 - i) * h / i;
1138                prob[i - begin] = toUnsignedInt30(t);
1139                sum += prob[i - begin];
1140            }
1141
1142            // If the sum is < 2^30 add the remaining sum to the mode (floor((n+1)p))).
1143            // If above 2^30 then the effect is truncation of the long tail of the distribution.
1144            final int mode = (int) ((trials + 1) * p) - begin;
1145            prob[mode] += Math.max(0, INT_30 - sum);
1146
1147            final SharedStateDiscreteSampler sampler = createSampler(rng, BINOMIAL_NAME, prob, begin);
1148
1149            // Check if an inversion was made
1150            return useInversion ?
1151                   new MarsagliaTsangWangInversionBinomialSampler(trials, sampler) :
1152                   sampler;
1153        }
1154    }
1155}