FastLoadedDiceRollerDiscreteSampler.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.rng.sampling.distribution;

  18. import java.math.BigInteger;
  19. import java.util.Arrays;
  20. import org.apache.commons.rng.UniformRandomProvider;

  21. /**
  22.  * Distribution sampler that uses the Fast Loaded Dice Roller (FLDR). It can be used to
  23.  * sample from {@code n} values each with an associated relative weight. If all unique items
  24.  * are assigned the same weight it is more efficient to use the {@link DiscreteUniformSampler}.
  25.  *
  26.  * <p>Given a list {@code L} of {@code n} positive numbers,
  27.  * where {@code L[i]} represents the relative weight of the {@code i}th side, FLDR returns
  28.  * integer {@code i} with relative probability {@code L[i]}.
  29.  *
  30.  * <p>FLDR produces <em>exact</em> samples from the specified probability distribution.
  31.  * <ul>
  32.  *   <li>For integer weights, the probability of returning {@code i} is precisely equal to the
  33.  *   rational number {@code L[i] / m}, where {@code m} is the sum of {@code L}.
  34.  *   <li>For floating-points weights, each weight {@code L[i]} is converted to the
  35.  *   corresponding rational number {@code p[i] / q[i]} where {@code p[i]} is a positive integer and
  36.  *   {@code q[i]} is a power of 2. The rational weights are then normalized (exactly) to sum to unity.
  37.  * </ul>
  38.  *
  39.  * <p>Note that if <em>exact</em> samples are not required then an alternative sampler that
  40.  * ignores very small relative weights may have improved sampling performance.
  41.  *
  42.  * <p>This implementation is based on the algorithm in:
  43.  *
  44.  * <blockquote>
  45.  *  Feras A. Saad, Cameron E. Freer, Martin C. Rinard, and Vikash K. Mansinghka.
  46.  *  The Fast Loaded Dice Roller: A Near-Optimal Exact Sampler for Discrete Probability
  47.  *  Distributions. In AISTATS 2020: Proceedings of the 23rd International Conference on
  48.  *  Artificial Intelligence and Statistics, Proceedings of Machine Learning Research 108,
  49.  *  Palermo, Sicily, Italy, 2020.
  50.  * </blockquote>
  51.  *
  52.  * <p>Sampling uses {@link UniformRandomProvider#nextInt()} as the source of random bits.
  53.  *
  54.  * @see <a href="https://arxiv.org/abs/2003.03830">Saad et al (2020)
  55.  * Proceedings of the 23rd International Conference on Artificial Intelligence and Statistics,
  56.  * PMLR 108:1036-1046.</a>
  57.  * @since 1.5
  58.  */
  59. public abstract class FastLoadedDiceRollerDiscreteSampler
  60.     implements SharedStateDiscreteSampler {
  61.     /**
  62.      * The maximum size of an array.
  63.      *
  64.      * <p>This value is taken from the limit in Open JDK 8 {@code java.util.ArrayList}.
  65.      * It allows VMs to reserve some header words in an array.
  66.      */
  67.     private static final int MAX_ARRAY_SIZE = Integer.MAX_VALUE - 8;
  68.     /** The maximum biased exponent for a finite double.
  69.      * This is offset by 1023 from {@code Math.getExponent(Double.MAX_VALUE)}. */
  70.     private static final int MAX_BIASED_EXPONENT = 2046;
  71.     /** Size of the mantissa of a double. Equal to 52 bits. */
  72.     private static final int MANTISSA_SIZE = 52;
  73.     /** Mask to extract the 52-bit mantissa from a long representation of a double. */
  74.     private static final long MANTISSA_MASK = 0x000f_ffff_ffff_ffffL;
  75.     /** BigInteger representation of {@link Long#MAX_VALUE}. */
  76.     private static final BigInteger MAX_LONG = BigInteger.valueOf(Long.MAX_VALUE);
  77.     /** The maximum offset that will avoid loss of bits for a left shift of a 53-bit value.
  78.      * The value will remain positive for any shift {@code <=} this value. */
  79.     private static final int MAX_OFFSET = 10;
  80.     /** Initial value for no leaf node label. */
  81.     private static final int NO_LABEL = Integer.MAX_VALUE;
  82.     /** Name of the sampler. */
  83.     private static final String SAMPLER_NAME = "Fast Loaded Dice Roller";

  84.     /**
  85.      * Class to handle the edge case of observations in only one category.
  86.      */
  87.     private static final class FixedValueDiscreteSampler extends FastLoadedDiceRollerDiscreteSampler {
  88.         /** The sample value. */
  89.         private final int sampleValue;

  90.         /**
  91.          * @param sampleValue Sample value.
  92.          */
  93.         FixedValueDiscreteSampler(int sampleValue) {
  94.             this.sampleValue = sampleValue;
  95.         }

  96.         @Override
  97.         public int sample() {
  98.             return sampleValue;
  99.         }

  100.         @Override
  101.         public FastLoadedDiceRollerDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
  102.             return this;
  103.         }

  104.         @Override
  105.         public String toString() {
  106.             return SAMPLER_NAME;
  107.         }
  108.     }

  109.     /**
  110.      * Class to implement the FLDR sample algorithm.
  111.      */
  112.     private static final class FLDRSampler extends FastLoadedDiceRollerDiscreteSampler {
  113.         /** Empty boolean source. This is the location of the sign-bit after 31 right shifts on
  114.          * the boolean source. */
  115.         private static final int EMPTY_BOOL_SOURCE = 1;

  116.         /** Underlying source of randomness. */
  117.         private final UniformRandomProvider rng;
  118.         /** Number of categories. */
  119.         private final int n;
  120.         /** Number of levels in the discrete distribution generating (DDG) tree.
  121.          * Equal to {@code ceil(log2(m))} where {@code m} is the sum of observations. */
  122.         private final int k;
  123.         /** Number of leaf nodes at each level. */
  124.         private final int[] h;
  125.         /** Stores the leaf node labels in increasing order. Named {@code H} in the FLDR paper. */
  126.         private final int[] lH;

  127.         /**
  128.          * Provides a bit source for booleans.
  129.          *
  130.          * <p>A cached value from a call to {@link UniformRandomProvider#nextInt()}.
  131.          *
  132.          * <p>Only stores 31-bits when full as 1 bit has already been consumed.
  133.          * The sign bit is a flag that shifts down so the source eventually equals 1
  134.          * when all bits are consumed and will trigger a refill.
  135.          */
  136.         private int booleanSource = EMPTY_BOOL_SOURCE;

  137.         /**
  138.          * Creates a sampler.
  139.          *
  140.          * <p>The input parameters are not validated and must be correctly computed tables.
  141.          *
  142.          * @param rng Generator of uniformly distributed random numbers.
  143.          * @param n Number of categories
  144.          * @param k Number of levels in the discrete distribution generating (DDG) tree.
  145.          * Equal to {@code ceil(log2(m))} where {@code m} is the sum of observations.
  146.          * @param h Number of leaf nodes at each level.
  147.          * @param lH Stores the leaf node labels in increasing order.
  148.          */
  149.         FLDRSampler(UniformRandomProvider rng,
  150.                     int n,
  151.                     int k,
  152.                     int[] h,
  153.                     int[] lH) {
  154.             this.rng = rng;
  155.             this.n = n;
  156.             this.k = k;
  157.             // Deliberate direct storage of input arrays
  158.             this.h = h;
  159.             this.lH = lH;
  160.         }

  161.         /**
  162.          * Creates a copy with a new source of randomness.
  163.          *
  164.          * @param rng Generator of uniformly distributed random numbers.
  165.          * @param source Source to copy.
  166.          */
  167.         private FLDRSampler(UniformRandomProvider rng,
  168.                             FLDRSampler source) {
  169.             this.rng = rng;
  170.             this.n = source.n;
  171.             this.k = source.k;
  172.             this.h = source.h;
  173.             this.lH = source.lH;
  174.         }

  175.         /** {@inheritDoc} */
  176.         @Override
  177.         public int sample() {
  178.             // ALGORITHM 5: SAMPLE
  179.             int c = 0;
  180.             int d = 0;
  181.             for (;;) {
  182.                 // b = flip()
  183.                 // d = 2 * d + (1 - b)
  184.                 d = (d << 1) + flip();
  185.                 if (d < h[c]) {
  186.                     // z = H[d][c]
  187.                     final int z = lH[d * k + c];
  188.                     // assert z != NO_LABEL
  189.                     if (z < n) {
  190.                         return z;
  191.                     }
  192.                     d = 0;
  193.                     c = 0;
  194.                 } else {
  195.                     d = d - h[c];
  196.                     c++;
  197.                 }
  198.             }
  199.         }

  200.         /**
  201.          * Provides a source of boolean bits.
  202.          *
  203.          * <p>Note: This replicates the boolean cache functionality of
  204.          * {@code o.a.c.rng.core.source32.IntProvider}. The method has been simplified to return
  205.          * an {@code int} value rather than a {@code boolean}.
  206.          *
  207.          * @return the bit (0 or 1)
  208.          */
  209.         private int flip() {
  210.             int bits = booleanSource;
  211.             if (bits == 1) {
  212.                 // Refill
  213.                 bits = rng.nextInt();
  214.                 // Store a refill flag in the sign bit and the unused 31 bits, return lowest bit
  215.                 booleanSource = Integer.MIN_VALUE | (bits >>> 1);
  216.                 return bits & 0x1;
  217.             }
  218.             // Shift down eventually triggering refill, return current lowest bit
  219.             booleanSource = bits >>> 1;
  220.             return bits & 0x1;
  221.         }

  222.         /** {@inheritDoc} */
  223.         @Override
  224.         public String toString() {
  225.             return SAMPLER_NAME + " [" + rng.toString() + "]";
  226.         }

  227.         /** {@inheritDoc} */
  228.         @Override
  229.         public FastLoadedDiceRollerDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
  230.             return new FLDRSampler(rng, this);
  231.         }
  232.     }

  233.     /** Package-private constructor. */
  234.     FastLoadedDiceRollerDiscreteSampler() {
  235.         // Intentionally empty
  236.     }

  237.     /** {@inheritDoc} */
  238.     // Redeclare the signature to return a FastLoadedDiceRollerSampler not a SharedStateLongSampler
  239.     @Override
  240.     public abstract FastLoadedDiceRollerDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng);

  241.     /**
  242.      * Creates a sampler.
  243.      *
  244.      * <p>Note: The discrete distribution generating (DDG) tree requires {@code (n + 1) * k} entries
  245.      * where {@code n} is the number of categories, {@code k == ceil(log2(m))} and {@code m}
  246.      * is the sum of the observed frequencies. An exception is raised if this cannot be allocated
  247.      * as a single array.
  248.      *
  249.      * <p>For reference the sum is limited to {@link Long#MAX_VALUE} and the value {@code k} to 63.
  250.      * The number of categories is limited to approximately {@code ((2^31 - 1) / k) = 34,087,042}
  251.      * when the sum of frequencies is large enough to create k=63.
  252.      *
  253.      * @param rng Generator of uniformly distributed random numbers.
  254.      * @param frequencies Observed frequencies of the discrete distribution.
  255.      * @return the sampler
  256.      * @throws IllegalArgumentException if {@code frequencies} is null or empty, a
  257.      * frequency is negative, the sum of all frequencies is either zero or
  258.      * above {@link Long#MAX_VALUE}, or the size of the discrete distribution generating tree
  259.      * is too large.
  260.      */
  261.     public static FastLoadedDiceRollerDiscreteSampler of(UniformRandomProvider rng,
  262.                                                          long[] frequencies) {
  263.         final long m = sum(frequencies);

  264.         // Obtain indices of non-zero frequencies
  265.         final int[] indices = indicesOfNonZero(frequencies);

  266.         // Edge case for 1 non-zero weight. This also handles edge case for 1 observation
  267.         // (as log2(m) == 0 will break the computation of the DDG tree).
  268.         if (indices.length == 1) {
  269.             return new FixedValueDiscreteSampler(indexOfNonZero(frequencies));
  270.         }

  271.         return createSampler(rng, frequencies, indices, m);
  272.     }

  273.     /**
  274.      * Creates a sampler.
  275.      *
  276.      * <p>Weights are converted to rational numbers {@code p / q} where {@code q} is a power of 2.
  277.      * The numerators {@code p} are scaled to use a common denominator before summing.
  278.      *
  279.      * <p>All weights are used to create the sampler. Weights with a small magnitude relative
  280.      * to the largest weight can be excluded using the constructor method with the
  281.      * relative magnitude parameter {@code alpha} (see {@link #of(UniformRandomProvider, double[], int)}).
  282.      *
  283.      * @param rng Generator of uniformly distributed random numbers.
  284.      * @param weights Weights of the discrete distribution.
  285.      * @return the sampler
  286.      * @throws IllegalArgumentException if {@code weights} is null or empty, a
  287.      * weight is negative, infinite or {@code NaN}, the sum of all weights is zero, or the size
  288.      * of the discrete distribution generating tree is too large.
  289.      * @see #of(UniformRandomProvider, double[], int)
  290.      */
  291.     public static FastLoadedDiceRollerDiscreteSampler of(UniformRandomProvider rng,
  292.                                                          double[] weights) {
  293.         return of(rng, weights, 0);
  294.     }

  295.     /**
  296.      * Creates a sampler.
  297.      *
  298.      * <p>Weights are converted to rational numbers {@code p / q} where {@code q} is
  299.      * a power of 2. The numerators {@code p} are scaled to use a common
  300.      * denominator before summing.
  301.      *
  302.      * <p>Note: The discrete distribution generating (DDG) tree requires
  303.      * {@code (n + 1) * k} entries where {@code n} is the number of categories,
  304.      * {@code k == ceil(log2(m))} and {@code m} is the sum of the weight numerators
  305.      * {@code q}. An exception is raised if this cannot be allocated as a single
  306.      * array.
  307.      *
  308.      * <p>For reference the value {@code k} is equal to or greater than the ratio of
  309.      * the largest to the smallest weight expressed as a power of 2. For
  310.      * {@code Double.MAX_VALUE / Double.MIN_VALUE} this is ~2098. The value
  311.      * {@code k} increases with the sum of the weight numerators. A number of
  312.      * weights in excess of 1,000,000 with values equal to {@link Double#MAX_VALUE}
  313.      * would be required to raise an exception when the minimum weight is
  314.      * {@link Double#MIN_VALUE}.
  315.      *
  316.      * <p>Weights with a small magnitude relative to the largest weight can be
  317.      * excluded using the relative magnitude parameter {@code alpha}. This will set
  318.      * any weight to zero if the magnitude is approximately 2<sup>alpha</sup>
  319.      * <em>smaller</em> than the largest weight. This comparison is made using only
  320.      * the exponent of the input weights. The {@code alpha} parameter is ignored if
  321.      * not above zero. Note that a small {@code alpha} parameter will exclude more
  322.      * weights than a large {@code alpha} parameter.
  323.      *
  324.      * <p>The alpha parameter can be used to exclude categories that
  325.      * have a very low probability of occurrence and will improve the construction
  326.      * performance of the sampler. The effect on sampling performance depends on
  327.      * the relative weights of the excluded categories; typically a high {@code alpha}
  328.      * is used to exclude categories that would be visited with a very low probability
  329.      * and the sampling performance is unchanged.
  330.      *
  331.      * <p><b>Implementation Note</b>
  332.      *
  333.      * <p>This method creates a sampler with <em>exact</em> samples from the
  334.      * specified probability distribution. It is recommended to use this method:
  335.      * <ul>
  336.      *  <li>if the weights are computed, for example from a probability mass function; or
  337.      *  <li>if the weights sum to an infinite value.
  338.      * </ul>
  339.      *
  340.      * <p>If the weights are computed from empirical observations then it is
  341.      * recommended to use the factory method
  342.      * {@link #of(UniformRandomProvider, long[]) accepting frequencies}. This
  343.      * requires the total number of observations to be representable as a long
  344.      * integer.
  345.      *
  346.      * <p>Note that if all weights are scaled by a power of 2 to be integers, and
  347.      * each integer can be represented as a positive 64-bit long value, then the
  348.      * sampler created using this method will match the output from a sampler
  349.      * created with the scaled weights converted to long values for the factory
  350.      * method {@link #of(UniformRandomProvider, long[]) accepting frequencies}. This
  351.      * assumes the sum of the integer values does not overflow.
  352.      *
  353.      * <p>It should be noted that the conversion of weights to rational numbers has
  354.      * a performance overhead during construction (sampling performance is not
  355.      * affected). This may be avoided by first converting them to integer values
  356.      * that can be summed without overflow. For example by scaling values by
  357.      * {@code 2^62 / sum} and converting to long by casting or rounding.
  358.      *
  359.      * <p>This approach may increase the efficiency of construction. The resulting
  360.      * sampler may no longer produce <em>exact</em> samples from the distribution.
  361.      * In particular any weights with a converted frequency of zero cannot be
  362.      * sampled.
  363.      *
  364.      * @param rng Generator of uniformly distributed random numbers.
  365.      * @param weights Weights of the discrete distribution.
  366.      * @param alpha Alpha parameter.
  367.      * @return the sampler
  368.      * @throws IllegalArgumentException if {@code weights} is null or empty, a
  369.      * weight is negative, infinite or {@code NaN}, the sum of all weights is zero,
  370.      * or the size of the discrete distribution generating tree is too large.
  371.      * @see #of(UniformRandomProvider, long[])
  372.      */
  373.     public static FastLoadedDiceRollerDiscreteSampler of(UniformRandomProvider rng,
  374.                                                          double[] weights,
  375.                                                          int alpha) {
  376.         final int n = checkWeightsNonZeroLength(weights);

  377.         // Convert floating-point double to a relative weight
  378.         // using a shifted integer representation
  379.         final long[] frequencies = new long[n];
  380.         final int[] offsets = new int[n];
  381.         convertToIntegers(weights, frequencies, offsets, alpha);

  382.         // Obtain indices of non-zero weights
  383.         final int[] indices = indicesOfNonZero(frequencies);

  384.         // Edge case for 1 non-zero weight.
  385.         if (indices.length == 1) {
  386.             return new FixedValueDiscreteSampler(indexOfNonZero(frequencies));
  387.         }

  388.         final BigInteger m = sum(frequencies, offsets, indices);

  389.         // Use long arithmetic if possible. This occurs when the weights are similar in magnitude.
  390.         if (m.compareTo(MAX_LONG) <= 0) {
  391.             // Apply the offset
  392.             for (int i = 0; i < n; i++) {
  393.                 frequencies[i] <<= offsets[i];
  394.             }
  395.             return createSampler(rng, frequencies, indices, m.longValue());
  396.         }

  397.         return createSampler(rng, frequencies, offsets, indices, m);
  398.     }

  399.     /**
  400.      * Sum the frequencies.
  401.      *
  402.      * @param frequencies Frequencies.
  403.      * @return the sum
  404.      * @throws IllegalArgumentException if {@code frequencies} is null or empty, a
  405.      * frequency is negative, or the sum of all frequencies is either zero or above
  406.      * {@link Long#MAX_VALUE}
  407.      */
  408.     private static long sum(long[] frequencies) {
  409.         // Validate
  410.         if (frequencies == null || frequencies.length == 0) {
  411.             throw new IllegalArgumentException("frequencies must contain at least 1 value");
  412.         }

  413.         // Sum the values.
  414.         // Combine all the sign bits in the observations and the intermediate sum in a flag.
  415.         long m = 0;
  416.         long signFlag = 0;
  417.         for (final long o : frequencies) {
  418.             m += o;
  419.             signFlag |= o | m;
  420.         }

  421.         // Check for a sign-bit.
  422.         if (signFlag < 0) {
  423.             // One or more observations were negative, or the sum overflowed.
  424.             for (final long o : frequencies) {
  425.                 if (o < 0) {
  426.                     throw new IllegalArgumentException("frequencies must contain positive values: " + o);
  427.                 }
  428.             }
  429.             throw new IllegalArgumentException("Overflow when summing frequencies");
  430.         }
  431.         if (m == 0) {
  432.             throw new IllegalArgumentException("Sum of frequencies is zero");
  433.         }
  434.         return m;
  435.     }

  436.     /**
  437.      * Convert the floating-point weights to relative weights represented as
  438.      * integers {@code value * 2^exponent}. The relative weight as an integer is:
  439.      *
  440.      * <pre>
  441.      * BigInteger.valueOf(value).shiftLeft(exponent)
  442.      * </pre>
  443.      *
  444.      * <p>Note that the weights are created using a common power-of-2 scaling
  445.      * operation so the minimum exponent is zero.
  446.      *
  447.      * <p>A positive {@code alpha} parameter is used to set any weight to zero if
  448.      * the magnitude is approximately 2<sup>alpha</sup> <em>smaller</em> than the
  449.      * largest weight. This comparison is made using only the exponent of the input
  450.      * weights.
  451.      *
  452.      * @param weights Weights of the discrete distribution.
  453.      * @param values Output floating-point mantissas converted to 53-bit integers.
  454.      * @param exponents Output power of 2 exponent.
  455.      * @param alpha Alpha parameter.
  456.      * @throws IllegalArgumentException if a weight is negative, infinite or
  457.      * {@code NaN}, or the sum of all weights is zero.
  458.      */
  459.     private static void convertToIntegers(double[] weights, long[] values, int[] exponents, int alpha) {
  460.         int maxExponent = Integer.MIN_VALUE;
  461.         for (int i = 0; i < weights.length; i++) {
  462.             final double weight = weights[i];
  463.             // Ignore zero.
  464.             // When creating the integer value later using bit shifts the result will remain zero.
  465.             if (weight == 0) {
  466.                 continue;
  467.             }
  468.             final long bits = Double.doubleToRawLongBits(weight);

  469.             // For the IEEE 754 format see Double.longBitsToDouble(long).

  470.             // Extract the exponent (with the sign bit)
  471.             int exp = (int) (bits >>> MANTISSA_SIZE);
  472.             // Detect negative, infinite or NaN.
  473.             // Note: Negative values sign bit will cause the exponent to be too high.
  474.             if (exp > MAX_BIASED_EXPONENT) {
  475.                 throw new IllegalArgumentException("Invalid weight: " + weight);
  476.             }
  477.             long mantissa;
  478.             if (exp == 0) {
  479.                 // Sub-normal number:
  480.                 mantissa = (bits & MANTISSA_MASK) << 1;
  481.                 // Here we convert to a normalised number by counting the leading zeros
  482.                 // to obtain the number of shifts of the most significant bit in
  483.                 // the mantissa that is required to get a 1 at position 53 (i.e. as
  484.                 // if it were a normal number with assumed leading bit).
  485.                 final int shift = Long.numberOfLeadingZeros(mantissa << 11);
  486.                 mantissa <<= shift;
  487.                 exp -= shift;
  488.             } else {
  489.                 // Normal number. Add the implicit leading 1-bit.
  490.                 mantissa = (bits & MANTISSA_MASK) | (1L << MANTISSA_SIZE);
  491.             }

  492.             // Here the floating-point number is equal to:
  493.             // mantissa * 2^(exp-1075)

  494.             values[i] = mantissa;
  495.             exponents[i] = exp;
  496.             maxExponent = Math.max(maxExponent, exp);
  497.         }

  498.         // No exponent indicates that all weights are zero
  499.         if (maxExponent == Integer.MIN_VALUE) {
  500.             throw new IllegalArgumentException("Sum of weights is zero");
  501.         }

  502.         filterWeights(values, exponents, alpha, maxExponent);
  503.         scaleWeights(values, exponents);
  504.     }

  505.     /**
  506.      * Filters small weights using the {@code alpha} parameter.
  507.      * A positive {@code alpha} parameter is used to set any weight to zero if
  508.      * the magnitude is approximately 2<sup>alpha</sup> <em>smaller</em> than the
  509.      * largest weight. This comparison is made using only the exponent of the input
  510.      * weights.
  511.      *
  512.      * @param values 53-bit values.
  513.      * @param exponents Power of 2 exponent.
  514.      * @param alpha Alpha parameter.
  515.      * @param maxExponent Maximum exponent.
  516.      */
  517.     private static void filterWeights(long[] values, int[] exponents, int alpha, int maxExponent) {
  518.         if (alpha > 0) {
  519.             // Filter weights. This must be done before the values are shifted so
  520.             // the exponent represents the approximate magnitude of the value.
  521.             for (int i = 0; i < exponents.length; i++) {
  522.                 if (maxExponent - exponents[i] > alpha) {
  523.                     values[i] = 0;
  524.                 }
  525.             }
  526.         }
  527.     }

  528.     /**
  529.      * Scale the weights represented as integers {@code value * 2^exponent} to use a
  530.      * minimum exponent of zero. The values are scaled to remove any common trailing zeros
  531.      * in their representation. This ultimately reduces the size of the discrete distribution
  532.      * generating (DGG) tree.
  533.      *
  534.      * @param values 53-bit values.
  535.      * @param exponents Power of 2 exponent.
  536.      */
  537.     private static void scaleWeights(long[] values, int[] exponents) {
  538.         // Find the minimum exponent and common trailing zeros.
  539.         int minExponent = Integer.MAX_VALUE;
  540.         for (int i = 0; i < exponents.length; i++) {
  541.             if (values[i] != 0) {
  542.                 minExponent = Math.min(minExponent, exponents[i]);
  543.             }
  544.         }
  545.         // Trailing zeros occur when the original weights have a representation with
  546.         // less than 52 binary digits, e.g. {1.5, 0.5, 0.25}.
  547.         int trailingZeros = Long.SIZE;
  548.         for (int i = 0; i < values.length && trailingZeros != 0; i++) {
  549.             trailingZeros = Math.min(trailingZeros, Long.numberOfTrailingZeros(values[i]));
  550.         }
  551.         // Scale by a power of 2 so the minimum exponent is zero.
  552.         for (int i = 0; i < exponents.length; i++) {
  553.             exponents[i] -= minExponent;
  554.         }
  555.         // Remove common trailing zeros.
  556.         if (trailingZeros != 0) {
  557.             for (int i = 0; i < values.length; i++) {
  558.                 values[i] >>>= trailingZeros;
  559.             }
  560.         }
  561.     }

  562.     /**
  563.      * Sum the integers at the specified indices.
  564.      * Integers are represented as {@code value * 2^exponent}.
  565.      *
  566.      * @param values 53-bit values.
  567.      * @param exponents Power of 2 exponent.
  568.      * @param indices Indices to sum.
  569.      * @return the sum
  570.      */
  571.     private static BigInteger sum(long[] values, int[] exponents, int[] indices) {
  572.         BigInteger m = BigInteger.ZERO;
  573.         for (final int i : indices) {
  574.             m = m.add(toBigInteger(values[i], exponents[i]));
  575.         }
  576.         return m;
  577.     }

  578.     /**
  579.      * Convert the value and left shift offset to a BigInteger.
  580.      * It is assumed the value is at most 53-bits. This allows optimising the left
  581.      * shift if it is below 11 bits.
  582.      *
  583.      * @param value 53-bit value.
  584.      * @param offset Left shift offset (must be positive).
  585.      * @return the BigInteger
  586.      */
  587.     private static BigInteger toBigInteger(long value, int offset) {
  588.         // Ignore zeros. The sum method uses indices of non-zero values.
  589.         if (offset <= MAX_OFFSET) {
  590.             // Assume (value << offset) <= Long.MAX_VALUE
  591.             return BigInteger.valueOf(value << offset);
  592.         }
  593.         return BigInteger.valueOf(value).shiftLeft(offset);
  594.     }

  595.     /**
  596.      * Creates the sampler.
  597.      *
  598.      * <p>It is assumed the frequencies are all positive and the sum does not
  599.      * overflow.
  600.      *
  601.      * @param rng Generator of uniformly distributed random numbers.
  602.      * @param frequencies Observed frequencies of the discrete distribution.
  603.      * @param indices Indices of non-zero frequencies.
  604.      * @param m Sum of the frequencies.
  605.      * @return the sampler
  606.      */
  607.     private static FastLoadedDiceRollerDiscreteSampler createSampler(UniformRandomProvider rng,
  608.                                                                      long[] frequencies,
  609.                                                                      int[] indices,
  610.                                                                      long m) {
  611.         // ALGORITHM 5: PREPROCESS
  612.         // a == frequencies
  613.         // m = sum(a)
  614.         // h = leaf node count
  615.         // H = leaf node label (lH)

  616.         final int n = frequencies.length;

  617.         // k = ceil(log2(m))
  618.         final int k = 64 - Long.numberOfLeadingZeros(m - 1);
  619.         // r = a(n+1) = 2^k - m
  620.         final long r = (1L << k) - m;

  621.         // Note:
  622.         // A sparse matrix can often be used for H, as most of its entries are empty.
  623.         // This implementation uses a 1D array for efficiency at the cost of memory.
  624.         // This is limited to approximately ((2^31 - 1) / k), e.g. 34087042 when the sum of
  625.         // observations is large enough to create k=63.
  626.         // This could be handled using a 2D array. In practice a number of categories this
  627.         // large is not expected and is currently not supported.
  628.         final int[] h = new int[k];
  629.         final int[] lH = new int[checkArraySize((n + 1L) * k)];
  630.         Arrays.fill(lH, NO_LABEL);

  631.         int d;
  632.         for (int j = 0; j < k; j++) {
  633.             final int shift = (k - 1) - j;
  634.             final long bitMask = 1L << shift;

  635.             d = 0;
  636.             for (final int i : indices) {
  637.                 // bool w ← (a[i] >> (k − 1) − j)) & 1
  638.                 // h[j] = h[j] + w
  639.                 // if w then:
  640.                 if ((frequencies[i] & bitMask) != 0) {
  641.                     h[j]++;
  642.                     // H[d][j] = i
  643.                     lH[d * k + j] = i;
  644.                     d++;
  645.                 }
  646.             }
  647.             // process a(n+1) without extending the input frequencies array by 1
  648.             if ((r & bitMask) != 0) {
  649.                 h[j]++;
  650.                 lH[d * k + j] = n;
  651.             }
  652.         }

  653.         return new FLDRSampler(rng, n, k, h, lH);
  654.     }

  655.     /**
  656.      * Creates the sampler. Frequencies are represented as a 53-bit value with a
  657.      * left-shift offset.
  658.      * <pre>
  659.      * BigInteger.valueOf(value).shiftLeft(offset)
  660.      * </pre>
  661.      *
  662.      * <p>It is assumed the frequencies are all positive.
  663.      *
  664.      * @param rng Generator of uniformly distributed random numbers.
  665.      * @param frequencies Observed frequencies of the discrete distribution.
  666.      * @param offsets Left shift offsets (must be positive).
  667.      * @param indices Indices of non-zero frequencies.
  668.      * @param m Sum of the frequencies.
  669.      * @return the sampler
  670.      */
  671.     private static FastLoadedDiceRollerDiscreteSampler createSampler(UniformRandomProvider rng,
  672.                                                                      long[] frequencies,
  673.                                                                      int[] offsets,
  674.                                                                      int[] indices,
  675.                                                                      BigInteger m) {
  676.         // Repeat the logic from createSampler(...) using extended arithmetic to test the bits

  677.         // ALGORITHM 5: PREPROCESS
  678.         // a == frequencies
  679.         // m = sum(a)
  680.         // h = leaf node count
  681.         // H = leaf node label (lH)

  682.         final int n = frequencies.length;

  683.         // k = ceil(log2(m))
  684.         final int k = m.subtract(BigInteger.ONE).bitLength();
  685.         // r = a(n+1) = 2^k - m
  686.         final BigInteger r = BigInteger.ONE.shiftLeft(k).subtract(m);

  687.         final int[] h = new int[k];
  688.         final int[] lH = new int[checkArraySize((n + 1L) * k)];
  689.         Arrays.fill(lH, NO_LABEL);

  690.         int d;
  691.         for (int j = 0; j < k; j++) {
  692.             final int shift = (k - 1) - j;

  693.             d = 0;
  694.             for (final int i : indices) {
  695.                 // bool w ← (a[i] >> (k − 1) − j)) & 1
  696.                 // h[j] = h[j] + w
  697.                 // if w then:
  698.                 if (testBit(frequencies[i], offsets[i], shift)) {
  699.                     h[j]++;
  700.                     // H[d][j] = i
  701.                     lH[d * k + j] = i;
  702.                     d++;
  703.                 }
  704.             }
  705.             // process a(n+1) without extending the input frequencies array by 1
  706.             if (r.testBit(shift)) {
  707.                 h[j]++;
  708.                 lH[d * k + j] = n;
  709.             }
  710.         }

  711.         return new FLDRSampler(rng, n, k, h, lH);
  712.     }

  713.     /**
  714.      * Test the logical bit of the shifted integer representation.
  715.      * The value is assumed to have at most 53-bits of information. The offset
  716.      * is assumed to be positive. This is functionally equivalent to:
  717.      * <pre>
  718.      * BigInteger.valueOf(value).shiftLeft(offset).testBit(n)
  719.      * </pre>
  720.      *
  721.      * @param value 53-bit value.
  722.      * @param offset Left shift offset.
  723.      * @param n Index of bit to test.
  724.      * @return true if the bit is 1
  725.      */
  726.     private static boolean testBit(long value, int offset, int n) {
  727.         if (n < offset) {
  728.             // All logical trailing bits are zero
  729.             return false;
  730.         }
  731.         // Test if outside the 53-bit value (note that the implicit 1 bit
  732.         // has been added to the 52-bit mantissas for 'normal' floating-point numbers).
  733.         final int bit = n - offset;
  734.         return bit <= MANTISSA_SIZE && (value & (1L << bit)) != 0;
  735.     }

  736.     /**
  737.      * Check the weights have a non-zero length.
  738.      *
  739.      * @param weights Weights.
  740.      * @return the length
  741.      */
  742.     private static int checkWeightsNonZeroLength(double[] weights) {
  743.         if (weights == null || weights.length == 0) {
  744.             throw new IllegalArgumentException("weights must contain at least 1 value");
  745.         }
  746.         return weights.length;
  747.     }

  748.     /**
  749.      * Create the indices of non-zero values.
  750.      *
  751.      * @param values Values.
  752.      * @return the indices
  753.      */
  754.     private static int[] indicesOfNonZero(long[] values) {
  755.         int n = 0;
  756.         final int[] indices = new int[values.length];
  757.         for (int i = 0; i < values.length; i++) {
  758.             if (values[i] != 0) {
  759.                 indices[n++] = i;
  760.             }
  761.         }
  762.         return Arrays.copyOf(indices, n);
  763.     }

  764.     /**
  765.      * Find the index of the first non-zero frequency.
  766.      *
  767.      * @param frequencies Frequencies.
  768.      * @return the index
  769.      * @throws IllegalStateException if all frequencies are zero.
  770.      */
  771.     static int indexOfNonZero(long[] frequencies) {
  772.         for (int i = 0; i < frequencies.length; i++) {
  773.             if (frequencies[i] != 0) {
  774.                 return i;
  775.             }
  776.         }
  777.         throw new IllegalStateException("All frequencies are zero");
  778.     }

  779.     /**
  780.      * Check the size is valid for a 1D array.
  781.      *
  782.      * @param size Size
  783.      * @return the size as an {@code int}
  784.      * @throws IllegalArgumentException if the size is too large for a 1D array.
  785.      */
  786.     static int checkArraySize(long size) {
  787.         if (size > MAX_ARRAY_SIZE) {
  788.             throw new IllegalArgumentException("Unable to allocate array of size: " + size);
  789.         }
  790.         return (int) size;
  791.     }
  792. }