MarsagliaTsangWangDiscreteSampler.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.rng.sampling.distribution;

  18. import org.apache.commons.rng.UniformRandomProvider;

  19. /**
  20.  * Sampler for a discrete distribution using an optimised look-up table.
  21.  *
  22.  * <ul>
  23.  *  <li>
  24.  *   The method requires 30-bit integer probabilities that sum to 2<sup>30</sup> as described
  25.  *   in George Marsaglia, Wai Wan Tsang, Jingbo Wang (2004) Fast Generation of Discrete
  26.  *   Random Variables. Journal of Statistical Software. Vol. 11, Issue. 3, pp. 1-11.
  27.  *  </li>
  28.  * </ul>
  29.  *
  30.  * <p>Sampling uses 1 call to {@link UniformRandomProvider#nextInt()}.</p>
  31.  *
  32.  * <p>Memory requirements depend on the maximum number of possible sample values, {@code n},
  33.  * and the values for the probabilities. Storage is optimised for {@code n}. The worst case
  34.  * scenario is a uniform distribution of the maximum sample size. This is capped at 0.06MB for
  35.  * {@code n <= } 2<sup>8</sup>, 17.0MB for {@code n <= } 2<sup>16</sup>, and 4.3GB for
  36.  * {@code n <=} 2<sup>30</sup>. Realistic requirements will be in the kB range.</p>
  37.  *
  38.  * <p>The sampler supports the following distributions:</p>
  39.  *
  40.  * <ul>
  41.  *  <li>Enumerated distribution (probabilities must be provided for each sample)
  42.  *  <li>Poisson distribution up to {@code mean = 1024}
  43.  *  <li>Binomial distribution up to {@code trials = 65535}
  44.  * </ul>
  45.  *
  46.  * @see <a href="http://dx.doi.org/10.18637/jss.v011.i03">Margsglia, et al (2004) JSS Vol.
  47.  * 11, Issue 3</a>
  48.  * @since 1.3
  49.  */
  50. public final class MarsagliaTsangWangDiscreteSampler {
  51.     /** The value 2<sup>8</sup> as an {@code int}. */
  52.     private static final int INT_8 = 1 << 8;
  53.     /** The value 2<sup>16</sup> as an {@code int}. */
  54.     private static final int INT_16 = 1 << 16;
  55.     /** The value 2<sup>30</sup> as an {@code int}. */
  56.     private static final int INT_30 = 1 << 30;
  57.     /** The value 2<sup>31</sup> as a {@code double}. */
  58.     private static final double DOUBLE_31 = 1L << 31;

  59.     // =========================================================================
  60.     // Implementation note:
  61.     //
  62.     // This sampler uses prepared look-up tables that are searched using a single
  63.     // random int variate. The look-up tables contain the sample value. The tables
  64.     // are constructed using probabilities that sum to 2^30. The original paper
  65.     // by Marsaglia, et al (2004) describes the use of 5, 3, or 2 look-up tables
  66.     // indexed using digits of base 2^6, 2^10 or 2^15. Currently only base 64 (2^6)
  67.     // is supported using 5 look-up tables.
  68.     //
  69.     // The implementations use 8, 16 or 32 bit storage tables to support different
  70.     // distribution sizes with optimal storage. Separate class implementations of
  71.     // the same algorithm allow array storage to be accessed directly from 1D tables.
  72.     // This provides a performance gain over using: abstracted storage accessed via
  73.     // an interface; or a single 2D table.
  74.     //
  75.     // To allow the optimal implementation to be chosen the sampler is created
  76.     // using factory methods. The sampler supports any probability distribution
  77.     // when provided via an array of probabilities and the Poisson and Binomial
  78.     // distributions for a restricted set of parameters. The restrictions are
  79.     // imposed by the requirement to compute the entire probability distribution
  80.     // from the controlling parameter(s) using a recursive method. Factory
  81.     // constructors return a SharedStateDiscreteSampler instance. Each distribution
  82.     // type is contained in an inner class.
  83.     // =========================================================================

  84.     /**
  85.      * The base class for Marsaglia-Tsang-Wang samplers.
  86.      */
  87.     private abstract static class AbstractMarsagliaTsangWangDiscreteSampler
  88.             implements SharedStateDiscreteSampler {
  89.         /** Underlying source of randomness. */
  90.         protected final UniformRandomProvider rng;

  91.         /** The name of the distribution. */
  92.         private final String distributionName;

  93.         /**
  94.          * @param rng Generator of uniformly distributed random numbers.
  95.          * @param distributionName Distribution name.
  96.          */
  97.         AbstractMarsagliaTsangWangDiscreteSampler(UniformRandomProvider rng,
  98.                                                   String distributionName) {
  99.             this.rng = rng;
  100.             this.distributionName = distributionName;
  101.         }

  102.         /**
  103.          * @param rng Generator of uniformly distributed random numbers.
  104.          * @param source Source to copy.
  105.          */
  106.         AbstractMarsagliaTsangWangDiscreteSampler(UniformRandomProvider rng,
  107.                                                   AbstractMarsagliaTsangWangDiscreteSampler source) {
  108.             this.rng = rng;
  109.             this.distributionName = source.distributionName;
  110.         }

  111.         /** {@inheritDoc} */
  112.         @Override
  113.         public String toString() {
  114.             return "Marsaglia Tsang Wang " + distributionName + " deviate [" + rng.toString() + "]";
  115.         }
  116.     }

  117.     /**
  118.      * An implementation for the sample algorithm based on the decomposition of the
  119.      * index in the range {@code [0,2^30)} into 5 base-64 digits with 8-bit backing storage.
  120.      */
  121.     private static final class MarsagliaTsangWangBase64Int8DiscreteSampler
  122.         extends AbstractMarsagliaTsangWangDiscreteSampler {
  123.         /** The mask to convert a {@code byte} to an unsigned 8-bit integer. */
  124.         private static final int MASK = 0xff;

  125.         /** Limit for look-up table 1. */
  126.         private final int t1;
  127.         /** Limit for look-up table 2. */
  128.         private final int t2;
  129.         /** Limit for look-up table 3. */
  130.         private final int t3;
  131.         /** Limit for look-up table 4. */
  132.         private final int t4;

  133.         /** Look-up table table1. */
  134.         private final byte[] table1;
  135.         /** Look-up table table2. */
  136.         private final byte[] table2;
  137.         /** Look-up table table3. */
  138.         private final byte[] table3;
  139.         /** Look-up table table4. */
  140.         private final byte[] table4;
  141.         /** Look-up table table5. */
  142.         private final byte[] table5;

  143.         /**
  144.          * @param rng Generator of uniformly distributed random numbers.
  145.          * @param distributionName Distribution name.
  146.          * @param prob The probabilities.
  147.          * @param offset The offset (must be positive).
  148.          */
  149.         MarsagliaTsangWangBase64Int8DiscreteSampler(UniformRandomProvider rng,
  150.                                                     String distributionName,
  151.                                                     int[] prob,
  152.                                                     int offset) {
  153.             super(rng, distributionName);

  154.             // Get table sizes for each base-64 digit
  155.             int n1 = 0;
  156.             int n2 = 0;
  157.             int n3 = 0;
  158.             int n4 = 0;
  159.             int n5 = 0;
  160.             for (final int m : prob) {
  161.                 n1 += getBase64Digit(m, 1);
  162.                 n2 += getBase64Digit(m, 2);
  163.                 n3 += getBase64Digit(m, 3);
  164.                 n4 += getBase64Digit(m, 4);
  165.                 n5 += getBase64Digit(m, 5);
  166.             }

  167.             table1 = new byte[n1];
  168.             table2 = new byte[n2];
  169.             table3 = new byte[n3];
  170.             table4 = new byte[n4];
  171.             table5 = new byte[n5];

  172.             // Compute offsets
  173.             t1 = n1 << 24;
  174.             t2 = t1 + (n2 << 18);
  175.             t3 = t2 + (n3 << 12);
  176.             t4 = t3 + (n4 << 6);
  177.             n1 = n2 = n3 = n4 = n5 = 0;

  178.             // Fill tables
  179.             for (int i = 0; i < prob.length; i++) {
  180.                 final int m = prob[i];
  181.                 // Primitive type conversion will extract lower 8 bits
  182.                 final byte k = (byte) (i + offset);
  183.                 n1 = fill(table1, n1, n1 + getBase64Digit(m, 1), k);
  184.                 n2 = fill(table2, n2, n2 + getBase64Digit(m, 2), k);
  185.                 n3 = fill(table3, n3, n3 + getBase64Digit(m, 3), k);
  186.                 n4 = fill(table4, n4, n4 + getBase64Digit(m, 4), k);
  187.                 n5 = fill(table5, n5, n5 + getBase64Digit(m, 5), k);
  188.             }
  189.         }

  190.         /**
  191.          * @param rng Generator of uniformly distributed random numbers.
  192.          * @param source Source to copy.
  193.          */
  194.         private MarsagliaTsangWangBase64Int8DiscreteSampler(UniformRandomProvider rng,
  195.                 MarsagliaTsangWangBase64Int8DiscreteSampler source) {
  196.             super(rng, source);
  197.             t1 = source.t1;
  198.             t2 = source.t2;
  199.             t3 = source.t3;
  200.             t4 = source.t4;
  201.             table1 = source.table1;
  202.             table2 = source.table2;
  203.             table3 = source.table3;
  204.             table4 = source.table4;
  205.             table5 = source.table5;
  206.         }

  207.         /**
  208.          * Fill the table with the value.
  209.          *
  210.          * @param table Table.
  211.          * @param from Lower bound index (inclusive)
  212.          * @param to Upper bound index (exclusive)
  213.          * @param value Value.
  214.          * @return the upper bound index
  215.          */
  216.         private static int fill(byte[] table, int from, int to, byte value) {
  217.             for (int i = from; i < to; i++) {
  218.                 table[i] = value;
  219.             }
  220.             return to;
  221.         }

  222.         @Override
  223.         public int sample() {
  224.             final int j = rng.nextInt() >>> 2;
  225.             if (j < t1) {
  226.                 return table1[j >>> 24] & MASK;
  227.             }
  228.             if (j < t2) {
  229.                 return table2[(j - t1) >>> 18] & MASK;
  230.             }
  231.             if (j < t3) {
  232.                 return table3[(j - t2) >>> 12] & MASK;
  233.             }
  234.             if (j < t4) {
  235.                 return table4[(j - t3) >>> 6] & MASK;
  236.             }
  237.             // Note the tables are filled on the assumption that the sum of the probabilities.
  238.             // is >=2^30. If this is not true then the final table table5 will be smaller by the
  239.             // difference. So the tables *must* be constructed correctly.
  240.             return table5[j - t4] & MASK;
  241.         }

  242.         @Override
  243.         public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
  244.             return new MarsagliaTsangWangBase64Int8DiscreteSampler(rng, this);
  245.         }
  246.     }

  247.     /**
  248.      * An implementation for the sample algorithm based on the decomposition of the
  249.      * index in the range {@code [0,2^30)} into 5 base-64 digits with 16-bit backing storage.
  250.      */
  251.     private static final class MarsagliaTsangWangBase64Int16DiscreteSampler
  252.         extends AbstractMarsagliaTsangWangDiscreteSampler {
  253.         /** The mask to convert a {@code byte} to an unsigned 16-bit integer. */
  254.         private static final int MASK = 0xffff;

  255.         /** Limit for look-up table 1. */
  256.         private final int t1;
  257.         /** Limit for look-up table 2. */
  258.         private final int t2;
  259.         /** Limit for look-up table 3. */
  260.         private final int t3;
  261.         /** Limit for look-up table 4. */
  262.         private final int t4;

  263.         /** Look-up table table1. */
  264.         private final short[] table1;
  265.         /** Look-up table table2. */
  266.         private final short[] table2;
  267.         /** Look-up table table3. */
  268.         private final short[] table3;
  269.         /** Look-up table table4. */
  270.         private final short[] table4;
  271.         /** Look-up table table5. */
  272.         private final short[] table5;

  273.         /**
  274.          * @param rng Generator of uniformly distributed random numbers.
  275.          * @param distributionName Distribution name.
  276.          * @param prob The probabilities.
  277.          * @param offset The offset (must be positive).
  278.          */
  279.         MarsagliaTsangWangBase64Int16DiscreteSampler(UniformRandomProvider rng,
  280.                                                      String distributionName,
  281.                                                      int[] prob,
  282.                                                      int offset) {
  283.             super(rng, distributionName);

  284.             // Get table sizes for each base-64 digit
  285.             int n1 = 0;
  286.             int n2 = 0;
  287.             int n3 = 0;
  288.             int n4 = 0;
  289.             int n5 = 0;
  290.             for (final int m : prob) {
  291.                 n1 += getBase64Digit(m, 1);
  292.                 n2 += getBase64Digit(m, 2);
  293.                 n3 += getBase64Digit(m, 3);
  294.                 n4 += getBase64Digit(m, 4);
  295.                 n5 += getBase64Digit(m, 5);
  296.             }

  297.             table1 = new short[n1];
  298.             table2 = new short[n2];
  299.             table3 = new short[n3];
  300.             table4 = new short[n4];
  301.             table5 = new short[n5];

  302.             // Compute offsets
  303.             t1 = n1 << 24;
  304.             t2 = t1 + (n2 << 18);
  305.             t3 = t2 + (n3 << 12);
  306.             t4 = t3 + (n4 << 6);
  307.             n1 = n2 = n3 = n4 = n5 = 0;

  308.             // Fill tables
  309.             for (int i = 0; i < prob.length; i++) {
  310.                 final int m = prob[i];
  311.                 // Primitive type conversion will extract lower 16 bits
  312.                 final short k = (short) (i + offset);
  313.                 n1 = fill(table1, n1, n1 + getBase64Digit(m, 1), k);
  314.                 n2 = fill(table2, n2, n2 + getBase64Digit(m, 2), k);
  315.                 n3 = fill(table3, n3, n3 + getBase64Digit(m, 3), k);
  316.                 n4 = fill(table4, n4, n4 + getBase64Digit(m, 4), k);
  317.                 n5 = fill(table5, n5, n5 + getBase64Digit(m, 5), k);
  318.             }
  319.         }

  320.         /**
  321.          * @param rng Generator of uniformly distributed random numbers.
  322.          * @param source Source to copy.
  323.          */
  324.         private MarsagliaTsangWangBase64Int16DiscreteSampler(UniformRandomProvider rng,
  325.                 MarsagliaTsangWangBase64Int16DiscreteSampler source) {
  326.             super(rng, source);
  327.             t1 = source.t1;
  328.             t2 = source.t2;
  329.             t3 = source.t3;
  330.             t4 = source.t4;
  331.             table1 = source.table1;
  332.             table2 = source.table2;
  333.             table3 = source.table3;
  334.             table4 = source.table4;
  335.             table5 = source.table5;
  336.         }

  337.         /**
  338.          * Fill the table with the value.
  339.          *
  340.          * @param table Table.
  341.          * @param from Lower bound index (inclusive)
  342.          * @param to Upper bound index (exclusive)
  343.          * @param value Value.
  344.          * @return the upper bound index
  345.          */
  346.         private static int fill(short[] table, int from, int to, short value) {
  347.             for (int i = from; i < to; i++) {
  348.                 table[i] = value;
  349.             }
  350.             return to;
  351.         }

  352.         @Override
  353.         public int sample() {
  354.             final int j = rng.nextInt() >>> 2;
  355.             if (j < t1) {
  356.                 return table1[j >>> 24] & MASK;
  357.             }
  358.             if (j < t2) {
  359.                 return table2[(j - t1) >>> 18] & MASK;
  360.             }
  361.             if (j < t3) {
  362.                 return table3[(j - t2) >>> 12] & MASK;
  363.             }
  364.             if (j < t4) {
  365.                 return table4[(j - t3) >>> 6] & MASK;
  366.             }
  367.             // Note the tables are filled on the assumption that the sum of the probabilities.
  368.             // is >=2^30. If this is not true then the final table table5 will be smaller by the
  369.             // difference. So the tables *must* be constructed correctly.
  370.             return table5[j - t4] & MASK;
  371.         }

  372.         @Override
  373.         public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
  374.             return new MarsagliaTsangWangBase64Int16DiscreteSampler(rng, this);
  375.         }
  376.     }

  377.     /**
  378.      * An implementation for the sample algorithm based on the decomposition of the
  379.      * index in the range {@code [0,2^30)} into 5 base-64 digits with 32-bit backing storage.
  380.      */
  381.     private static final class MarsagliaTsangWangBase64Int32DiscreteSampler
  382.         extends AbstractMarsagliaTsangWangDiscreteSampler {
  383.         /** Limit for look-up table 1. */
  384.         private final int t1;
  385.         /** Limit for look-up table 2. */
  386.         private final int t2;
  387.         /** Limit for look-up table 3. */
  388.         private final int t3;
  389.         /** Limit for look-up table 4. */
  390.         private final int t4;

  391.         /** Look-up table table1. */
  392.         private final int[] table1;
  393.         /** Look-up table table2. */
  394.         private final int[] table2;
  395.         /** Look-up table table3. */
  396.         private final int[] table3;
  397.         /** Look-up table table4. */
  398.         private final int[] table4;
  399.         /** Look-up table table5. */
  400.         private final int[] table5;

  401.         /**
  402.          * @param rng Generator of uniformly distributed random numbers.
  403.          * @param distributionName Distribution name.
  404.          * @param prob The probabilities.
  405.          * @param offset The offset (must be positive).
  406.          */
  407.         MarsagliaTsangWangBase64Int32DiscreteSampler(UniformRandomProvider rng,
  408.                                                      String distributionName,
  409.                                                      int[] prob,
  410.                                                      int offset) {
  411.             super(rng, distributionName);

  412.             // Get table sizes for each base-64 digit
  413.             int n1 = 0;
  414.             int n2 = 0;
  415.             int n3 = 0;
  416.             int n4 = 0;
  417.             int n5 = 0;
  418.             for (final int m : prob) {
  419.                 n1 += getBase64Digit(m, 1);
  420.                 n2 += getBase64Digit(m, 2);
  421.                 n3 += getBase64Digit(m, 3);
  422.                 n4 += getBase64Digit(m, 4);
  423.                 n5 += getBase64Digit(m, 5);
  424.             }

  425.             table1 = new int[n1];
  426.             table2 = new int[n2];
  427.             table3 = new int[n3];
  428.             table4 = new int[n4];
  429.             table5 = new int[n5];

  430.             // Compute offsets
  431.             t1 = n1 << 24;
  432.             t2 = t1 + (n2 << 18);
  433.             t3 = t2 + (n3 << 12);
  434.             t4 = t3 + (n4 << 6);
  435.             n1 = n2 = n3 = n4 = n5 = 0;

  436.             // Fill tables
  437.             for (int i = 0; i < prob.length; i++) {
  438.                 final int m = prob[i];
  439.                 final int k = i + offset;
  440.                 n1 = fill(table1, n1, n1 + getBase64Digit(m, 1), k);
  441.                 n2 = fill(table2, n2, n2 + getBase64Digit(m, 2), k);
  442.                 n3 = fill(table3, n3, n3 + getBase64Digit(m, 3), k);
  443.                 n4 = fill(table4, n4, n4 + getBase64Digit(m, 4), k);
  444.                 n5 = fill(table5, n5, n5 + getBase64Digit(m, 5), k);
  445.             }
  446.         }

  447.         /**
  448.          * @param rng Generator of uniformly distributed random numbers.
  449.          * @param source Source to copy.
  450.          */
  451.         private MarsagliaTsangWangBase64Int32DiscreteSampler(UniformRandomProvider rng,
  452.                 MarsagliaTsangWangBase64Int32DiscreteSampler source) {
  453.             super(rng, source);
  454.             t1 = source.t1;
  455.             t2 = source.t2;
  456.             t3 = source.t3;
  457.             t4 = source.t4;
  458.             table1 = source.table1;
  459.             table2 = source.table2;
  460.             table3 = source.table3;
  461.             table4 = source.table4;
  462.             table5 = source.table5;
  463.         }

  464.         /**
  465.          * Fill the table with the value.
  466.          *
  467.          * @param table Table.
  468.          * @param from Lower bound index (inclusive)
  469.          * @param to Upper bound index (exclusive)
  470.          * @param value Value.
  471.          * @return the upper bound index
  472.          */
  473.         private static int fill(int[] table, int from, int to, int value) {
  474.             for (int i = from; i < to; i++) {
  475.                 table[i] = value;
  476.             }
  477.             return to;
  478.         }

  479.         @Override
  480.         public int sample() {
  481.             final int j = rng.nextInt() >>> 2;
  482.             if (j < t1) {
  483.                 return table1[j >>> 24];
  484.             }
  485.             if (j < t2) {
  486.                 return table2[(j - t1) >>> 18];
  487.             }
  488.             if (j < t3) {
  489.                 return table3[(j - t2) >>> 12];
  490.             }
  491.             if (j < t4) {
  492.                 return table4[(j - t3) >>> 6];
  493.             }
  494.             // Note the tables are filled on the assumption that the sum of the probabilities.
  495.             // is >=2^30. If this is not true then the final table table5 will be smaller by the
  496.             // difference. So the tables *must* be constructed correctly.
  497.             return table5[j - t4];
  498.         }

  499.         @Override
  500.         public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
  501.             return new MarsagliaTsangWangBase64Int32DiscreteSampler(rng, this);
  502.         }
  503.     }



  504.     /** Class contains only static methods. */
  505.     private MarsagliaTsangWangDiscreteSampler() {}

  506.     /**
  507.      * Gets the k<sup>th</sup> base 64 digit of {@code m}.
  508.      *
  509.      * @param m the value m.
  510.      * @param k the digit.
  511.      * @return the base 64 digit
  512.      */
  513.     private static int getBase64Digit(int m, int k) {
  514.         return (m >>> (30 - 6 * k)) & 63;
  515.     }

  516.     /**
  517.      * Convert the probability to an integer in the range [0,2^30]. This is the numerator of
  518.      * a fraction with assumed denominator 2<sup>30</sup>.
  519.      *
  520.      * @param p Probability.
  521.      * @return the fraction numerator
  522.      */
  523.     private static int toUnsignedInt30(double p) {
  524.         return (int) (p * INT_30 + 0.5);
  525.     }

  526.     /**
  527.      * Create a new instance for probabilities {@code p(i)} where the sample value {@code x} is
  528.      * {@code i + offset}.
  529.      *
  530.      * <p>The sum of the probabilities must be {@code >=} 2<sup>30</sup>. Only the
  531.      * values for cumulative probability up to 2<sup>30</sup> will be sampled.</p>
  532.      *
  533.      * @param rng Generator of uniformly distributed random numbers.
  534.      * @param distributionName Distribution name.
  535.      * @param prob The probabilities.
  536.      * @param offset The offset (must be positive).
  537.      * @return Sampler.
  538.      */
  539.     private static SharedStateDiscreteSampler createSampler(UniformRandomProvider rng,
  540.                                                             String distributionName,
  541.                                                             int[] prob,
  542.                                                             int offset) {
  543.         // Note: No argument checks for private method.

  544.         // Choose implementation based on the maximum index
  545.         final int maxIndex = prob.length + offset - 1;
  546.         if (maxIndex < INT_8) {
  547.             return new MarsagliaTsangWangBase64Int8DiscreteSampler(rng, distributionName, prob, offset);
  548.         }
  549.         if (maxIndex < INT_16) {
  550.             return new MarsagliaTsangWangBase64Int16DiscreteSampler(rng, distributionName, prob, offset);
  551.         }
  552.         return new MarsagliaTsangWangBase64Int32DiscreteSampler(rng, distributionName, prob, offset);
  553.     }

  554.     // =========================================================================
  555.     // The following public classes provide factory methods to construct a sampler for:
  556.     // - Enumerated probability distribution (from provided double[] probabilities)
  557.     // - Poisson distribution for mean <= 1024
  558.     // - Binomial distribution for trials <= 65535
  559.     // =========================================================================

  560.     /**
  561.      * Create a sampler for an enumerated distribution of {@code n} values each with an
  562.      * associated probability.
  563.      * The samples corresponding to each probability are assumed to be a natural sequence
  564.      * starting at zero.
  565.      */
  566.     public static final class Enumerated {
  567.         /** The name of the enumerated probability distribution. */
  568.         private static final String ENUMERATED_NAME = "Enumerated";

  569.         /** Class contains only static methods. */
  570.         private Enumerated() {}

  571.         /**
  572.          * Creates a sampler for an enumerated distribution of {@code n} values each with an
  573.          * associated probability.
  574.          *
  575.          * <p>The probabilities will be normalised using their sum. The only requirement
  576.          * is the sum is positive.</p>
  577.          *
  578.          * <p>The sum of the probabilities is normalised to 2<sup>30</sup>. Note that
  579.          * probabilities are adjusted to the nearest 2<sup>-30</sup> due to round-off during
  580.          * the normalisation conversion. Consequently any probability less than 2<sup>-31</sup>
  581.          * will not be observed in samples.</p>
  582.          *
  583.          * @param rng Generator of uniformly distributed random numbers.
  584.          * @param probabilities The list of probabilities.
  585.          * @return Sampler.
  586.          * @throws IllegalArgumentException if {@code probabilities} is null or empty, a
  587.          * probability is negative, infinite or {@code NaN}, or the sum of all
  588.          * probabilities is not strictly positive.
  589.          */
  590.         public static SharedStateDiscreteSampler of(UniformRandomProvider rng,
  591.                                                     double[] probabilities) {
  592.             return createSampler(rng, ENUMERATED_NAME, normaliseProbabilities(probabilities), 0);
  593.         }

  594.         /**
  595.          * Normalise the probabilities to integers that sum to 2<sup>30</sup>.
  596.          *
  597.          * @param probabilities The list of probabilities.
  598.          * @return the normalised probabilities.
  599.          * @throws IllegalArgumentException if {@code probabilities} is null or empty, a
  600.          * probability is negative, infinite or {@code NaN}, or the sum of all
  601.          * probabilities is not strictly positive.
  602.          */
  603.         private static int[] normaliseProbabilities(double[] probabilities) {
  604.             final double sumProb = InternalUtils.validateProbabilities(probabilities);

  605.             // Compute the normalisation: 2^30 / sum
  606.             final double normalisation = INT_30 / sumProb;
  607.             final int[] prob = new int[probabilities.length];
  608.             int sum = 0;
  609.             int max = 0;
  610.             int mode = 0;
  611.             for (int i = 0; i < prob.length; i++) {
  612.                 // Add 0.5 for rounding
  613.                 final int p = (int) (probabilities[i] * normalisation + 0.5);
  614.                 sum += p;
  615.                 // Find the mode (maximum probability)
  616.                 if (max < p) {
  617.                     max = p;
  618.                     mode = i;
  619.                 }
  620.                 prob[i] = p;
  621.             }

  622.             // The sum must be >= 2^30.
  623.             // Here just compensate the difference onto the highest probability.
  624.             prob[mode] += INT_30 - sum;

  625.             return prob;
  626.         }
  627.     }

  628.     /**
  629.      * Create a sampler for the Poisson distribution.
  630.      */
  631.     public static final class Poisson {
  632.         /** The name of the Poisson distribution. */
  633.         private static final String POISSON_NAME = "Poisson";

  634.         /**
  635.          * Upper bound on the mean for the Poisson distribution.
  636.          *
  637.          * <p>The original source code provided in Marsaglia, et al (2004) has no explicit
  638.          * limit but the code fails at mean {@code >= 1941} as the transform to compute p(x=mode)
  639.          * produces infinity. Use a conservative limit of 1024.</p>
  640.          */

  641.         private static final double MAX_MEAN = 1024;
  642.         /**
  643.          * The threshold for the mean of the Poisson distribution to switch the method used
  644.          * to compute the probabilities. This is taken from the example software provided by
  645.          * Marsaglia, et al (2004).
  646.          */
  647.         private static final double MEAN_THRESHOLD = 21.4;

  648.         /** Class contains only static methods. */
  649.         private Poisson() {}

  650.         /**
  651.          * Creates a sampler for the Poisson distribution.
  652.          *
  653.          * <p>Any probability less than 2<sup>-31</sup> will not be observed in samples.</p>
  654.          *
  655.          * <p>Storage requirements depend on the tabulated probability values. Example storage
  656.          * requirements are listed below.</p>
  657.          *
  658.          * <pre>
  659.          * mean      table size     kB
  660.          * 0.25      882            0.88
  661.          * 0.5       1135           1.14
  662.          * 1         1200           1.20
  663.          * 2         1451           1.45
  664.          * 4         1955           1.96
  665.          * 8         2961           2.96
  666.          * 16        4410           4.41
  667.          * 32        6115           6.11
  668.          * 64        8499           8.50
  669.          * 128       11528          11.53
  670.          * 256       15935          31.87
  671.          * 512       20912          41.82
  672.          * 1024      30614          61.23
  673.          * </pre>
  674.          *
  675.          * <p>Note: Storage changes to 2 bytes per index between {@code mean=128} and {@code mean=256}.</p>
  676.          *
  677.          * @param rng Generator of uniformly distributed random numbers.
  678.          * @param mean Mean.
  679.          * @return Sampler.
  680.          * @throws IllegalArgumentException if {@code mean <= 0} or {@code mean > 1024}.
  681.          */
  682.         public static SharedStateDiscreteSampler of(UniformRandomProvider rng,
  683.                                                     double mean) {
  684.             validatePoissonDistributionParameters(mean);

  685.             // Create the distribution either from X=0 or from X=mode when the mean is high.
  686.             return mean < MEAN_THRESHOLD ?
  687.                 createPoissonDistributionFromX0(rng, mean) :
  688.                 createPoissonDistributionFromXMode(rng, mean);
  689.         }

  690.         /**
  691.          * Validate the Poisson distribution parameters.
  692.          *
  693.          * @param mean Mean.
  694.          * @throws IllegalArgumentException if {@code mean <= 0} or {@code mean > 1024}.
  695.          */
  696.         private static void validatePoissonDistributionParameters(double mean) {
  697.             InternalUtils.requireStrictlyPositive(mean, "mean");
  698.             if (mean > MAX_MEAN) {
  699.                 throw new IllegalArgumentException("mean " + mean + " > " + MAX_MEAN);
  700.             }
  701.         }

  702.         /**
  703.          * Creates the Poisson distribution by computing probabilities recursively from {@code X=0}.
  704.          *
  705.          * @param rng Generator of uniformly distributed random numbers.
  706.          * @param mean Mean.
  707.          * @return Sampler.
  708.          */
  709.         private static SharedStateDiscreteSampler createPoissonDistributionFromX0(
  710.                 UniformRandomProvider rng, double mean) {
  711.             final double p0 = Math.exp(-mean);

  712.             // Recursive update of Poisson probability until the value is too small
  713.             // p(x + 1) = p(x) * mean / (x + 1)
  714.             double p = p0;
  715.             int i = 1;
  716.             while (p * DOUBLE_31 >= 1) {
  717.                 p *= mean / i++;
  718.             }

  719.             // Probabilities are 30-bit integers, assumed denominator 2^30
  720.             final int size = i - 1;
  721.             final int[] prob = new int[size];

  722.             p = p0;
  723.             prob[0] = toUnsignedInt30(p);
  724.             // The sum must exceed 2^30. In edges cases this is false due to round-off.
  725.             int sum = prob[0];
  726.             for (i = 1; i < prob.length; i++) {
  727.                 p *= mean / i;
  728.                 prob[i] = toUnsignedInt30(p);
  729.                 sum += prob[i];
  730.             }

  731.             // If the sum is < 2^30 add the remaining sum to the mode (floor(mean)).
  732.             prob[(int) mean] += Math.max(0, INT_30 - sum);

  733.             // Note: offset = 0
  734.             return createSampler(rng, POISSON_NAME, prob, 0);
  735.         }

  736.         /**
  737.          * Creates the Poisson distribution by computing probabilities recursively upward and downward
  738.          * from {@code X=mode}, the location of the largest p-value.
  739.          *
  740.          * @param rng Generator of uniformly distributed random numbers.
  741.          * @param mean Mean.
  742.          * @return Sampler.
  743.          */
  744.         private static SharedStateDiscreteSampler createPoissonDistributionFromXMode(
  745.                 UniformRandomProvider rng, double mean) {
  746.             // If mean >= 21.4, generate from largest p-value up, then largest down.
  747.             // The largest p-value will be at the mode (floor(mean)).

  748.             // Find p(x=mode)
  749.             final int mode = (int) mean;
  750.             // This transform is stable until mean >= 1941 where p will result in Infinity
  751.             // before the divisor i is large enough to start reducing the product (i.e. i > c).
  752.             final double c = mean * Math.exp(-mean / mode);
  753.             double p = 1.0;
  754.             for (int i = 1; i <= mode; i++) {
  755.                 p *= c / i;
  756.             }
  757.             final double pMode = p;

  758.             // Find the upper limit using recursive computation of the p-value.
  759.             // Note this will exit when i overflows to negative so no check on the range
  760.             int i = mode + 1;
  761.             while (p * DOUBLE_31 >= 1) {
  762.                 p *= mean / i++;
  763.             }
  764.             final int last = i - 2;

  765.             // Find the lower limit using recursive computation of the p-value.
  766.             p = pMode;
  767.             int j = -1;
  768.             for (i = mode - 1; i >= 0; i--) {
  769.                 p *= (i + 1) / mean;
  770.                 if (p * DOUBLE_31 < 1) {
  771.                     j = i;
  772.                     break;
  773.                 }
  774.             }

  775.             // Probabilities are 30-bit integers, assumed denominator 2^30.
  776.             // This is the minimum sample value: prob[x - offset] = p(x)
  777.             final int offset = j + 1;
  778.             final int size = last - offset + 1;
  779.             final int[] prob = new int[size];

  780.             p = pMode;
  781.             prob[mode - offset] = toUnsignedInt30(p);
  782.             // The sum must exceed 2^30. In edges cases this is false due to round-off.
  783.             int sum = prob[mode - offset];
  784.             // From mode to upper limit
  785.             for (i = mode + 1; i <= last; i++) {
  786.                 p *= mean / i;
  787.                 prob[i - offset] = toUnsignedInt30(p);
  788.                 sum += prob[i - offset];
  789.             }
  790.             // From mode to lower limit
  791.             p = pMode;
  792.             for (i = mode - 1; i >= offset; i--) {
  793.                 p *= (i + 1) / mean;
  794.                 prob[i - offset] = toUnsignedInt30(p);
  795.                 sum += prob[i - offset];
  796.             }

  797.             // If the sum is < 2^30 add the remaining sum to the mode.
  798.             // If above 2^30 then the effect is truncation of the long tail of the distribution.
  799.             prob[mode - offset] += Math.max(0, INT_30 - sum);

  800.             return createSampler(rng, POISSON_NAME, prob, offset);
  801.         }
  802.     }

  803.     /**
  804.      * Create a sampler for the Binomial distribution.
  805.      */
  806.     public static final class Binomial {
  807.         /** The name of the Binomial distribution. */
  808.         private static final String BINOMIAL_NAME = "Binomial";

  809.         /**
  810.          * Return a fixed result for the Binomial distribution. This is a special class to handle
  811.          * an edge case of probability of success equal to 0 or 1.
  812.          */
  813.         private static final class MarsagliaTsangWangFixedResultBinomialSampler
  814.             extends AbstractMarsagliaTsangWangDiscreteSampler {
  815.             /** The result. */
  816.             private final int result;

  817.             /**
  818.              * @param result Result.
  819.              */
  820.             MarsagliaTsangWangFixedResultBinomialSampler(int result) {
  821.                 super(null, BINOMIAL_NAME);
  822.                 this.result = result;
  823.             }

  824.             @Override
  825.             public int sample() {
  826.                 return result;
  827.             }

  828.             @Override
  829.             public String toString() {
  830.                 return BINOMIAL_NAME + " deviate";
  831.             }

  832.             @Override
  833.             public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
  834.                 // No shared state
  835.                 return this;
  836.             }
  837.         }

  838.         /**
  839.          * Return an inversion result for the Binomial distribution. This assumes the
  840.          * following:
  841.          *
  842.          * <pre>
  843.          * Binomial(n, p) = 1 - Binomial(n, 1 - p)
  844.          * </pre>
  845.          */
  846.         private static final class MarsagliaTsangWangInversionBinomialSampler
  847.             extends AbstractMarsagliaTsangWangDiscreteSampler {
  848.             /** The number of trials. */
  849.             private final int trials;
  850.             /** The Binomial distribution sampler. */
  851.             private final SharedStateDiscreteSampler sampler;

  852.             /**
  853.              * @param trials Number of trials.
  854.              * @param sampler Binomial distribution sampler.
  855.              */
  856.             MarsagliaTsangWangInversionBinomialSampler(int trials,
  857.                                                        SharedStateDiscreteSampler sampler) {
  858.                 super(null, BINOMIAL_NAME);
  859.                 this.trials = trials;
  860.                 this.sampler = sampler;
  861.             }

  862.             @Override
  863.             public int sample() {
  864.                 return trials - sampler.sample();
  865.             }

  866.             @Override
  867.             public String toString() {
  868.                 return sampler.toString();
  869.             }

  870.             @Override
  871.             public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
  872.                 return new MarsagliaTsangWangInversionBinomialSampler(this.trials,
  873.                     this.sampler.withUniformRandomProvider(rng));
  874.             }
  875.         }

  876.         /** Class contains only static methods. */
  877.         private Binomial() {}

  878.         /**
  879.          * Creates a sampler for the Binomial distribution.
  880.          *
  881.          * <p>Any probability less than 2<sup>-31</sup> will not be observed in samples.</p>
  882.          *
  883.          * <p>Storage requirements depend on the tabulated probability values. Example storage
  884.          * requirements are listed below (in kB).</p>
  885.          *
  886.          * <pre>
  887.          *          p
  888.          * trials   0.5    0.1   0.01  0.001
  889.          *    4    0.06   0.63   0.44   0.44
  890.          *   16    0.69   1.14   0.76   0.44
  891.          *   64    4.73   2.40   1.14   0.51
  892.          *  256    8.63   5.17   1.89   0.82
  893.          * 1024   31.12   9.45   3.34   0.89
  894.          * </pre>
  895.          *
  896.          * <p>The method requires that the Binomial distribution probability at {@code x=0} can be computed.
  897.          * This will fail when {@code (1 - p)^trials == 0} which requires {@code trials} to be large
  898.          * and/or {@code p} to be small. In this case an exception is raised.</p>
  899.          *
  900.          * @param rng Generator of uniformly distributed random numbers.
  901.          * @param trials Number of trials.
  902.          * @param probabilityOfSuccess Probability of success (p).
  903.          * @return Sampler.
  904.          * @throws IllegalArgumentException if {@code trials < 0} or {@code trials >= 2^16},
  905.          * {@code p} is not in the range {@code [0-1]}, or the probability distribution cannot
  906.          * be computed.
  907.          */
  908.         public static SharedStateDiscreteSampler of(UniformRandomProvider rng,
  909.                                                     int trials,
  910.                                                     double probabilityOfSuccess) {
  911.             validateBinomialDistributionParameters(trials, probabilityOfSuccess);

  912.             // Handle edge cases
  913.             if (probabilityOfSuccess == 0) {
  914.                 return new MarsagliaTsangWangFixedResultBinomialSampler(0);
  915.             }
  916.             if (probabilityOfSuccess == 1) {
  917.                 return new MarsagliaTsangWangFixedResultBinomialSampler(trials);
  918.             }

  919.             // Check the supported size.
  920.             if (trials >= INT_16) {
  921.                 throw new IllegalArgumentException("Unsupported number of trials: " + trials);
  922.             }

  923.             return createBinomialDistributionSampler(rng, trials, probabilityOfSuccess);
  924.         }

  925.         /**
  926.          * Validate the Binomial distribution parameters.
  927.          *
  928.          * @param trials Number of trials.
  929.          * @param probabilityOfSuccess Probability of success (p).
  930.          * @throws IllegalArgumentException if {@code trials < 0} or
  931.          * {@code p} is not in the range {@code [0-1]}
  932.          */
  933.         private static void validateBinomialDistributionParameters(int trials, double probabilityOfSuccess) {
  934.             if (trials < 0) {
  935.                 throw new IllegalArgumentException("Trials is not positive: " + trials);
  936.             }
  937.             InternalUtils.requireRangeClosed(0, 1, probabilityOfSuccess, "probability of success");
  938.         }

  939.         /**
  940.          * Creates the Binomial distribution sampler.
  941.          *
  942.          * <p>This assumes the parameters for the distribution are valid. The method
  943.          * will only fail if the initial probability for {@code X=0} is zero.</p>
  944.          *
  945.          * @param rng Generator of uniformly distributed random numbers.
  946.          * @param trials Number of trials.
  947.          * @param probabilityOfSuccess Probability of success (p).
  948.          * @return Sampler.
  949.          * @throws IllegalArgumentException if the probability distribution cannot be
  950.          * computed.
  951.          */
  952.         private static SharedStateDiscreteSampler createBinomialDistributionSampler(
  953.                 UniformRandomProvider rng, int trials, double probabilityOfSuccess) {

  954.             // The maximum supported value for Math.exp is approximately -744.
  955.             // This occurs when trials is large and p is close to 1.
  956.             // Handle this by using an inversion: generate j=Binomial(n,1-p), return n-j
  957.             final boolean useInversion = probabilityOfSuccess > 0.5;
  958.             final double p = useInversion ? 1 - probabilityOfSuccess : probabilityOfSuccess;

  959.             // Check if the distribution can be computed
  960.             final double p0 = Math.exp(trials * Math.log(1 - p));
  961.             if (p0 < Double.MIN_VALUE) {
  962.                 throw new IllegalArgumentException("Unable to compute distribution");
  963.             }

  964.             // First find size of probability array
  965.             double t = p0;
  966.             final double h = p / (1 - p);
  967.             // Find first probability above the threshold of 2^-31
  968.             int begin = 0;
  969.             if (t * DOUBLE_31 < 1) {
  970.                 // Somewhere after p(0)
  971.                 // Note:
  972.                 // If this loop is entered p(0) is < 2^-31.
  973.                 // This has been tested at the extreme for p(0)=Double.MIN_VALUE and either
  974.                 // p=0.5 or trials=2^16-1 and does not fail to find the beginning.
  975.                 for (int i = 1; i <= trials; i++) {
  976.                     t *= (trials + 1 - i) * h / i;
  977.                     if (t * DOUBLE_31 >= 1) {
  978.                         begin = i;
  979.                         break;
  980.                     }
  981.                 }
  982.             }
  983.             // Find last probability
  984.             int end = trials;
  985.             for (int i = begin + 1; i <= trials; i++) {
  986.                 t *= (trials + 1 - i) * h / i;
  987.                 if (t * DOUBLE_31 < 1) {
  988.                     end = i - 1;
  989.                     break;
  990.                 }
  991.             }

  992.             return createBinomialDistributionSamplerFromRange(rng, trials, p, useInversion,
  993.                     p0, begin, end);
  994.         }

  995.         /**
  996.          * Creates the Binomial distribution sampler using only the probability values for {@code X}
  997.          * between the begin and the end (inclusive).
  998.          *
  999.          * @param rng Generator of uniformly distributed random numbers.
  1000.          * @param trials Number of trials.
  1001.          * @param p Probability of success (p).
  1002.          * @param useInversion Set to {@code true} if the probability was inverted.
  1003.          * @param p0 Probability at {@code X=0}
  1004.          * @param begin Begin value {@code X} for the distribution.
  1005.          * @param end End value {@code X} for the distribution.
  1006.          * @return Sampler.
  1007.          */
  1008.         private static SharedStateDiscreteSampler createBinomialDistributionSamplerFromRange(
  1009.                 UniformRandomProvider rng, int trials, double p,
  1010.                 boolean useInversion, double p0, int begin, int end) {

  1011.             // Assign probability values as 30-bit integers
  1012.             final int size = end - begin + 1;
  1013.             final int[] prob = new int[size];
  1014.             double t = p0;
  1015.             final double h = p / (1 - p);
  1016.             for (int i = 1; i <= begin; i++) {
  1017.                 t *= (trials + 1 - i) * h / i;
  1018.             }
  1019.             int sum = toUnsignedInt30(t);
  1020.             prob[0] = sum;
  1021.             for (int i = begin + 1; i <= end; i++) {
  1022.                 t *= (trials + 1 - i) * h / i;
  1023.                 prob[i - begin] = toUnsignedInt30(t);
  1024.                 sum += prob[i - begin];
  1025.             }

  1026.             // If the sum is < 2^30 add the remaining sum to the mode (floor((n+1)p))).
  1027.             // If above 2^30 then the effect is truncation of the long tail of the distribution.
  1028.             final int mode = (int) ((trials + 1) * p) - begin;
  1029.             prob[mode] += Math.max(0, INT_30 - sum);

  1030.             final SharedStateDiscreteSampler sampler = createSampler(rng, BINOMIAL_NAME, prob, begin);

  1031.             // Check if an inversion was made
  1032.             return useInversion ?
  1033.                    new MarsagliaTsangWangInversionBinomialSampler(trials, sampler) :
  1034.                    sampler;
  1035.         }
  1036.     }
  1037. }