001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.rng.sampling.distribution;
018
019import org.apache.commons.rng.UniformRandomProvider;
020
021/**
022 * Sampler for a discrete distribution using an optimised look-up table.
023 *
024 * <ul>
025 *  <li>
026 *   The method requires 30-bit integer probabilities that sum to 2<sup>30</sup> as described
027 *   in George Marsaglia, Wai Wan Tsang, Jingbo Wang (2004) Fast Generation of Discrete
028 *   Random Variables. Journal of Statistical Software. Vol. 11, Issue. 3, pp. 1-11.
029 *  </li>
030 * </ul>
031 *
032 * <p>Sampling uses 1 call to {@link UniformRandomProvider#nextInt()}.</p>
033 *
034 * <p>Memory requirements depend on the maximum number of possible sample values, {@code n},
035 * and the values for the probabilities. Storage is optimised for {@code n}. The worst case
036 * scenario is a uniform distribution of the maximum sample size. This is capped at 0.06MB for
037 * {@code n <= } 2<sup>8</sup>, 17.0MB for {@code n <= } 2<sup>16</sup>, and 4.3GB for
038 * {@code n <=} 2<sup>30</sup>. Realistic requirements will be in the kB range.</p>
039 *
040 * <p>The sampler supports the following distributions:</p>
041 *
042 * <ul>
043 *  <li>Enumerated distribution (probabilities must be provided for each sample)
044 *  <li>Poisson distribution up to {@code mean = 1024}
045 *  <li>Binomial distribution up to {@code trials = 65535}
046 * </ul>
047 *
048 * @see <a href="http://dx.doi.org/10.18637/jss.v011.i03">Margsglia, et al (2004) JSS Vol.
049 * 11, Issue 3</a>
050 * @since 1.3
051 */
052public final class MarsagliaTsangWangDiscreteSampler {
053    /** The value 2<sup>8</sup> as an {@code int}. */
054    private static final int INT_8 = 1 << 8;
055    /** The value 2<sup>16</sup> as an {@code int}. */
056    private static final int INT_16 = 1 << 16;
057    /** The value 2<sup>30</sup> as an {@code int}. */
058    private static final int INT_30 = 1 << 30;
059    /** The value 2<sup>31</sup> as a {@code double}. */
060    private static final double DOUBLE_31 = 1L << 31;
061
062    // =========================================================================
063    // Implementation note:
064    //
065    // This sampler uses prepared look-up tables that are searched using a single
066    // random int variate. The look-up tables contain the sample value. The tables
067    // are constructed using probabilities that sum to 2^30. The original paper
068    // by Marsaglia, et al (2004) describes the use of 5, 3, or 2 look-up tables
069    // indexed using digits of base 2^6, 2^10 or 2^15. Currently only base 64 (2^6)
070    // is supported using 5 look-up tables.
071    //
072    // The implementations use 8, 16 or 32 bit storage tables to support different
073    // distribution sizes with optimal storage. Separate class implementations of
074    // the same algorithm allow array storage to be accessed directly from 1D tables.
075    // This provides a performance gain over using: abstracted storage accessed via
076    // an interface; or a single 2D table.
077    //
078    // To allow the optimal implementation to be chosen the sampler is created
079    // using factory methods. The sampler supports any probability distribution
080    // when provided via an array of probabilities and the Poisson and Binomial
081    // distributions for a restricted set of parameters. The restrictions are
082    // imposed by the requirement to compute the entire probability distribution
083    // from the controlling parameter(s) using a recursive method. Factory
084    // constructors return a SharedStateDiscreteSampler instance. Each distribution
085    // type is contained in an inner class.
086    // =========================================================================
087
088    /**
089     * The base class for Marsaglia-Tsang-Wang samplers.
090     */
091    private abstract static class AbstractMarsagliaTsangWangDiscreteSampler
092            implements SharedStateDiscreteSampler {
093        /** Underlying source of randomness. */
094        protected final UniformRandomProvider rng;
095
096        /** The name of the distribution. */
097        private final String distributionName;
098
099        /**
100         * @param rng Generator of uniformly distributed random numbers.
101         * @param distributionName Distribution name.
102         */
103        AbstractMarsagliaTsangWangDiscreteSampler(UniformRandomProvider rng,
104                                                  String distributionName) {
105            this.rng = rng;
106            this.distributionName = distributionName;
107        }
108
109        /**
110         * @param rng Generator of uniformly distributed random numbers.
111         * @param source Source to copy.
112         */
113        AbstractMarsagliaTsangWangDiscreteSampler(UniformRandomProvider rng,
114                                                  AbstractMarsagliaTsangWangDiscreteSampler source) {
115            this.rng = rng;
116            this.distributionName = source.distributionName;
117        }
118
119        /** {@inheritDoc} */
120        @Override
121        public String toString() {
122            return "Marsaglia Tsang Wang " + distributionName + " deviate [" + rng.toString() + "]";
123        }
124    }
125
126    /**
127     * An implementation for the sample algorithm based on the decomposition of the
128     * index in the range {@code [0,2^30)} into 5 base-64 digits with 8-bit backing storage.
129     */
130    private static class MarsagliaTsangWangBase64Int8DiscreteSampler
131        extends AbstractMarsagliaTsangWangDiscreteSampler {
132        /** The mask to convert a {@code byte} to an unsigned 8-bit integer. */
133        private static final int MASK = 0xff;
134
135        /** Limit for look-up table 1. */
136        private final int t1;
137        /** Limit for look-up table 2. */
138        private final int t2;
139        /** Limit for look-up table 3. */
140        private final int t3;
141        /** Limit for look-up table 4. */
142        private final int t4;
143
144        /** Look-up table table1. */
145        private final byte[] table1;
146        /** Look-up table table2. */
147        private final byte[] table2;
148        /** Look-up table table3. */
149        private final byte[] table3;
150        /** Look-up table table4. */
151        private final byte[] table4;
152        /** Look-up table table5. */
153        private final byte[] table5;
154
155        /**
156         * @param rng Generator of uniformly distributed random numbers.
157         * @param distributionName Distribution name.
158         * @param prob The probabilities.
159         * @param offset The offset (must be positive).
160         */
161        MarsagliaTsangWangBase64Int8DiscreteSampler(UniformRandomProvider rng,
162                                                    String distributionName,
163                                                    int[] prob,
164                                                    int offset) {
165            super(rng, distributionName);
166
167            // Get table sizes for each base-64 digit
168            int n1 = 0;
169            int n2 = 0;
170            int n3 = 0;
171            int n4 = 0;
172            int n5 = 0;
173            for (final int m : prob) {
174                n1 += getBase64Digit(m, 1);
175                n2 += getBase64Digit(m, 2);
176                n3 += getBase64Digit(m, 3);
177                n4 += getBase64Digit(m, 4);
178                n5 += getBase64Digit(m, 5);
179            }
180
181            table1 = new byte[n1];
182            table2 = new byte[n2];
183            table3 = new byte[n3];
184            table4 = new byte[n4];
185            table5 = new byte[n5];
186
187            // Compute offsets
188            t1 = n1 << 24;
189            t2 = t1 + (n2 << 18);
190            t3 = t2 + (n3 << 12);
191            t4 = t3 + (n4 << 6);
192            n1 = n2 = n3 = n4 = n5 = 0;
193
194            // Fill tables
195            for (int i = 0; i < prob.length; i++) {
196                final int m = prob[i];
197                // Primitive type conversion will extract lower 8 bits
198                final byte k = (byte) (i + offset);
199                n1 = fill(table1, n1, n1 + getBase64Digit(m, 1), k);
200                n2 = fill(table2, n2, n2 + getBase64Digit(m, 2), k);
201                n3 = fill(table3, n3, n3 + getBase64Digit(m, 3), k);
202                n4 = fill(table4, n4, n4 + getBase64Digit(m, 4), k);
203                n5 = fill(table5, n5, n5 + getBase64Digit(m, 5), k);
204            }
205        }
206
207        /**
208         * @param rng Generator of uniformly distributed random numbers.
209         * @param source Source to copy.
210         */
211        private MarsagliaTsangWangBase64Int8DiscreteSampler(UniformRandomProvider rng,
212                MarsagliaTsangWangBase64Int8DiscreteSampler source) {
213            super(rng, source);
214            t1 = source.t1;
215            t2 = source.t2;
216            t3 = source.t3;
217            t4 = source.t4;
218            table1 = source.table1;
219            table2 = source.table2;
220            table3 = source.table3;
221            table4 = source.table4;
222            table5 = source.table5;
223        }
224
225        /**
226         * Fill the table with the value.
227         *
228         * @param table Table.
229         * @param from Lower bound index (inclusive)
230         * @param to Upper bound index (exclusive)
231         * @param value Value.
232         * @return the upper bound index
233         */
234        private static int fill(byte[] table, int from, int to, byte value) {
235            for (int i = from; i < to; i++) {
236                table[i] = value;
237            }
238            return to;
239        }
240
241        @Override
242        public int sample() {
243            final int j = rng.nextInt() >>> 2;
244            if (j < t1) {
245                return table1[j >>> 24] & MASK;
246            }
247            if (j < t2) {
248                return table2[(j - t1) >>> 18] & MASK;
249            }
250            if (j < t3) {
251                return table3[(j - t2) >>> 12] & MASK;
252            }
253            if (j < t4) {
254                return table4[(j - t3) >>> 6] & MASK;
255            }
256            // Note the tables are filled on the assumption that the sum of the probabilities.
257            // is >=2^30. If this is not true then the final table table5 will be smaller by the
258            // difference. So the tables *must* be constructed correctly.
259            return table5[j - t4] & MASK;
260        }
261
262        @Override
263        public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
264            return new MarsagliaTsangWangBase64Int8DiscreteSampler(rng, this);
265        }
266    }
267
268    /**
269     * An implementation for the sample algorithm based on the decomposition of the
270     * index in the range {@code [0,2^30)} into 5 base-64 digits with 16-bit backing storage.
271     */
272    private static class MarsagliaTsangWangBase64Int16DiscreteSampler
273        extends AbstractMarsagliaTsangWangDiscreteSampler {
274        /** The mask to convert a {@code byte} to an unsigned 16-bit integer. */
275        private static final int MASK = 0xffff;
276
277        /** Limit for look-up table 1. */
278        private final int t1;
279        /** Limit for look-up table 2. */
280        private final int t2;
281        /** Limit for look-up table 3. */
282        private final int t3;
283        /** Limit for look-up table 4. */
284        private final int t4;
285
286        /** Look-up table table1. */
287        private final short[] table1;
288        /** Look-up table table2. */
289        private final short[] table2;
290        /** Look-up table table3. */
291        private final short[] table3;
292        /** Look-up table table4. */
293        private final short[] table4;
294        /** Look-up table table5. */
295        private final short[] table5;
296
297        /**
298         * @param rng Generator of uniformly distributed random numbers.
299         * @param distributionName Distribution name.
300         * @param prob The probabilities.
301         * @param offset The offset (must be positive).
302         */
303        MarsagliaTsangWangBase64Int16DiscreteSampler(UniformRandomProvider rng,
304                                                     String distributionName,
305                                                     int[] prob,
306                                                     int offset) {
307            super(rng, distributionName);
308
309            // Get table sizes for each base-64 digit
310            int n1 = 0;
311            int n2 = 0;
312            int n3 = 0;
313            int n4 = 0;
314            int n5 = 0;
315            for (final int m : prob) {
316                n1 += getBase64Digit(m, 1);
317                n2 += getBase64Digit(m, 2);
318                n3 += getBase64Digit(m, 3);
319                n4 += getBase64Digit(m, 4);
320                n5 += getBase64Digit(m, 5);
321            }
322
323            table1 = new short[n1];
324            table2 = new short[n2];
325            table3 = new short[n3];
326            table4 = new short[n4];
327            table5 = new short[n5];
328
329            // Compute offsets
330            t1 = n1 << 24;
331            t2 = t1 + (n2 << 18);
332            t3 = t2 + (n3 << 12);
333            t4 = t3 + (n4 << 6);
334            n1 = n2 = n3 = n4 = n5 = 0;
335
336            // Fill tables
337            for (int i = 0; i < prob.length; i++) {
338                final int m = prob[i];
339                // Primitive type conversion will extract lower 16 bits
340                final short k = (short) (i + offset);
341                n1 = fill(table1, n1, n1 + getBase64Digit(m, 1), k);
342                n2 = fill(table2, n2, n2 + getBase64Digit(m, 2), k);
343                n3 = fill(table3, n3, n3 + getBase64Digit(m, 3), k);
344                n4 = fill(table4, n4, n4 + getBase64Digit(m, 4), k);
345                n5 = fill(table5, n5, n5 + getBase64Digit(m, 5), k);
346            }
347        }
348
349        /**
350         * @param rng Generator of uniformly distributed random numbers.
351         * @param source Source to copy.
352         */
353        private MarsagliaTsangWangBase64Int16DiscreteSampler(UniformRandomProvider rng,
354                MarsagliaTsangWangBase64Int16DiscreteSampler source) {
355            super(rng, source);
356            t1 = source.t1;
357            t2 = source.t2;
358            t3 = source.t3;
359            t4 = source.t4;
360            table1 = source.table1;
361            table2 = source.table2;
362            table3 = source.table3;
363            table4 = source.table4;
364            table5 = source.table5;
365        }
366
367        /**
368         * Fill the table with the value.
369         *
370         * @param table Table.
371         * @param from Lower bound index (inclusive)
372         * @param to Upper bound index (exclusive)
373         * @param value Value.
374         * @return the upper bound index
375         */
376        private static int fill(short[] table, int from, int to, short value) {
377            for (int i = from; i < to; i++) {
378                table[i] = value;
379            }
380            return to;
381        }
382
383        @Override
384        public int sample() {
385            final int j = rng.nextInt() >>> 2;
386            if (j < t1) {
387                return table1[j >>> 24] & MASK;
388            }
389            if (j < t2) {
390                return table2[(j - t1) >>> 18] & MASK;
391            }
392            if (j < t3) {
393                return table3[(j - t2) >>> 12] & MASK;
394            }
395            if (j < t4) {
396                return table4[(j - t3) >>> 6] & MASK;
397            }
398            // Note the tables are filled on the assumption that the sum of the probabilities.
399            // is >=2^30. If this is not true then the final table table5 will be smaller by the
400            // difference. So the tables *must* be constructed correctly.
401            return table5[j - t4] & MASK;
402        }
403
404        @Override
405        public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
406            return new MarsagliaTsangWangBase64Int16DiscreteSampler(rng, this);
407        }
408    }
409
410    /**
411     * An implementation for the sample algorithm based on the decomposition of the
412     * index in the range {@code [0,2^30)} into 5 base-64 digits with 32-bit backing storage.
413     */
414    private static class MarsagliaTsangWangBase64Int32DiscreteSampler
415        extends AbstractMarsagliaTsangWangDiscreteSampler {
416        /** Limit for look-up table 1. */
417        private final int t1;
418        /** Limit for look-up table 2. */
419        private final int t2;
420        /** Limit for look-up table 3. */
421        private final int t3;
422        /** Limit for look-up table 4. */
423        private final int t4;
424
425        /** Look-up table table1. */
426        private final int[] table1;
427        /** Look-up table table2. */
428        private final int[] table2;
429        /** Look-up table table3. */
430        private final int[] table3;
431        /** Look-up table table4. */
432        private final int[] table4;
433        /** Look-up table table5. */
434        private final int[] table5;
435
436        /**
437         * @param rng Generator of uniformly distributed random numbers.
438         * @param distributionName Distribution name.
439         * @param prob The probabilities.
440         * @param offset The offset (must be positive).
441         */
442        MarsagliaTsangWangBase64Int32DiscreteSampler(UniformRandomProvider rng,
443                                                     String distributionName,
444                                                     int[] prob,
445                                                     int offset) {
446            super(rng, distributionName);
447
448            // Get table sizes for each base-64 digit
449            int n1 = 0;
450            int n2 = 0;
451            int n3 = 0;
452            int n4 = 0;
453            int n5 = 0;
454            for (final int m : prob) {
455                n1 += getBase64Digit(m, 1);
456                n2 += getBase64Digit(m, 2);
457                n3 += getBase64Digit(m, 3);
458                n4 += getBase64Digit(m, 4);
459                n5 += getBase64Digit(m, 5);
460            }
461
462            table1 = new int[n1];
463            table2 = new int[n2];
464            table3 = new int[n3];
465            table4 = new int[n4];
466            table5 = new int[n5];
467
468            // Compute offsets
469            t1 = n1 << 24;
470            t2 = t1 + (n2 << 18);
471            t3 = t2 + (n3 << 12);
472            t4 = t3 + (n4 << 6);
473            n1 = n2 = n3 = n4 = n5 = 0;
474
475            // Fill tables
476            for (int i = 0; i < prob.length; i++) {
477                final int m = prob[i];
478                final int k = i + offset;
479                n1 = fill(table1, n1, n1 + getBase64Digit(m, 1), k);
480                n2 = fill(table2, n2, n2 + getBase64Digit(m, 2), k);
481                n3 = fill(table3, n3, n3 + getBase64Digit(m, 3), k);
482                n4 = fill(table4, n4, n4 + getBase64Digit(m, 4), k);
483                n5 = fill(table5, n5, n5 + getBase64Digit(m, 5), k);
484            }
485        }
486
487        /**
488         * @param rng Generator of uniformly distributed random numbers.
489         * @param source Source to copy.
490         */
491        private MarsagliaTsangWangBase64Int32DiscreteSampler(UniformRandomProvider rng,
492                MarsagliaTsangWangBase64Int32DiscreteSampler source) {
493            super(rng, source);
494            t1 = source.t1;
495            t2 = source.t2;
496            t3 = source.t3;
497            t4 = source.t4;
498            table1 = source.table1;
499            table2 = source.table2;
500            table3 = source.table3;
501            table4 = source.table4;
502            table5 = source.table5;
503        }
504
505        /**
506         * Fill the table with the value.
507         *
508         * @param table Table.
509         * @param from Lower bound index (inclusive)
510         * @param to Upper bound index (exclusive)
511         * @param value Value.
512         * @return the upper bound index
513         */
514        private static int fill(int[] table, int from, int to, int value) {
515            for (int i = from; i < to; i++) {
516                table[i] = value;
517            }
518            return to;
519        }
520
521        @Override
522        public int sample() {
523            final int j = rng.nextInt() >>> 2;
524            if (j < t1) {
525                return table1[j >>> 24];
526            }
527            if (j < t2) {
528                return table2[(j - t1) >>> 18];
529            }
530            if (j < t3) {
531                return table3[(j - t2) >>> 12];
532            }
533            if (j < t4) {
534                return table4[(j - t3) >>> 6];
535            }
536            // Note the tables are filled on the assumption that the sum of the probabilities.
537            // is >=2^30. If this is not true then the final table table5 will be smaller by the
538            // difference. So the tables *must* be constructed correctly.
539            return table5[j - t4];
540        }
541
542        @Override
543        public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
544            return new MarsagliaTsangWangBase64Int32DiscreteSampler(rng, this);
545        }
546    }
547
548
549
550    /** Class contains only static methods. */
551    private MarsagliaTsangWangDiscreteSampler() {}
552
553    /**
554     * Gets the k<sup>th</sup> base 64 digit of {@code m}.
555     *
556     * @param m the value m.
557     * @param k the digit.
558     * @return the base 64 digit
559     */
560    private static int getBase64Digit(int m, int k) {
561        return (m >>> (30 - 6 * k)) & 63;
562    }
563
564    /**
565     * Convert the probability to an integer in the range [0,2^30]. This is the numerator of
566     * a fraction with assumed denominator 2<sup>30</sup>.
567     *
568     * @param p Probability.
569     * @return the fraction numerator
570     */
571    private static int toUnsignedInt30(double p) {
572        return (int) (p * INT_30 + 0.5);
573    }
574
575    /**
576     * Create a new instance for probabilities {@code p(i)} where the sample value {@code x} is
577     * {@code i + offset}.
578     *
579     * <p>The sum of the probabilities must be >= 2<sup>30</sup>. Only the
580     * values for cumulative probability up to 2<sup>30</sup> will be sampled.</p>
581     *
582     * @param rng Generator of uniformly distributed random numbers.
583     * @param distributionName Distribution name.
584     * @param prob The probabilities.
585     * @param offset The offset (must be positive).
586     * @return Sampler.
587     */
588    private static SharedStateDiscreteSampler createSampler(UniformRandomProvider rng,
589                                                            String distributionName,
590                                                            int[] prob,
591                                                            int offset) {
592        // Note: No argument checks for private method.
593
594        // Choose implementation based on the maximum index
595        final int maxIndex = prob.length + offset - 1;
596        if (maxIndex < INT_8) {
597            return new MarsagliaTsangWangBase64Int8DiscreteSampler(rng, distributionName, prob, offset);
598        }
599        if (maxIndex < INT_16) {
600            return new MarsagliaTsangWangBase64Int16DiscreteSampler(rng, distributionName, prob, offset);
601        }
602        return new MarsagliaTsangWangBase64Int32DiscreteSampler(rng, distributionName, prob, offset);
603    }
604
605    // =========================================================================
606    // The following public classes provide factory methods to construct a sampler for:
607    // - Enumerated probability distribution (from provided double[] probabilities)
608    // - Poisson distribution for mean <= 1024
609    // - Binomial distribution for trials <= 65535
610    // =========================================================================
611
612    /**
613     * Create a sampler for an enumerated distribution of {@code n} values each with an
614     * associated probability.
615     * The samples corresponding to each probability are assumed to be a natural sequence
616     * starting at zero.
617     */
618    public static final class Enumerated {
619        /** The name of the enumerated probability distribution. */
620        private static final String ENUMERATED_NAME = "Enumerated";
621
622        /** Class contains only static methods. */
623        private Enumerated() {}
624
625        /**
626         * Creates a sampler for an enumerated distribution of {@code n} values each with an
627         * associated probability.
628         *
629         * <p>The probabilities will be normalised using their sum. The only requirement
630         * is the sum is positive.</p>
631         *
632         * <p>The sum of the probabilities is normalised to 2<sup>30</sup>. Note that
633         * probabilities are adjusted to the nearest 2<sup>-30</sup> due to round-off during
634         * the normalisation conversion. Consequently any probability less than 2<sup>-31</sup>
635         * will not be observed in samples.</p>
636         *
637         * @param rng Generator of uniformly distributed random numbers.
638         * @param probabilities The list of probabilities.
639         * @return Sampler.
640         * @throws IllegalArgumentException if {@code probabilities} is null or empty, a
641         * probability is negative, infinite or {@code NaN}, or the sum of all
642         * probabilities is not strictly positive.
643         */
644        public static SharedStateDiscreteSampler of(UniformRandomProvider rng,
645                                                    double[] probabilities) {
646            return createSampler(rng, ENUMERATED_NAME, normaliseProbabilities(probabilities), 0);
647        }
648
649        /**
650         * Normalise the probabilities to integers that sum to 2<sup>30</sup>.
651         *
652         * @param probabilities The list of probabilities.
653         * @return the normalised probabilities.
654         * @throws IllegalArgumentException if {@code probabilities} is null or empty, a
655         * probability is negative, infinite or {@code NaN}, or the sum of all
656         * probabilities is not strictly positive.
657         */
658        private static int[] normaliseProbabilities(double[] probabilities) {
659            final double sumProb = InternalUtils.validateProbabilities(probabilities);
660
661            // Compute the normalisation: 2^30 / sum
662            final double normalisation = INT_30 / sumProb;
663            final int[] prob = new int[probabilities.length];
664            int sum = 0;
665            int max = 0;
666            int mode = 0;
667            for (int i = 0; i < prob.length; i++) {
668                // Add 0.5 for rounding
669                final int p = (int) (probabilities[i] * normalisation + 0.5);
670                sum += p;
671                // Find the mode (maximum probability)
672                if (max < p) {
673                    max = p;
674                    mode = i;
675                }
676                prob[i] = p;
677            }
678
679            // The sum must be >= 2^30.
680            // Here just compensate the difference onto the highest probability.
681            prob[mode] += INT_30 - sum;
682
683            return prob;
684        }
685    }
686
687    /**
688     * Create a sampler for the Poisson distribution.
689     */
690    public static final class Poisson {
691        /** The name of the Poisson distribution. */
692        private static final String POISSON_NAME = "Poisson";
693
694        /**
695         * Upper bound on the mean for the Poisson distribution.
696         *
697         * <p>The original source code provided in Marsaglia, et al (2004) has no explicit
698         * limit but the code fails at mean >= 1941 as the transform to compute p(x=mode)
699         * produces infinity. Use a conservative limit of 1024.</p>
700         */
701
702        private static final double MAX_MEAN = 1024;
703        /**
704         * The threshold for the mean of the Poisson distribution to switch the method used
705         * to compute the probabilities. This is taken from the example software provided by
706         * Marsaglia, et al (2004).
707         */
708        private static final double MEAN_THRESHOLD = 21.4;
709
710        /** Class contains only static methods. */
711        private Poisson() {}
712
713        /**
714         * Creates a sampler for the Poisson distribution.
715         *
716         * <p>Any probability less than 2<sup>-31</sup> will not be observed in samples.</p>
717         *
718         * <p>Storage requirements depend on the tabulated probability values. Example storage
719         * requirements are listed below.</p>
720         *
721         * <pre>
722         * mean      table size     kB
723         * 0.25      882            0.88
724         * 0.5       1135           1.14
725         * 1         1200           1.20
726         * 2         1451           1.45
727         * 4         1955           1.96
728         * 8         2961           2.96
729         * 16        4410           4.41
730         * 32        6115           6.11
731         * 64        8499           8.50
732         * 128       11528          11.53
733         * 256       15935          31.87
734         * 512       20912          41.82
735         * 1024      30614          61.23
736         * </pre>
737         *
738         * <p>Note: Storage changes to 2 bytes per index between {@code mean=128} and {@code mean=256}.</p>
739         *
740         * @param rng Generator of uniformly distributed random numbers.
741         * @param mean Mean.
742         * @return Sampler.
743         * @throws IllegalArgumentException if {@code mean <= 0} or {@code mean > 1024}.
744         */
745        public static SharedStateDiscreteSampler of(UniformRandomProvider rng,
746                                                    double mean) {
747            validatePoissonDistributionParameters(mean);
748
749            // Create the distribution either from X=0 or from X=mode when the mean is high.
750            return mean < MEAN_THRESHOLD ?
751                createPoissonDistributionFromX0(rng, mean) :
752                createPoissonDistributionFromXMode(rng, mean);
753        }
754
755        /**
756         * Validate the Poisson distribution parameters.
757         *
758         * @param mean Mean.
759         * @throws IllegalArgumentException if {@code mean <= 0} or {@code mean > 1024}.
760         */
761        private static void validatePoissonDistributionParameters(double mean) {
762            if (mean <= 0) {
763                throw new IllegalArgumentException("mean is not strictly positive: " + mean);
764            }
765            if (mean > MAX_MEAN) {
766                throw new IllegalArgumentException("mean " + mean + " > " + MAX_MEAN);
767            }
768        }
769
770        /**
771         * Creates the Poisson distribution by computing probabilities recursively from {@code X=0}.
772         *
773         * @param rng Generator of uniformly distributed random numbers.
774         * @param mean Mean.
775         * @return Sampler.
776         */
777        private static SharedStateDiscreteSampler createPoissonDistributionFromX0(
778                UniformRandomProvider rng, double mean) {
779            final double p0 = Math.exp(-mean);
780
781            // Recursive update of Poisson probability until the value is too small
782            // p(x + 1) = p(x) * mean / (x + 1)
783            double p = p0;
784            int i = 1;
785            while (p * DOUBLE_31 >= 1) {
786                p *= mean / i++;
787            }
788
789            // Probabilities are 30-bit integers, assumed denominator 2^30
790            final int size = i - 1;
791            final int[] prob = new int[size];
792
793            p = p0;
794            prob[0] = toUnsignedInt30(p);
795            // The sum must exceed 2^30. In edges cases this is false due to round-off.
796            int sum = prob[0];
797            for (i = 1; i < prob.length; i++) {
798                p *= mean / i;
799                prob[i] = toUnsignedInt30(p);
800                sum += prob[i];
801            }
802
803            // If the sum is < 2^30 add the remaining sum to the mode (floor(mean)).
804            prob[(int) mean] += Math.max(0, INT_30 - sum);
805
806            // Note: offset = 0
807            return createSampler(rng, POISSON_NAME, prob, 0);
808        }
809
810        /**
811         * Creates the Poisson distribution by computing probabilities recursively upward and downward
812         * from {@code X=mode}, the location of the largest p-value.
813         *
814         * @param rng Generator of uniformly distributed random numbers.
815         * @param mean Mean.
816         * @return Sampler.
817         */
818        private static SharedStateDiscreteSampler createPoissonDistributionFromXMode(
819                UniformRandomProvider rng, double mean) {
820            // If mean >= 21.4, generate from largest p-value up, then largest down.
821            // The largest p-value will be at the mode (floor(mean)).
822
823            // Find p(x=mode)
824            final int mode = (int) mean;
825            // This transform is stable until mean >= 1941 where p will result in Infinity
826            // before the divisor i is large enough to start reducing the product (i.e. i > c).
827            final double c = mean * Math.exp(-mean / mode);
828            double p = 1.0;
829            for (int i = 1; i <= mode; i++) {
830                p *= c / i;
831            }
832            final double pMode = p;
833
834            // Find the upper limit using recursive computation of the p-value.
835            // Note this will exit when i overflows to negative so no check on the range
836            int i = mode + 1;
837            while (p * DOUBLE_31 >= 1) {
838                p *= mean / i++;
839            }
840            final int last = i - 2;
841
842            // Find the lower limit using recursive computation of the p-value.
843            p = pMode;
844            int j = -1;
845            for (i = mode - 1; i >= 0; i--) {
846                p *= (i + 1) / mean;
847                if (p * DOUBLE_31 < 1) {
848                    j = i;
849                    break;
850                }
851            }
852
853            // Probabilities are 30-bit integers, assumed denominator 2^30.
854            // This is the minimum sample value: prob[x - offset] = p(x)
855            final int offset = j + 1;
856            final int size = last - offset + 1;
857            final int[] prob = new int[size];
858
859            p = pMode;
860            prob[mode - offset] = toUnsignedInt30(p);
861            // The sum must exceed 2^30. In edges cases this is false due to round-off.
862            int sum = prob[mode - offset];
863            // From mode to upper limit
864            for (i = mode + 1; i <= last; i++) {
865                p *= mean / i;
866                prob[i - offset] = toUnsignedInt30(p);
867                sum += prob[i - offset];
868            }
869            // From mode to lower limit
870            p = pMode;
871            for (i = mode - 1; i >= offset; i--) {
872                p *= (i + 1) / mean;
873                prob[i - offset] = toUnsignedInt30(p);
874                sum += prob[i - offset];
875            }
876
877            // If the sum is < 2^30 add the remaining sum to the mode.
878            // If above 2^30 then the effect is truncation of the long tail of the distribution.
879            prob[mode - offset] += Math.max(0, INT_30 - sum);
880
881            return createSampler(rng, POISSON_NAME, prob, offset);
882        }
883    }
884
885    /**
886     * Create a sampler for the Binomial distribution.
887     */
888    public static final class Binomial {
889        /** The name of the Binomial distribution. */
890        private static final String BINOMIAL_NAME = "Binomial";
891
892        /**
893         * Return a fixed result for the Binomial distribution. This is a special class to handle
894         * an edge case of probability of success equal to 0 or 1.
895         */
896        private static class MarsagliaTsangWangFixedResultBinomialSampler
897            extends AbstractMarsagliaTsangWangDiscreteSampler {
898            /** The result. */
899            private final int result;
900
901            /**
902             * @param result Result.
903             */
904            MarsagliaTsangWangFixedResultBinomialSampler(int result) {
905                super(null, BINOMIAL_NAME);
906                this.result = result;
907            }
908
909            @Override
910            public int sample() {
911                return result;
912            }
913
914            @Override
915            public String toString() {
916                return BINOMIAL_NAME + " deviate";
917            }
918
919            @Override
920            public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
921                // No shared state
922                return this;
923            }
924        }
925
926        /**
927         * Return an inversion result for the Binomial distribution. This assumes the
928         * following:
929         *
930         * <pre>
931         * Binomial(n, p) = 1 - Binomial(n, 1 - p)
932         * </pre>
933         */
934        private static class MarsagliaTsangWangInversionBinomialSampler
935            extends AbstractMarsagliaTsangWangDiscreteSampler {
936            /** The number of trials. */
937            private final int trials;
938            /** The Binomial distribution sampler. */
939            private final SharedStateDiscreteSampler sampler;
940
941            /**
942             * @param trials Number of trials.
943             * @param sampler Binomial distribution sampler.
944             */
945            MarsagliaTsangWangInversionBinomialSampler(int trials,
946                                                       SharedStateDiscreteSampler sampler) {
947                super(null, BINOMIAL_NAME);
948                this.trials = trials;
949                this.sampler = sampler;
950            }
951
952            @Override
953            public int sample() {
954                return trials - sampler.sample();
955            }
956
957            @Override
958            public String toString() {
959                return sampler.toString();
960            }
961
962            @Override
963            public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
964                return new MarsagliaTsangWangInversionBinomialSampler(this.trials,
965                    this.sampler.withUniformRandomProvider(rng));
966            }
967        }
968
969        /** Class contains only static methods. */
970        private Binomial() {}
971
972        /**
973         * Creates a sampler for the Binomial distribution.
974         *
975         * <p>Any probability less than 2<sup>-31</sup> will not be observed in samples.</p>
976         *
977         * <p>Storage requirements depend on the tabulated probability values. Example storage
978         * requirements are listed below (in kB).</p>
979         *
980         * <pre>
981         *          p
982         * trials   0.5    0.1   0.01  0.001
983         *    4    0.06   0.63   0.44   0.44
984         *   16    0.69   1.14   0.76   0.44
985         *   64    4.73   2.40   1.14   0.51
986         *  256    8.63   5.17   1.89   0.82
987         * 1024   31.12   9.45   3.34   0.89
988         * </pre>
989         *
990         * <p>The method requires that the Binomial distribution probability at {@code x=0} can be computed.
991         * This will fail when {@code (1 - p)^trials == 0} which requires {@code trials} to be large
992         * and/or {@code p} to be small. In this case an exception is raised.</p>
993         *
994         * @param rng Generator of uniformly distributed random numbers.
995         * @param trials Number of trials.
996         * @param probabilityOfSuccess Probability of success (p).
997         * @return Sampler.
998         * @throws IllegalArgumentException if {@code trials < 0} or {@code trials >= 2^16},
999         * {@code p} is not in the range {@code [0-1]}, or the probability distribution cannot
1000         * be computed.
1001         */
1002        public static SharedStateDiscreteSampler of(UniformRandomProvider rng,
1003                                                    int trials,
1004                                                    double probabilityOfSuccess) {
1005            validateBinomialDistributionParameters(trials, probabilityOfSuccess);
1006
1007            // Handle edge cases
1008            if (probabilityOfSuccess == 0) {
1009                return new MarsagliaTsangWangFixedResultBinomialSampler(0);
1010            }
1011            if (probabilityOfSuccess == 1) {
1012                return new MarsagliaTsangWangFixedResultBinomialSampler(trials);
1013            }
1014
1015            // Check the supported size.
1016            if (trials >= INT_16) {
1017                throw new IllegalArgumentException("Unsupported number of trials: " + trials);
1018            }
1019
1020            return createBinomialDistributionSampler(rng, trials, probabilityOfSuccess);
1021        }
1022
1023        /**
1024         * Validate the Binomial distribution parameters.
1025         *
1026         * @param trials Number of trials.
1027         * @param probabilityOfSuccess Probability of success (p).
1028         * @throws IllegalArgumentException if {@code trials < 0} or
1029         * {@code p} is not in the range {@code [0-1]}
1030         */
1031        private static void validateBinomialDistributionParameters(int trials, double probabilityOfSuccess) {
1032            if (trials < 0) {
1033                throw new IllegalArgumentException("Trials is not positive: " + trials);
1034            }
1035            if (probabilityOfSuccess < 0 || probabilityOfSuccess > 1) {
1036                throw new IllegalArgumentException("Probability is not in range [0,1]: " + probabilityOfSuccess);
1037            }
1038        }
1039
1040        /**
1041         * Creates the Binomial distribution sampler.
1042         *
1043         * <p>This assumes the parameters for the distribution are valid. The method
1044         * will only fail if the initial probability for {@code X=0} is zero.</p>
1045         *
1046         * @param rng Generator of uniformly distributed random numbers.
1047         * @param trials Number of trials.
1048         * @param probabilityOfSuccess Probability of success (p).
1049         * @return Sampler.
1050         * @throws IllegalArgumentException if the probability distribution cannot be
1051         * computed.
1052         */
1053        private static SharedStateDiscreteSampler createBinomialDistributionSampler(
1054                UniformRandomProvider rng, int trials, double probabilityOfSuccess) {
1055
1056            // The maximum supported value for Math.exp is approximately -744.
1057            // This occurs when trials is large and p is close to 1.
1058            // Handle this by using an inversion: generate j=Binomial(n,1-p), return n-j
1059            final boolean useInversion = probabilityOfSuccess > 0.5;
1060            final double p = useInversion ? 1 - probabilityOfSuccess : probabilityOfSuccess;
1061
1062            // Check if the distribution can be computed
1063            final double p0 = Math.exp(trials * Math.log(1 - p));
1064            if (p0 < Double.MIN_VALUE) {
1065                throw new IllegalArgumentException("Unable to compute distribution");
1066            }
1067
1068            // First find size of probability array
1069            double t = p0;
1070            final double h = p / (1 - p);
1071            // Find first probability above the threshold of 2^-31
1072            int begin = 0;
1073            if (t * DOUBLE_31 < 1) {
1074                // Somewhere after p(0)
1075                // Note:
1076                // If this loop is entered p(0) is < 2^-31.
1077                // This has been tested at the extreme for p(0)=Double.MIN_VALUE and either
1078                // p=0.5 or trials=2^16-1 and does not fail to find the beginning.
1079                for (int i = 1; i <= trials; i++) {
1080                    t *= (trials + 1 - i) * h / i;
1081                    if (t * DOUBLE_31 >= 1) {
1082                        begin = i;
1083                        break;
1084                    }
1085                }
1086            }
1087            // Find last probability
1088            int end = trials;
1089            for (int i = begin + 1; i <= trials; i++) {
1090                t *= (trials + 1 - i) * h / i;
1091                if (t * DOUBLE_31 < 1) {
1092                    end = i - 1;
1093                    break;
1094                }
1095            }
1096
1097            return createBinomialDistributionSamplerFromRange(rng, trials, p, useInversion,
1098                    p0, begin, end);
1099        }
1100
1101        /**
1102         * Creates the Binomial distribution sampler using only the probability values for {@code X}
1103         * between the begin and the end (inclusive).
1104         *
1105         * @param rng Generator of uniformly distributed random numbers.
1106         * @param trials Number of trials.
1107         * @param p Probability of success (p).
1108         * @param useInversion Set to {@code true} if the probability was inverted.
1109         * @param p0 Probability at {@code X=0}
1110         * @param begin Begin value {@code X} for the distribution.
1111         * @param end End value {@code X} for the distribution.
1112         * @return Sampler.
1113         */
1114        private static SharedStateDiscreteSampler createBinomialDistributionSamplerFromRange(
1115                UniformRandomProvider rng, int trials, double p,
1116                boolean useInversion, double p0, int begin, int end) {
1117
1118            // Assign probability values as 30-bit integers
1119            final int size = end - begin + 1;
1120            final int[] prob = new int[size];
1121            double t = p0;
1122            final double h = p / (1 - p);
1123            for (int i = 1; i <= begin; i++) {
1124                t *= (trials + 1 - i) * h / i;
1125            }
1126            int sum = toUnsignedInt30(t);
1127            prob[0] = sum;
1128            for (int i = begin + 1; i <= end; i++) {
1129                t *= (trials + 1 - i) * h / i;
1130                prob[i - begin] = toUnsignedInt30(t);
1131                sum += prob[i - begin];
1132            }
1133
1134            // If the sum is < 2^30 add the remaining sum to the mode (floor((n+1)p))).
1135            // If above 2^30 then the effect is truncation of the long tail of the distribution.
1136            final int mode = (int) ((trials + 1) * p) - begin;
1137            prob[mode] += Math.max(0, INT_30 - sum);
1138
1139            final SharedStateDiscreteSampler sampler = createSampler(rng, BINOMIAL_NAME, prob, begin);
1140
1141            // Check if an inversion was made
1142            return useInversion ?
1143                   new MarsagliaTsangWangInversionBinomialSampler(trials, sampler) :
1144                   sampler;
1145        }
1146    }
1147}