View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.rng.sampling.distribution;
18  
19  import org.apache.commons.rng.UniformRandomProvider;
20  
21  /**
22   * Sampler for a discrete distribution using an optimised look-up table.
23   *
24   * <ul>
25   *  <li>
26   *   The method requires 30-bit integer probabilities that sum to 2<sup>30</sup> as described
27   *   in George Marsaglia, Wai Wan Tsang, Jingbo Wang (2004) Fast Generation of Discrete
28   *   Random Variables. Journal of Statistical Software. Vol. 11, Issue. 3, pp. 1-11.
29   *  </li>
30   * </ul>
31   *
32   * <p>Sampling uses 1 call to {@link UniformRandomProvider#nextInt()}.</p>
33   *
34   * <p>Memory requirements depend on the maximum number of possible sample values, {@code n},
35   * and the values for the probabilities. Storage is optimised for {@code n}. The worst case
36   * scenario is a uniform distribution of the maximum sample size. This is capped at 0.06MB for
37   * {@code n <= } 2<sup>8</sup>, 17.0MB for {@code n <= } 2<sup>16</sup>, and 4.3GB for
38   * {@code n <=} 2<sup>30</sup>. Realistic requirements will be in the kB range.</p>
39   *
40   * <p>The sampler supports the following distributions:</p>
41   *
42   * <ul>
43   *  <li>Enumerated distribution (probabilities must be provided for each sample)
44   *  <li>Poisson distribution up to {@code mean = 1024}
45   *  <li>Binomial distribution up to {@code trials = 65535}
46   * </ul>
47   *
48   * @see <a href="http://dx.doi.org/10.18637/jss.v011.i03">Margsglia, et al (2004) JSS Vol.
49   * 11, Issue 3</a>
50   * @since 1.3
51   */
52  public final class MarsagliaTsangWangDiscreteSampler {
53      /** The value 2<sup>8</sup> as an {@code int}. */
54      private static final int INT_8 = 1 << 8;
55      /** The value 2<sup>16</sup> as an {@code int}. */
56      private static final int INT_16 = 1 << 16;
57      /** The value 2<sup>30</sup> as an {@code int}. */
58      private static final int INT_30 = 1 << 30;
59      /** The value 2<sup>31</sup> as a {@code double}. */
60      private static final double DOUBLE_31 = 1L << 31;
61  
62      // =========================================================================
63      // Implementation note:
64      //
65      // This sampler uses prepared look-up tables that are searched using a single
66      // random int variate. The look-up tables contain the sample value. The tables
67      // are constructed using probabilities that sum to 2^30. The original paper
68      // by Marsaglia, et al (2004) describes the use of 5, 3, or 2 look-up tables
69      // indexed using digits of base 2^6, 2^10 or 2^15. Currently only base 64 (2^6)
70      // is supported using 5 look-up tables.
71      //
72      // The implementations use 8, 16 or 32 bit storage tables to support different
73      // distribution sizes with optimal storage. Separate class implementations of
74      // the same algorithm allow array storage to be accessed directly from 1D tables.
75      // This provides a performance gain over using: abstracted storage accessed via
76      // an interface; or a single 2D table.
77      //
78      // To allow the optimal implementation to be chosen the sampler is created
79      // using factory methods. The sampler supports any probability distribution
80      // when provided via an array of probabilities and the Poisson and Binomial
81      // distributions for a restricted set of parameters. The restrictions are
82      // imposed by the requirement to compute the entire probability distribution
83      // from the controlling parameter(s) using a recursive method. Factory
84      // constructors return a SharedStateDiscreteSampler instance. Each distribution
85      // type is contained in an inner class.
86      // =========================================================================
87  
88      /**
89       * The base class for Marsaglia-Tsang-Wang samplers.
90       */
91      private abstract static class AbstractMarsagliaTsangWangDiscreteSampler
92              implements SharedStateDiscreteSampler {
93          /** Underlying source of randomness. */
94          protected final UniformRandomProvider rng;
95  
96          /** The name of the distribution. */
97          private final String distributionName;
98  
99          /**
100          * @param rng Generator of uniformly distributed random numbers.
101          * @param distributionName Distribution name.
102          */
103         AbstractMarsagliaTsangWangDiscreteSampler(UniformRandomProvider rng,
104                                                   String distributionName) {
105             this.rng = rng;
106             this.distributionName = distributionName;
107         }
108 
109         /**
110          * @param rng Generator of uniformly distributed random numbers.
111          * @param source Source to copy.
112          */
113         AbstractMarsagliaTsangWangDiscreteSampler(UniformRandomProvider rng,
114                                                   AbstractMarsagliaTsangWangDiscreteSampler source) {
115             this.rng = rng;
116             this.distributionName = source.distributionName;
117         }
118 
119         /** {@inheritDoc} */
120         @Override
121         public String toString() {
122             return "Marsaglia Tsang Wang " + distributionName + " deviate [" + rng.toString() + "]";
123         }
124     }
125 
126     /**
127      * An implementation for the sample algorithm based on the decomposition of the
128      * index in the range {@code [0,2^30)} into 5 base-64 digits with 8-bit backing storage.
129      */
130     private static class MarsagliaTsangWangBase64Int8DiscreteSampler
131         extends AbstractMarsagliaTsangWangDiscreteSampler {
132         /** The mask to convert a {@code byte} to an unsigned 8-bit integer. */
133         private static final int MASK = 0xff;
134 
135         /** Limit for look-up table 1. */
136         private final int t1;
137         /** Limit for look-up table 2. */
138         private final int t2;
139         /** Limit for look-up table 3. */
140         private final int t3;
141         /** Limit for look-up table 4. */
142         private final int t4;
143 
144         /** Look-up table table1. */
145         private final byte[] table1;
146         /** Look-up table table2. */
147         private final byte[] table2;
148         /** Look-up table table3. */
149         private final byte[] table3;
150         /** Look-up table table4. */
151         private final byte[] table4;
152         /** Look-up table table5. */
153         private final byte[] table5;
154 
155         /**
156          * @param rng Generator of uniformly distributed random numbers.
157          * @param distributionName Distribution name.
158          * @param prob The probabilities.
159          * @param offset The offset (must be positive).
160          */
161         MarsagliaTsangWangBase64Int8DiscreteSampler(UniformRandomProvider rng,
162                                                     String distributionName,
163                                                     int[] prob,
164                                                     int offset) {
165             super(rng, distributionName);
166 
167             // Get table sizes for each base-64 digit
168             int n1 = 0;
169             int n2 = 0;
170             int n3 = 0;
171             int n4 = 0;
172             int n5 = 0;
173             for (final int m : prob) {
174                 n1 += getBase64Digit(m, 1);
175                 n2 += getBase64Digit(m, 2);
176                 n3 += getBase64Digit(m, 3);
177                 n4 += getBase64Digit(m, 4);
178                 n5 += getBase64Digit(m, 5);
179             }
180 
181             table1 = new byte[n1];
182             table2 = new byte[n2];
183             table3 = new byte[n3];
184             table4 = new byte[n4];
185             table5 = new byte[n5];
186 
187             // Compute offsets
188             t1 = n1 << 24;
189             t2 = t1 + (n2 << 18);
190             t3 = t2 + (n3 << 12);
191             t4 = t3 + (n4 << 6);
192             n1 = n2 = n3 = n4 = n5 = 0;
193 
194             // Fill tables
195             for (int i = 0; i < prob.length; i++) {
196                 final int m = prob[i];
197                 // Primitive type conversion will extract lower 8 bits
198                 final byte k = (byte) (i + offset);
199                 n1 = fill(table1, n1, n1 + getBase64Digit(m, 1), k);
200                 n2 = fill(table2, n2, n2 + getBase64Digit(m, 2), k);
201                 n3 = fill(table3, n3, n3 + getBase64Digit(m, 3), k);
202                 n4 = fill(table4, n4, n4 + getBase64Digit(m, 4), k);
203                 n5 = fill(table5, n5, n5 + getBase64Digit(m, 5), k);
204             }
205         }
206 
207         /**
208          * @param rng Generator of uniformly distributed random numbers.
209          * @param source Source to copy.
210          */
211         private MarsagliaTsangWangBase64Int8DiscreteSampler(UniformRandomProvider rng,
212                 MarsagliaTsangWangBase64Int8DiscreteSampler source) {
213             super(rng, source);
214             t1 = source.t1;
215             t2 = source.t2;
216             t3 = source.t3;
217             t4 = source.t4;
218             table1 = source.table1;
219             table2 = source.table2;
220             table3 = source.table3;
221             table4 = source.table4;
222             table5 = source.table5;
223         }
224 
225         /**
226          * Fill the table with the value.
227          *
228          * @param table Table.
229          * @param from Lower bound index (inclusive)
230          * @param to Upper bound index (exclusive)
231          * @param value Value.
232          * @return the upper bound index
233          */
234         private static int fill(byte[] table, int from, int to, byte value) {
235             for (int i = from; i < to; i++) {
236                 table[i] = value;
237             }
238             return to;
239         }
240 
241         @Override
242         public int sample() {
243             final int j = rng.nextInt() >>> 2;
244             if (j < t1) {
245                 return table1[j >>> 24] & MASK;
246             }
247             if (j < t2) {
248                 return table2[(j - t1) >>> 18] & MASK;
249             }
250             if (j < t3) {
251                 return table3[(j - t2) >>> 12] & MASK;
252             }
253             if (j < t4) {
254                 return table4[(j - t3) >>> 6] & MASK;
255             }
256             // Note the tables are filled on the assumption that the sum of the probabilities.
257             // is >=2^30. If this is not true then the final table table5 will be smaller by the
258             // difference. So the tables *must* be constructed correctly.
259             return table5[j - t4] & MASK;
260         }
261 
262         @Override
263         public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
264             return new MarsagliaTsangWangBase64Int8DiscreteSampler(rng, this);
265         }
266     }
267 
268     /**
269      * An implementation for the sample algorithm based on the decomposition of the
270      * index in the range {@code [0,2^30)} into 5 base-64 digits with 16-bit backing storage.
271      */
272     private static class MarsagliaTsangWangBase64Int16DiscreteSampler
273         extends AbstractMarsagliaTsangWangDiscreteSampler {
274         /** The mask to convert a {@code byte} to an unsigned 16-bit integer. */
275         private static final int MASK = 0xffff;
276 
277         /** Limit for look-up table 1. */
278         private final int t1;
279         /** Limit for look-up table 2. */
280         private final int t2;
281         /** Limit for look-up table 3. */
282         private final int t3;
283         /** Limit for look-up table 4. */
284         private final int t4;
285 
286         /** Look-up table table1. */
287         private final short[] table1;
288         /** Look-up table table2. */
289         private final short[] table2;
290         /** Look-up table table3. */
291         private final short[] table3;
292         /** Look-up table table4. */
293         private final short[] table4;
294         /** Look-up table table5. */
295         private final short[] table5;
296 
297         /**
298          * @param rng Generator of uniformly distributed random numbers.
299          * @param distributionName Distribution name.
300          * @param prob The probabilities.
301          * @param offset The offset (must be positive).
302          */
303         MarsagliaTsangWangBase64Int16DiscreteSampler(UniformRandomProvider rng,
304                                                      String distributionName,
305                                                      int[] prob,
306                                                      int offset) {
307             super(rng, distributionName);
308 
309             // Get table sizes for each base-64 digit
310             int n1 = 0;
311             int n2 = 0;
312             int n3 = 0;
313             int n4 = 0;
314             int n5 = 0;
315             for (final int m : prob) {
316                 n1 += getBase64Digit(m, 1);
317                 n2 += getBase64Digit(m, 2);
318                 n3 += getBase64Digit(m, 3);
319                 n4 += getBase64Digit(m, 4);
320                 n5 += getBase64Digit(m, 5);
321             }
322 
323             table1 = new short[n1];
324             table2 = new short[n2];
325             table3 = new short[n3];
326             table4 = new short[n4];
327             table5 = new short[n5];
328 
329             // Compute offsets
330             t1 = n1 << 24;
331             t2 = t1 + (n2 << 18);
332             t3 = t2 + (n3 << 12);
333             t4 = t3 + (n4 << 6);
334             n1 = n2 = n3 = n4 = n5 = 0;
335 
336             // Fill tables
337             for (int i = 0; i < prob.length; i++) {
338                 final int m = prob[i];
339                 // Primitive type conversion will extract lower 16 bits
340                 final short k = (short) (i + offset);
341                 n1 = fill(table1, n1, n1 + getBase64Digit(m, 1), k);
342                 n2 = fill(table2, n2, n2 + getBase64Digit(m, 2), k);
343                 n3 = fill(table3, n3, n3 + getBase64Digit(m, 3), k);
344                 n4 = fill(table4, n4, n4 + getBase64Digit(m, 4), k);
345                 n5 = fill(table5, n5, n5 + getBase64Digit(m, 5), k);
346             }
347         }
348 
349         /**
350          * @param rng Generator of uniformly distributed random numbers.
351          * @param source Source to copy.
352          */
353         private MarsagliaTsangWangBase64Int16DiscreteSampler(UniformRandomProvider rng,
354                 MarsagliaTsangWangBase64Int16DiscreteSampler source) {
355             super(rng, source);
356             t1 = source.t1;
357             t2 = source.t2;
358             t3 = source.t3;
359             t4 = source.t4;
360             table1 = source.table1;
361             table2 = source.table2;
362             table3 = source.table3;
363             table4 = source.table4;
364             table5 = source.table5;
365         }
366 
367         /**
368          * Fill the table with the value.
369          *
370          * @param table Table.
371          * @param from Lower bound index (inclusive)
372          * @param to Upper bound index (exclusive)
373          * @param value Value.
374          * @return the upper bound index
375          */
376         private static int fill(short[] table, int from, int to, short value) {
377             for (int i = from; i < to; i++) {
378                 table[i] = value;
379             }
380             return to;
381         }
382 
383         @Override
384         public int sample() {
385             final int j = rng.nextInt() >>> 2;
386             if (j < t1) {
387                 return table1[j >>> 24] & MASK;
388             }
389             if (j < t2) {
390                 return table2[(j - t1) >>> 18] & MASK;
391             }
392             if (j < t3) {
393                 return table3[(j - t2) >>> 12] & MASK;
394             }
395             if (j < t4) {
396                 return table4[(j - t3) >>> 6] & MASK;
397             }
398             // Note the tables are filled on the assumption that the sum of the probabilities.
399             // is >=2^30. If this is not true then the final table table5 will be smaller by the
400             // difference. So the tables *must* be constructed correctly.
401             return table5[j - t4] & MASK;
402         }
403 
404         @Override
405         public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
406             return new MarsagliaTsangWangBase64Int16DiscreteSampler(rng, this);
407         }
408     }
409 
410     /**
411      * An implementation for the sample algorithm based on the decomposition of the
412      * index in the range {@code [0,2^30)} into 5 base-64 digits with 32-bit backing storage.
413      */
414     private static class MarsagliaTsangWangBase64Int32DiscreteSampler
415         extends AbstractMarsagliaTsangWangDiscreteSampler {
416         /** Limit for look-up table 1. */
417         private final int t1;
418         /** Limit for look-up table 2. */
419         private final int t2;
420         /** Limit for look-up table 3. */
421         private final int t3;
422         /** Limit for look-up table 4. */
423         private final int t4;
424 
425         /** Look-up table table1. */
426         private final int[] table1;
427         /** Look-up table table2. */
428         private final int[] table2;
429         /** Look-up table table3. */
430         private final int[] table3;
431         /** Look-up table table4. */
432         private final int[] table4;
433         /** Look-up table table5. */
434         private final int[] table5;
435 
436         /**
437          * @param rng Generator of uniformly distributed random numbers.
438          * @param distributionName Distribution name.
439          * @param prob The probabilities.
440          * @param offset The offset (must be positive).
441          */
442         MarsagliaTsangWangBase64Int32DiscreteSampler(UniformRandomProvider rng,
443                                                      String distributionName,
444                                                      int[] prob,
445                                                      int offset) {
446             super(rng, distributionName);
447 
448             // Get table sizes for each base-64 digit
449             int n1 = 0;
450             int n2 = 0;
451             int n3 = 0;
452             int n4 = 0;
453             int n5 = 0;
454             for (final int m : prob) {
455                 n1 += getBase64Digit(m, 1);
456                 n2 += getBase64Digit(m, 2);
457                 n3 += getBase64Digit(m, 3);
458                 n4 += getBase64Digit(m, 4);
459                 n5 += getBase64Digit(m, 5);
460             }
461 
462             table1 = new int[n1];
463             table2 = new int[n2];
464             table3 = new int[n3];
465             table4 = new int[n4];
466             table5 = new int[n5];
467 
468             // Compute offsets
469             t1 = n1 << 24;
470             t2 = t1 + (n2 << 18);
471             t3 = t2 + (n3 << 12);
472             t4 = t3 + (n4 << 6);
473             n1 = n2 = n3 = n4 = n5 = 0;
474 
475             // Fill tables
476             for (int i = 0; i < prob.length; i++) {
477                 final int m = prob[i];
478                 final int k = i + offset;
479                 n1 = fill(table1, n1, n1 + getBase64Digit(m, 1), k);
480                 n2 = fill(table2, n2, n2 + getBase64Digit(m, 2), k);
481                 n3 = fill(table3, n3, n3 + getBase64Digit(m, 3), k);
482                 n4 = fill(table4, n4, n4 + getBase64Digit(m, 4), k);
483                 n5 = fill(table5, n5, n5 + getBase64Digit(m, 5), k);
484             }
485         }
486 
487         /**
488          * @param rng Generator of uniformly distributed random numbers.
489          * @param source Source to copy.
490          */
491         private MarsagliaTsangWangBase64Int32DiscreteSampler(UniformRandomProvider rng,
492                 MarsagliaTsangWangBase64Int32DiscreteSampler source) {
493             super(rng, source);
494             t1 = source.t1;
495             t2 = source.t2;
496             t3 = source.t3;
497             t4 = source.t4;
498             table1 = source.table1;
499             table2 = source.table2;
500             table3 = source.table3;
501             table4 = source.table4;
502             table5 = source.table5;
503         }
504 
505         /**
506          * Fill the table with the value.
507          *
508          * @param table Table.
509          * @param from Lower bound index (inclusive)
510          * @param to Upper bound index (exclusive)
511          * @param value Value.
512          * @return the upper bound index
513          */
514         private static int fill(int[] table, int from, int to, int value) {
515             for (int i = from; i < to; i++) {
516                 table[i] = value;
517             }
518             return to;
519         }
520 
521         @Override
522         public int sample() {
523             final int j = rng.nextInt() >>> 2;
524             if (j < t1) {
525                 return table1[j >>> 24];
526             }
527             if (j < t2) {
528                 return table2[(j - t1) >>> 18];
529             }
530             if (j < t3) {
531                 return table3[(j - t2) >>> 12];
532             }
533             if (j < t4) {
534                 return table4[(j - t3) >>> 6];
535             }
536             // Note the tables are filled on the assumption that the sum of the probabilities.
537             // is >=2^30. If this is not true then the final table table5 will be smaller by the
538             // difference. So the tables *must* be constructed correctly.
539             return table5[j - t4];
540         }
541 
542         @Override
543         public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
544             return new MarsagliaTsangWangBase64Int32DiscreteSampler(rng, this);
545         }
546     }
547 
548 
549 
550     /** Class contains only static methods. */
551     private MarsagliaTsangWangDiscreteSampler() {}
552 
553     /**
554      * Gets the k<sup>th</sup> base 64 digit of {@code m}.
555      *
556      * @param m the value m.
557      * @param k the digit.
558      * @return the base 64 digit
559      */
560     private static int getBase64Digit(int m, int k) {
561         return (m >>> (30 - 6 * k)) & 63;
562     }
563 
564     /**
565      * Convert the probability to an integer in the range [0,2^30]. This is the numerator of
566      * a fraction with assumed denominator 2<sup>30</sup>.
567      *
568      * @param p Probability.
569      * @return the fraction numerator
570      */
571     private static int toUnsignedInt30(double p) {
572         return (int) (p * INT_30 + 0.5);
573     }
574 
575     /**
576      * Create a new instance for probabilities {@code p(i)} where the sample value {@code x} is
577      * {@code i + offset}.
578      *
579      * <p>The sum of the probabilities must be >= 2<sup>30</sup>. Only the
580      * values for cumulative probability up to 2<sup>30</sup> will be sampled.</p>
581      *
582      * @param rng Generator of uniformly distributed random numbers.
583      * @param distributionName Distribution name.
584      * @param prob The probabilities.
585      * @param offset The offset (must be positive).
586      * @return Sampler.
587      */
588     private static SharedStateDiscreteSampler createSampler(UniformRandomProvider rng,
589                                                             String distributionName,
590                                                             int[] prob,
591                                                             int offset) {
592         // Note: No argument checks for private method.
593 
594         // Choose implementation based on the maximum index
595         final int maxIndex = prob.length + offset - 1;
596         if (maxIndex < INT_8) {
597             return new MarsagliaTsangWangBase64Int8DiscreteSampler(rng, distributionName, prob, offset);
598         }
599         if (maxIndex < INT_16) {
600             return new MarsagliaTsangWangBase64Int16DiscreteSampler(rng, distributionName, prob, offset);
601         }
602         return new MarsagliaTsangWangBase64Int32DiscreteSampler(rng, distributionName, prob, offset);
603     }
604 
605     // =========================================================================
606     // The following public classes provide factory methods to construct a sampler for:
607     // - Enumerated probability distribution (from provided double[] probabilities)
608     // - Poisson distribution for mean <= 1024
609     // - Binomial distribution for trials <= 65535
610     // =========================================================================
611 
612     /**
613      * Create a sampler for an enumerated distribution of {@code n} values each with an
614      * associated probability.
615      * The samples corresponding to each probability are assumed to be a natural sequence
616      * starting at zero.
617      */
618     public static final class Enumerated {
619         /** The name of the enumerated probability distribution. */
620         private static final String ENUMERATED_NAME = "Enumerated";
621 
622         /** Class contains only static methods. */
623         private Enumerated() {}
624 
625         /**
626          * Creates a sampler for an enumerated distribution of {@code n} values each with an
627          * associated probability.
628          *
629          * <p>The probabilities will be normalised using their sum. The only requirement
630          * is the sum is positive.</p>
631          *
632          * <p>The sum of the probabilities is normalised to 2<sup>30</sup>. Note that
633          * probabilities are adjusted to the nearest 2<sup>-30</sup> due to round-off during
634          * the normalisation conversion. Consequently any probability less than 2<sup>-31</sup>
635          * will not be observed in samples.</p>
636          *
637          * @param rng Generator of uniformly distributed random numbers.
638          * @param probabilities The list of probabilities.
639          * @return Sampler.
640          * @throws IllegalArgumentException if {@code probabilities} is null or empty, a
641          * probability is negative, infinite or {@code NaN}, or the sum of all
642          * probabilities is not strictly positive.
643          */
644         public static SharedStateDiscreteSampler of(UniformRandomProvider rng,
645                                                     double[] probabilities) {
646             return createSampler(rng, ENUMERATED_NAME, normaliseProbabilities(probabilities), 0);
647         }
648 
649         /**
650          * Normalise the probabilities to integers that sum to 2<sup>30</sup>.
651          *
652          * @param probabilities The list of probabilities.
653          * @return the normalised probabilities.
654          * @throws IllegalArgumentException if {@code probabilities} is null or empty, a
655          * probability is negative, infinite or {@code NaN}, or the sum of all
656          * probabilities is not strictly positive.
657          */
658         private static int[] normaliseProbabilities(double[] probabilities) {
659             final double sumProb = InternalUtils.validateProbabilities(probabilities);
660 
661             // Compute the normalisation: 2^30 / sum
662             final double normalisation = INT_30 / sumProb;
663             final int[] prob = new int[probabilities.length];
664             int sum = 0;
665             int max = 0;
666             int mode = 0;
667             for (int i = 0; i < prob.length; i++) {
668                 // Add 0.5 for rounding
669                 final int p = (int) (probabilities[i] * normalisation + 0.5);
670                 sum += p;
671                 // Find the mode (maximum probability)
672                 if (max < p) {
673                     max = p;
674                     mode = i;
675                 }
676                 prob[i] = p;
677             }
678 
679             // The sum must be >= 2^30.
680             // Here just compensate the difference onto the highest probability.
681             prob[mode] += INT_30 - sum;
682 
683             return prob;
684         }
685     }
686 
687     /**
688      * Create a sampler for the Poisson distribution.
689      */
690     public static final class Poisson {
691         /** The name of the Poisson distribution. */
692         private static final String POISSON_NAME = "Poisson";
693 
694         /**
695          * Upper bound on the mean for the Poisson distribution.
696          *
697          * <p>The original source code provided in Marsaglia, et al (2004) has no explicit
698          * limit but the code fails at mean >= 1941 as the transform to compute p(x=mode)
699          * produces infinity. Use a conservative limit of 1024.</p>
700          */
701 
702         private static final double MAX_MEAN = 1024;
703         /**
704          * The threshold for the mean of the Poisson distribution to switch the method used
705          * to compute the probabilities. This is taken from the example software provided by
706          * Marsaglia, et al (2004).
707          */
708         private static final double MEAN_THRESHOLD = 21.4;
709 
710         /** Class contains only static methods. */
711         private Poisson() {}
712 
713         /**
714          * Creates a sampler for the Poisson distribution.
715          *
716          * <p>Any probability less than 2<sup>-31</sup> will not be observed in samples.</p>
717          *
718          * <p>Storage requirements depend on the tabulated probability values. Example storage
719          * requirements are listed below.</p>
720          *
721          * <pre>
722          * mean      table size     kB
723          * 0.25      882            0.88
724          * 0.5       1135           1.14
725          * 1         1200           1.20
726          * 2         1451           1.45
727          * 4         1955           1.96
728          * 8         2961           2.96
729          * 16        4410           4.41
730          * 32        6115           6.11
731          * 64        8499           8.50
732          * 128       11528          11.53
733          * 256       15935          31.87
734          * 512       20912          41.82
735          * 1024      30614          61.23
736          * </pre>
737          *
738          * <p>Note: Storage changes to 2 bytes per index between {@code mean=128} and {@code mean=256}.</p>
739          *
740          * @param rng Generator of uniformly distributed random numbers.
741          * @param mean Mean.
742          * @return Sampler.
743          * @throws IllegalArgumentException if {@code mean <= 0} or {@code mean > 1024}.
744          */
745         public static SharedStateDiscreteSampler of(UniformRandomProvider rng,
746                                                     double mean) {
747             validatePoissonDistributionParameters(mean);
748 
749             // Create the distribution either from X=0 or from X=mode when the mean is high.
750             return mean < MEAN_THRESHOLD ?
751                 createPoissonDistributionFromX0(rng, mean) :
752                 createPoissonDistributionFromXMode(rng, mean);
753         }
754 
755         /**
756          * Validate the Poisson distribution parameters.
757          *
758          * @param mean Mean.
759          * @throws IllegalArgumentException if {@code mean <= 0} or {@code mean > 1024}.
760          */
761         private static void validatePoissonDistributionParameters(double mean) {
762             if (mean <= 0) {
763                 throw new IllegalArgumentException("mean is not strictly positive: " + mean);
764             }
765             if (mean > MAX_MEAN) {
766                 throw new IllegalArgumentException("mean " + mean + " > " + MAX_MEAN);
767             }
768         }
769 
770         /**
771          * Creates the Poisson distribution by computing probabilities recursively from {@code X=0}.
772          *
773          * @param rng Generator of uniformly distributed random numbers.
774          * @param mean Mean.
775          * @return Sampler.
776          */
777         private static SharedStateDiscreteSampler createPoissonDistributionFromX0(
778                 UniformRandomProvider rng, double mean) {
779             final double p0 = Math.exp(-mean);
780 
781             // Recursive update of Poisson probability until the value is too small
782             // p(x + 1) = p(x) * mean / (x + 1)
783             double p = p0;
784             int i = 1;
785             while (p * DOUBLE_31 >= 1) {
786                 p *= mean / i++;
787             }
788 
789             // Probabilities are 30-bit integers, assumed denominator 2^30
790             final int size = i - 1;
791             final int[] prob = new int[size];
792 
793             p = p0;
794             prob[0] = toUnsignedInt30(p);
795             // The sum must exceed 2^30. In edges cases this is false due to round-off.
796             int sum = prob[0];
797             for (i = 1; i < prob.length; i++) {
798                 p *= mean / i;
799                 prob[i] = toUnsignedInt30(p);
800                 sum += prob[i];
801             }
802 
803             // If the sum is < 2^30 add the remaining sum to the mode (floor(mean)).
804             prob[(int) mean] += Math.max(0, INT_30 - sum);
805 
806             // Note: offset = 0
807             return createSampler(rng, POISSON_NAME, prob, 0);
808         }
809 
810         /**
811          * Creates the Poisson distribution by computing probabilities recursively upward and downward
812          * from {@code X=mode}, the location of the largest p-value.
813          *
814          * @param rng Generator of uniformly distributed random numbers.
815          * @param mean Mean.
816          * @return Sampler.
817          */
818         private static SharedStateDiscreteSampler createPoissonDistributionFromXMode(
819                 UniformRandomProvider rng, double mean) {
820             // If mean >= 21.4, generate from largest p-value up, then largest down.
821             // The largest p-value will be at the mode (floor(mean)).
822 
823             // Find p(x=mode)
824             final int mode = (int) mean;
825             // This transform is stable until mean >= 1941 where p will result in Infinity
826             // before the divisor i is large enough to start reducing the product (i.e. i > c).
827             final double c = mean * Math.exp(-mean / mode);
828             double p = 1.0;
829             for (int i = 1; i <= mode; i++) {
830                 p *= c / i;
831             }
832             final double pMode = p;
833 
834             // Find the upper limit using recursive computation of the p-value.
835             // Note this will exit when i overflows to negative so no check on the range
836             int i = mode + 1;
837             while (p * DOUBLE_31 >= 1) {
838                 p *= mean / i++;
839             }
840             final int last = i - 2;
841 
842             // Find the lower limit using recursive computation of the p-value.
843             p = pMode;
844             int j = -1;
845             for (i = mode - 1; i >= 0; i--) {
846                 p *= (i + 1) / mean;
847                 if (p * DOUBLE_31 < 1) {
848                     j = i;
849                     break;
850                 }
851             }
852 
853             // Probabilities are 30-bit integers, assumed denominator 2^30.
854             // This is the minimum sample value: prob[x - offset] = p(x)
855             final int offset = j + 1;
856             final int size = last - offset + 1;
857             final int[] prob = new int[size];
858 
859             p = pMode;
860             prob[mode - offset] = toUnsignedInt30(p);
861             // The sum must exceed 2^30. In edges cases this is false due to round-off.
862             int sum = prob[mode - offset];
863             // From mode to upper limit
864             for (i = mode + 1; i <= last; i++) {
865                 p *= mean / i;
866                 prob[i - offset] = toUnsignedInt30(p);
867                 sum += prob[i - offset];
868             }
869             // From mode to lower limit
870             p = pMode;
871             for (i = mode - 1; i >= offset; i--) {
872                 p *= (i + 1) / mean;
873                 prob[i - offset] = toUnsignedInt30(p);
874                 sum += prob[i - offset];
875             }
876 
877             // If the sum is < 2^30 add the remaining sum to the mode.
878             // If above 2^30 then the effect is truncation of the long tail of the distribution.
879             prob[mode - offset] += Math.max(0, INT_30 - sum);
880 
881             return createSampler(rng, POISSON_NAME, prob, offset);
882         }
883     }
884 
885     /**
886      * Create a sampler for the Binomial distribution.
887      */
888     public static final class Binomial {
889         /** The name of the Binomial distribution. */
890         private static final String BINOMIAL_NAME = "Binomial";
891 
892         /**
893          * Return a fixed result for the Binomial distribution. This is a special class to handle
894          * an edge case of probability of success equal to 0 or 1.
895          */
896         private static class MarsagliaTsangWangFixedResultBinomialSampler
897             extends AbstractMarsagliaTsangWangDiscreteSampler {
898             /** The result. */
899             private final int result;
900 
901             /**
902              * @param result Result.
903              */
904             MarsagliaTsangWangFixedResultBinomialSampler(int result) {
905                 super(null, BINOMIAL_NAME);
906                 this.result = result;
907             }
908 
909             @Override
910             public int sample() {
911                 return result;
912             }
913 
914             @Override
915             public String toString() {
916                 return BINOMIAL_NAME + " deviate";
917             }
918 
919             @Override
920             public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
921                 // No shared state
922                 return this;
923             }
924         }
925 
926         /**
927          * Return an inversion result for the Binomial distribution. This assumes the
928          * following:
929          *
930          * <pre>
931          * Binomial(n, p) = 1 - Binomial(n, 1 - p)
932          * </pre>
933          */
934         private static class MarsagliaTsangWangInversionBinomialSampler
935             extends AbstractMarsagliaTsangWangDiscreteSampler {
936             /** The number of trials. */
937             private final int trials;
938             /** The Binomial distribution sampler. */
939             private final SharedStateDiscreteSampler sampler;
940 
941             /**
942              * @param trials Number of trials.
943              * @param sampler Binomial distribution sampler.
944              */
945             MarsagliaTsangWangInversionBinomialSampler(int trials,
946                                                        SharedStateDiscreteSampler sampler) {
947                 super(null, BINOMIAL_NAME);
948                 this.trials = trials;
949                 this.sampler = sampler;
950             }
951 
952             @Override
953             public int sample() {
954                 return trials - sampler.sample();
955             }
956 
957             @Override
958             public String toString() {
959                 return sampler.toString();
960             }
961 
962             @Override
963             public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
964                 return new MarsagliaTsangWangInversionBinomialSampler(this.trials,
965                     this.sampler.withUniformRandomProvider(rng));
966             }
967         }
968 
969         /** Class contains only static methods. */
970         private Binomial() {}
971 
972         /**
973          * Creates a sampler for the Binomial distribution.
974          *
975          * <p>Any probability less than 2<sup>-31</sup> will not be observed in samples.</p>
976          *
977          * <p>Storage requirements depend on the tabulated probability values. Example storage
978          * requirements are listed below (in kB).</p>
979          *
980          * <pre>
981          *          p
982          * trials   0.5    0.1   0.01  0.001
983          *    4    0.06   0.63   0.44   0.44
984          *   16    0.69   1.14   0.76   0.44
985          *   64    4.73   2.40   1.14   0.51
986          *  256    8.63   5.17   1.89   0.82
987          * 1024   31.12   9.45   3.34   0.89
988          * </pre>
989          *
990          * <p>The method requires that the Binomial distribution probability at {@code x=0} can be computed.
991          * This will fail when {@code (1 - p)^trials == 0} which requires {@code trials} to be large
992          * and/or {@code p} to be small. In this case an exception is raised.</p>
993          *
994          * @param rng Generator of uniformly distributed random numbers.
995          * @param trials Number of trials.
996          * @param probabilityOfSuccess Probability of success (p).
997          * @return Sampler.
998          * @throws IllegalArgumentException if {@code trials < 0} or {@code trials >= 2^16},
999          * {@code p} is not in the range {@code [0-1]}, or the probability distribution cannot
1000          * be computed.
1001          */
1002         public static SharedStateDiscreteSampler of(UniformRandomProvider rng,
1003                                                     int trials,
1004                                                     double probabilityOfSuccess) {
1005             validateBinomialDistributionParameters(trials, probabilityOfSuccess);
1006 
1007             // Handle edge cases
1008             if (probabilityOfSuccess == 0) {
1009                 return new MarsagliaTsangWangFixedResultBinomialSampler(0);
1010             }
1011             if (probabilityOfSuccess == 1) {
1012                 return new MarsagliaTsangWangFixedResultBinomialSampler(trials);
1013             }
1014 
1015             // Check the supported size.
1016             if (trials >= INT_16) {
1017                 throw new IllegalArgumentException("Unsupported number of trials: " + trials);
1018             }
1019 
1020             return createBinomialDistributionSampler(rng, trials, probabilityOfSuccess);
1021         }
1022 
1023         /**
1024          * Validate the Binomial distribution parameters.
1025          *
1026          * @param trials Number of trials.
1027          * @param probabilityOfSuccess Probability of success (p).
1028          * @throws IllegalArgumentException if {@code trials < 0} or
1029          * {@code p} is not in the range {@code [0-1]}
1030          */
1031         private static void validateBinomialDistributionParameters(int trials, double probabilityOfSuccess) {
1032             if (trials < 0) {
1033                 throw new IllegalArgumentException("Trials is not positive: " + trials);
1034             }
1035             if (probabilityOfSuccess < 0 || probabilityOfSuccess > 1) {
1036                 throw new IllegalArgumentException("Probability is not in range [0,1]: " + probabilityOfSuccess);
1037             }
1038         }
1039 
1040         /**
1041          * Creates the Binomial distribution sampler.
1042          *
1043          * <p>This assumes the parameters for the distribution are valid. The method
1044          * will only fail if the initial probability for {@code X=0} is zero.</p>
1045          *
1046          * @param rng Generator of uniformly distributed random numbers.
1047          * @param trials Number of trials.
1048          * @param probabilityOfSuccess Probability of success (p).
1049          * @return Sampler.
1050          * @throws IllegalArgumentException if the probability distribution cannot be
1051          * computed.
1052          */
1053         private static SharedStateDiscreteSampler createBinomialDistributionSampler(
1054                 UniformRandomProvider rng, int trials, double probabilityOfSuccess) {
1055 
1056             // The maximum supported value for Math.exp is approximately -744.
1057             // This occurs when trials is large and p is close to 1.
1058             // Handle this by using an inversion: generate j=Binomial(n,1-p), return n-j
1059             final boolean useInversion = probabilityOfSuccess > 0.5;
1060             final double p = useInversion ? 1 - probabilityOfSuccess : probabilityOfSuccess;
1061 
1062             // Check if the distribution can be computed
1063             final double p0 = Math.exp(trials * Math.log(1 - p));
1064             if (p0 < Double.MIN_VALUE) {
1065                 throw new IllegalArgumentException("Unable to compute distribution");
1066             }
1067 
1068             // First find size of probability array
1069             double t = p0;
1070             final double h = p / (1 - p);
1071             // Find first probability above the threshold of 2^-31
1072             int begin = 0;
1073             if (t * DOUBLE_31 < 1) {
1074                 // Somewhere after p(0)
1075                 // Note:
1076                 // If this loop is entered p(0) is < 2^-31.
1077                 // This has been tested at the extreme for p(0)=Double.MIN_VALUE and either
1078                 // p=0.5 or trials=2^16-1 and does not fail to find the beginning.
1079                 for (int i = 1; i <= trials; i++) {
1080                     t *= (trials + 1 - i) * h / i;
1081                     if (t * DOUBLE_31 >= 1) {
1082                         begin = i;
1083                         break;
1084                     }
1085                 }
1086             }
1087             // Find last probability
1088             int end = trials;
1089             for (int i = begin + 1; i <= trials; i++) {
1090                 t *= (trials + 1 - i) * h / i;
1091                 if (t * DOUBLE_31 < 1) {
1092                     end = i - 1;
1093                     break;
1094                 }
1095             }
1096 
1097             return createBinomialDistributionSamplerFromRange(rng, trials, p, useInversion,
1098                     p0, begin, end);
1099         }
1100 
1101         /**
1102          * Creates the Binomial distribution sampler using only the probability values for {@code X}
1103          * between the begin and the end (inclusive).
1104          *
1105          * @param rng Generator of uniformly distributed random numbers.
1106          * @param trials Number of trials.
1107          * @param p Probability of success (p).
1108          * @param useInversion Set to {@code true} if the probability was inverted.
1109          * @param p0 Probability at {@code X=0}
1110          * @param begin Begin value {@code X} for the distribution.
1111          * @param end End value {@code X} for the distribution.
1112          * @return Sampler.
1113          */
1114         private static SharedStateDiscreteSampler createBinomialDistributionSamplerFromRange(
1115                 UniformRandomProvider rng, int trials, double p,
1116                 boolean useInversion, double p0, int begin, int end) {
1117 
1118             // Assign probability values as 30-bit integers
1119             final int size = end - begin + 1;
1120             final int[] prob = new int[size];
1121             double t = p0;
1122             final double h = p / (1 - p);
1123             for (int i = 1; i <= begin; i++) {
1124                 t *= (trials + 1 - i) * h / i;
1125             }
1126             int sum = toUnsignedInt30(t);
1127             prob[0] = sum;
1128             for (int i = begin + 1; i <= end; i++) {
1129                 t *= (trials + 1 - i) * h / i;
1130                 prob[i - begin] = toUnsignedInt30(t);
1131                 sum += prob[i - begin];
1132             }
1133 
1134             // If the sum is < 2^30 add the remaining sum to the mode (floor((n+1)p))).
1135             // If above 2^30 then the effect is truncation of the long tail of the distribution.
1136             final int mode = (int) ((trials + 1) * p) - begin;
1137             prob[mode] += Math.max(0, INT_30 - sum);
1138 
1139             final SharedStateDiscreteSampler sampler = createSampler(rng, BINOMIAL_NAME, prob, begin);
1140 
1141             // Check if an inversion was made
1142             return useInversion ?
1143                    new MarsagliaTsangWangInversionBinomialSampler(trials, sampler) :
1144                    sampler;
1145         }
1146     }
1147 }