View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.rng.sampling.distribution;
18  
19  import java.math.BigInteger;
20  import java.util.Arrays;
21  import org.apache.commons.rng.UniformRandomProvider;
22  
23  /**
24   * Distribution sampler that uses the Fast Loaded Dice Roller (FLDR). It can be used to
25   * sample from {@code n} values each with an associated relative weight. If all unique items
26   * are assigned the same weight it is more efficient to use the {@link DiscreteUniformSampler}.
27   *
28   * <p>Given a list {@code L} of {@code n} positive numbers,
29   * where {@code L[i]} represents the relative weight of the {@code i}th side, FLDR returns
30   * integer {@code i} with relative probability {@code L[i]}.
31   *
32   * <p>FLDR produces <em>exact</em> samples from the specified probability distribution.
33   * <ul>
34   *   <li>For integer weights, the probability of returning {@code i} is precisely equal to the
35   *   rational number {@code L[i] / m}, where {@code m} is the sum of {@code L}.
36   *   <li>For floating-points weights, each weight {@code L[i]} is converted to the
37   *   corresponding rational number {@code p[i] / q[i]} where {@code p[i]} is a positive integer and
38   *   {@code q[i]} is a power of 2. The rational weights are then normalized (exactly) to sum to unity.
39   * </ul>
40   *
41   * <p>Note that if <em>exact</em> samples are not required then an alternative sampler that
42   * ignores very small relative weights may have improved sampling performance.
43   *
44   * <p>This implementation is based on the algorithm in:
45   *
46   * <blockquote>
47   *  Feras A. Saad, Cameron E. Freer, Martin C. Rinard, and Vikash K. Mansinghka.
48   *  The Fast Loaded Dice Roller: A Near-Optimal Exact Sampler for Discrete Probability
49   *  Distributions. In AISTATS 2020: Proceedings of the 23rd International Conference on
50   *  Artificial Intelligence and Statistics, Proceedings of Machine Learning Research 108,
51   *  Palermo, Sicily, Italy, 2020.
52   * </blockquote>
53   *
54   * <p>Sampling uses {@link UniformRandomProvider#nextInt()} as the source of random bits.
55   *
56   * @see <a href="https://arxiv.org/abs/2003.03830">Saad et al (2020)
57   * Proceedings of the 23rd International Conference on Artificial Intelligence and Statistics,
58   * PMLR 108:1036-1046.</a>
59   * @since 1.5
60   */
61  public abstract class FastLoadedDiceRollerDiscreteSampler
62      implements SharedStateDiscreteSampler {
63      /**
64       * The maximum size of an array.
65       *
66       * <p>This value is taken from the limit in Open JDK 8 {@code java.util.ArrayList}.
67       * It allows VMs to reserve some header words in an array.
68       */
69      private static final int MAX_ARRAY_SIZE = Integer.MAX_VALUE - 8;
70      /** The maximum biased exponent for a finite double.
71       * This is offset by 1023 from {@code Math.getExponent(Double.MAX_VALUE)}. */
72      private static final int MAX_BIASED_EXPONENT = 2046;
73      /** Size of the mantissa of a double. Equal to 52 bits. */
74      private static final int MANTISSA_SIZE = 52;
75      /** Mask to extract the 52-bit mantissa from a long representation of a double. */
76      private static final long MANTISSA_MASK = 0x000f_ffff_ffff_ffffL;
77      /** BigInteger representation of {@link Long#MAX_VALUE}. */
78      private static final BigInteger MAX_LONG = BigInteger.valueOf(Long.MAX_VALUE);
79      /** The maximum offset that will avoid loss of bits for a left shift of a 53-bit value.
80       * The value will remain positive for any shift {@code <=} this value. */
81      private static final int MAX_OFFSET = 10;
82      /** Initial value for no leaf node label. */
83      private static final int NO_LABEL = Integer.MAX_VALUE;
84      /** Name of the sampler. */
85      private static final String SAMPLER_NAME = "Fast Loaded Dice Roller";
86  
87      /**
88       * Class to handle the edge case of observations in only one category.
89       */
90      private static class FixedValueDiscreteSampler extends FastLoadedDiceRollerDiscreteSampler {
91          /** The sample value. */
92          private final int sampleValue;
93  
94          /**
95           * @param sampleValue Sample value.
96           */
97          FixedValueDiscreteSampler(int sampleValue) {
98              this.sampleValue = sampleValue;
99          }
100 
101         @Override
102         public int sample() {
103             return sampleValue;
104         }
105 
106         @Override
107         public FastLoadedDiceRollerDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
108             return this;
109         }
110 
111         @Override
112         public String toString() {
113             return SAMPLER_NAME;
114         }
115     }
116 
117     /**
118      * Class to implement the FLDR sample algorithm.
119      */
120     private static class FLDRSampler extends FastLoadedDiceRollerDiscreteSampler {
121         /** Empty boolean source. This is the location of the sign-bit after 31 right shifts on
122          * the boolean source. */
123         private static final int EMPTY_BOOL_SOURCE = 1;
124 
125         /** Underlying source of randomness. */
126         private final UniformRandomProvider rng;
127         /** Number of categories. */
128         private final int n;
129         /** Number of levels in the discrete distribution generating (DDG) tree.
130          * Equal to {@code ceil(log2(m))} where {@code m} is the sum of observations. */
131         private final int k;
132         /** Number of leaf nodes at each level. */
133         private final int[] h;
134         /** Stores the leaf node labels in increasing order. Named {@code H} in the FLDR paper. */
135         private final int[] lH;
136 
137         /**
138          * Provides a bit source for booleans.
139          *
140          * <p>A cached value from a call to {@link UniformRandomProvider#nextInt()}.
141          *
142          * <p>Only stores 31-bits when full as 1 bit has already been consumed.
143          * The sign bit is a flag that shifts down so the source eventually equals 1
144          * when all bits are consumed and will trigger a refill.
145          */
146         private int booleanSource = EMPTY_BOOL_SOURCE;
147 
148         /**
149          * Creates a sampler.
150          *
151          * <p>The input parameters are not validated and must be correctly computed tables.
152          *
153          * @param rng Generator of uniformly distributed random numbers.
154          * @param n Number of categories
155          * @param k Number of levels in the discrete distribution generating (DDG) tree.
156          * Equal to {@code ceil(log2(m))} where {@code m} is the sum of observations.
157          * @param h Number of leaf nodes at each level.
158          * @param lH Stores the leaf node labels in increasing order.
159          */
160         FLDRSampler(UniformRandomProvider rng,
161                     int n,
162                     int k,
163                     int[] h,
164                     int[] lH) {
165             this.rng = rng;
166             this.n = n;
167             this.k = k;
168             // Deliberate direct storage of input arrays
169             this.h = h;
170             this.lH = lH;
171         }
172 
173         /**
174          * Creates a copy with a new source of randomness.
175          *
176          * @param rng Generator of uniformly distributed random numbers.
177          * @param source Source to copy.
178          */
179         private FLDRSampler(UniformRandomProvider rng,
180                             FLDRSampler source) {
181             this.rng = rng;
182             this.n = source.n;
183             this.k = source.k;
184             this.h = source.h;
185             this.lH = source.lH;
186         }
187 
188         /** {@inheritDoc} */
189         @Override
190         public int sample() {
191             // ALGORITHM 5: SAMPLE
192             int c = 0;
193             int d = 0;
194             for (;;) {
195                 // b = flip()
196                 // d = 2 * d + (1 - b)
197                 d = (d << 1) + flip();
198                 if (d < h[c]) {
199                     // z = H[d][c]
200                     final int z = lH[d * k + c];
201                     // assert z != NO_LABEL
202                     if (z < n) {
203                         return z;
204                     }
205                     d = 0;
206                     c = 0;
207                 } else {
208                     d = d - h[c];
209                     c++;
210                 }
211             }
212         }
213 
214         /**
215          * Provides a source of boolean bits.
216          *
217          * <p>Note: This replicates the boolean cache functionality of
218          * {@code o.a.c.rng.core.source32.IntProvider}. The method has been simplified to return
219          * an {@code int} value rather than a {@code boolean}.
220          *
221          * @return the bit (0 or 1)
222          */
223         private int flip() {
224             int bits = booleanSource;
225             if (bits == 1) {
226                 // Refill
227                 bits = rng.nextInt();
228                 // Store a refill flag in the sign bit and the unused 31 bits, return lowest bit
229                 booleanSource = Integer.MIN_VALUE | (bits >>> 1);
230                 return bits & 0x1;
231             }
232             // Shift down eventually triggering refill, return current lowest bit
233             booleanSource = bits >>> 1;
234             return bits & 0x1;
235         }
236 
237         /** {@inheritDoc} */
238         @Override
239         public String toString() {
240             return SAMPLER_NAME + " [" + rng.toString() + "]";
241         }
242 
243         /** {@inheritDoc} */
244         @Override
245         public FastLoadedDiceRollerDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
246             return new FLDRSampler(rng, this);
247         }
248     }
249 
250     /** Package-private constructor. */
251     FastLoadedDiceRollerDiscreteSampler() {
252         // Intentionally empty
253     }
254 
255     /** {@inheritDoc} */
256     // Redeclare the signature to return a FastLoadedDiceRollerSampler not a SharedStateLongSampler
257     @Override
258     public abstract FastLoadedDiceRollerDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng);
259 
260     /**
261      * Creates a sampler.
262      *
263      * <p>Note: The discrete distribution generating (DDG) tree requires {@code (n + 1) * k} entries
264      * where {@code n} is the number of categories, {@code k == ceil(log2(m))} and {@code m}
265      * is the sum of the observed frequencies. An exception is raised if this cannot be allocated
266      * as a single array.
267      *
268      * <p>For reference the sum is limited to {@link Long#MAX_VALUE} and the value {@code k} to 63.
269      * The number of categories is limited to approximately {@code ((2^31 - 1) / k) = 34,087,042}
270      * when the sum of frequencies is large enough to create k=63.
271      *
272      * @param rng Generator of uniformly distributed random numbers.
273      * @param frequencies Observed frequencies of the discrete distribution.
274      * @return the sampler
275      * @throws IllegalArgumentException if {@code frequencies} is null or empty, a
276      * frequency is negative, the sum of all frequencies is either zero or
277      * above {@link Long#MAX_VALUE}, or the size of the discrete distribution generating tree
278      * is too large.
279      */
280     public static FastLoadedDiceRollerDiscreteSampler of(UniformRandomProvider rng,
281                                                          long[] frequencies) {
282         final long m = sum(frequencies);
283 
284         // Obtain indices of non-zero frequencies
285         final int[] indices = indicesOfNonZero(frequencies);
286 
287         // Edge case for 1 non-zero weight. This also handles edge case for 1 observation
288         // (as log2(m) == 0 will break the computation of the DDG tree).
289         if (indices.length == 1) {
290             return new FixedValueDiscreteSampler(indexOfNonZero(frequencies));
291         }
292 
293         return createSampler(rng, frequencies, indices, m);
294     }
295 
296     /**
297      * Creates a sampler.
298      *
299      * <p>Weights are converted to rational numbers {@code p / q} where {@code q} is a power of 2.
300      * The numerators {@code p} are scaled to use a common denominator before summing.
301      *
302      * <p>All weights are used to create the sampler. Weights with a small magnitude relative
303      * to the largest weight can be excluded using the constructor method with the
304      * relative magnitude parameter {@code alpha} (see {@link #of(UniformRandomProvider, double[], int)}).
305      *
306      * @param rng Generator of uniformly distributed random numbers.
307      * @param weights Weights of the discrete distribution.
308      * @return the sampler
309      * @throws IllegalArgumentException if {@code weights} is null or empty, a
310      * weight is negative, infinite or {@code NaN}, the sum of all weights is zero, or the size
311      * of the discrete distribution generating tree is too large.
312      * @see #of(UniformRandomProvider, double[], int)
313      */
314     public static FastLoadedDiceRollerDiscreteSampler of(UniformRandomProvider rng,
315                                                          double[] weights) {
316         return of(rng, weights, 0);
317     }
318 
319     /**
320      * Creates a sampler.
321      *
322      * <p>Weights are converted to rational numbers {@code p / q} where {@code q} is
323      * a power of 2. The numerators {@code p} are scaled to use a common
324      * denominator before summing.
325      *
326      * <p>Note: The discrete distribution generating (DDG) tree requires
327      * {@code (n + 1) * k} entries where {@code n} is the number of categories,
328      * {@code k == ceil(log2(m))} and {@code m} is the sum of the weight numerators
329      * {@code q}. An exception is raised if this cannot be allocated as a single
330      * array.
331      *
332      * <p>For reference the value {@code k} is equal to or greater than the ratio of
333      * the largest to the smallest weight expressed as a power of 2. For
334      * {@code Double.MAX_VALUE / Double.MIN_VALUE} this is ~2098. The value
335      * {@code k} increases with the sum of the weight numerators. A number of
336      * weights in excess of 1,000,000 with values equal to {@link Double#MAX_VALUE}
337      * would be required to raise an exception when the minimum weight is
338      * {@link Double#MIN_VALUE}.
339      *
340      * <p>Weights with a small magnitude relative to the largest weight can be
341      * excluded using the relative magnitude parameter {@code alpha}. This will set
342      * any weight to zero if the magnitude is approximately 2<sup>alpha</sup>
343      * <em>smaller</em> than the largest weight. This comparison is made using only
344      * the exponent of the input weights. The {@code alpha} parameter is ignored if
345      * not above zero. Note that a small {@code alpha} parameter will exclude more
346      * weights than a large {@code alpha} parameter.
347      *
348      * <p>The alpha parameter can be used to exclude categories that
349      * have a very low probability of occurrence and will improve the construction
350      * performance of the sampler. The effect on sampling performance depends on
351      * the relative weights of the excluded categories; typically a high {@code alpha}
352      * is used to exclude categories that would be visited with a very low probability
353      * and the sampling performance is unchanged.
354      *
355      * <p><b>Implementation Note</b>
356      *
357      * <p>This method creates a sampler with <em>exact</em> samples from the
358      * specified probability distribution. It is recommended to use this method:
359      * <ul>
360      *  <li>if the weights are computed, for example from a probability mass function; or
361      *  <li>if the weights sum to an infinite value.
362      * </ul>
363      *
364      * <p>If the weights are computed from empirical observations then it is
365      * recommended to use the factory method
366      * {@link #of(UniformRandomProvider, long[]) accepting frequencies}. This
367      * requires the total number of observations to be representable as a long
368      * integer.
369      *
370      * <p>Note that if all weights are scaled by a power of 2 to be integers, and
371      * each integer can be represented as a positive 64-bit long value, then the
372      * sampler created using this method will match the output from a sampler
373      * created with the scaled weights converted to long values for the factory
374      * method {@link #of(UniformRandomProvider, long[]) accepting frequencies}. This
375      * assumes the sum of the integer values does not overflow.
376      *
377      * <p>It should be noted that the conversion of weights to rational numbers has
378      * a performance overhead during construction (sampling performance is not
379      * affected). This may be avoided by first converting them to integer values
380      * that can be summed without overflow. For example by scaling values by
381      * {@code 2^62 / sum} and converting to long by casting or rounding.
382      *
383      * <p>This approach may increase the efficiency of construction. The resulting
384      * sampler may no longer produce <em>exact</em> samples from the distribution.
385      * In particular any weights with a converted frequency of zero cannot be
386      * sampled.
387      *
388      * @param rng Generator of uniformly distributed random numbers.
389      * @param weights Weights of the discrete distribution.
390      * @param alpha Alpha parameter.
391      * @return the sampler
392      * @throws IllegalArgumentException if {@code weights} is null or empty, a
393      * weight is negative, infinite or {@code NaN}, the sum of all weights is zero,
394      * or the size of the discrete distribution generating tree is too large.
395      * @see #of(UniformRandomProvider, long[])
396      */
397     public static FastLoadedDiceRollerDiscreteSampler of(UniformRandomProvider rng,
398                                                          double[] weights,
399                                                          int alpha) {
400         final int n = checkWeightsNonZeroLength(weights);
401 
402         // Convert floating-point double to a relative weight
403         // using a shifted integer representation
404         final long[] frequencies = new long[n];
405         final int[] offsets = new int[n];
406         convertToIntegers(weights, frequencies, offsets, alpha);
407 
408         // Obtain indices of non-zero weights
409         final int[] indices = indicesOfNonZero(frequencies);
410 
411         // Edge case for 1 non-zero weight.
412         if (indices.length == 1) {
413             return new FixedValueDiscreteSampler(indexOfNonZero(frequencies));
414         }
415 
416         final BigInteger m = sum(frequencies, offsets, indices);
417 
418         // Use long arithmetic if possible. This occurs when the weights are similar in magnitude.
419         if (m.compareTo(MAX_LONG) <= 0) {
420             // Apply the offset
421             for (int i = 0; i < n; i++) {
422                 frequencies[i] <<= offsets[i];
423             }
424             return createSampler(rng, frequencies, indices, m.longValue());
425         }
426 
427         return createSampler(rng, frequencies, offsets, indices, m);
428     }
429 
430     /**
431      * Sum the frequencies.
432      *
433      * @param frequencies Frequencies.
434      * @return the sum
435      * @throws IllegalArgumentException if {@code frequencies} is null or empty, a
436      * frequency is negative, or the sum of all frequencies is either zero or above
437      * {@link Long#MAX_VALUE}
438      */
439     private static long sum(long[] frequencies) {
440         // Validate
441         if (frequencies == null || frequencies.length == 0) {
442             throw new IllegalArgumentException("frequencies must contain at least 1 value");
443         }
444 
445         // Sum the values.
446         // Combine all the sign bits in the observations and the intermediate sum in a flag.
447         long m = 0;
448         long signFlag = 0;
449         for (final long o : frequencies) {
450             m += o;
451             signFlag |= o | m;
452         }
453 
454         // Check for a sign-bit.
455         if (signFlag < 0) {
456             // One or more observations were negative, or the sum overflowed.
457             for (final long o : frequencies) {
458                 if (o < 0) {
459                     throw new IllegalArgumentException("frequencies must contain positive values: " + o);
460                 }
461             }
462             throw new IllegalArgumentException("Overflow when summing frequencies");
463         }
464         if (m == 0) {
465             throw new IllegalArgumentException("Sum of frequencies is zero");
466         }
467         return m;
468     }
469 
470     /**
471      * Convert the floating-point weights to relative weights represented as
472      * integers {@code value * 2^exponent}. The relative weight as an integer is:
473      *
474      * <pre>
475      * BigInteger.valueOf(value).shiftLeft(exponent)
476      * </pre>
477      *
478      * <p>Note that the weights are created using a common power-of-2 scaling
479      * operation so the minimum exponent is zero.
480      *
481      * <p>A positive {@code alpha} parameter is used to set any weight to zero if
482      * the magnitude is approximately 2<sup>alpha</sup> <em>smaller</em> than the
483      * largest weight. This comparison is made using only the exponent of the input
484      * weights.
485      *
486      * @param weights Weights of the discrete distribution.
487      * @param values Output floating-point mantissas converted to 53-bit integers.
488      * @param exponents Output power of 2 exponent.
489      * @param alpha Alpha parameter.
490      * @throws IllegalArgumentException if a weight is negative, infinite or
491      * {@code NaN}, or the sum of all weights is zero.
492      */
493     private static void convertToIntegers(double[] weights, long[] values, int[] exponents, int alpha) {
494         int maxExponent = Integer.MIN_VALUE;
495         for (int i = 0; i < weights.length; i++) {
496             final double weight = weights[i];
497             // Ignore zero.
498             // When creating the integer value later using bit shifts the result will remain zero.
499             if (weight == 0) {
500                 continue;
501             }
502             final long bits = Double.doubleToRawLongBits(weight);
503 
504             // For the IEEE 754 format see Double.longBitsToDouble(long).
505 
506             // Extract the exponent (with the sign bit)
507             int exp = (int) (bits >>> MANTISSA_SIZE);
508             // Detect negative, infinite or NaN.
509             // Note: Negative values sign bit will cause the exponent to be too high.
510             if (exp > MAX_BIASED_EXPONENT) {
511                 throw new IllegalArgumentException("Invalid weight: " + weight);
512             }
513             long mantissa;
514             if (exp == 0) {
515                 // Sub-normal number:
516                 mantissa = (bits & MANTISSA_MASK) << 1;
517                 // Here we convert to a normalised number by counting the leading zeros
518                 // to obtain the number of shifts of the most significant bit in
519                 // the mantissa that is required to get a 1 at position 53 (i.e. as
520                 // if it were a normal number with assumed leading bit).
521                 final int shift = Long.numberOfLeadingZeros(mantissa << 11);
522                 mantissa <<= shift;
523                 exp -= shift;
524             } else {
525                 // Normal number. Add the implicit leading 1-bit.
526                 mantissa = (bits & MANTISSA_MASK) | (1L << MANTISSA_SIZE);
527             }
528 
529             // Here the floating-point number is equal to:
530             // mantissa * 2^(exp-1075)
531 
532             values[i] = mantissa;
533             exponents[i] = exp;
534             maxExponent = Math.max(maxExponent, exp);
535         }
536 
537         // No exponent indicates that all weights are zero
538         if (maxExponent == Integer.MIN_VALUE) {
539             throw new IllegalArgumentException("Sum of weights is zero");
540         }
541 
542         filterWeights(values, exponents, alpha, maxExponent);
543         scaleWeights(values, exponents);
544     }
545 
546     /**
547      * Filters small weights using the {@code alpha} parameter.
548      * A positive {@code alpha} parameter is used to set any weight to zero if
549      * the magnitude is approximately 2<sup>alpha</sup> <em>smaller</em> than the
550      * largest weight. This comparison is made using only the exponent of the input
551      * weights.
552      *
553      * @param values 53-bit values.
554      * @param exponents Power of 2 exponent.
555      * @param alpha Alpha parameter.
556      * @param maxExponent Maximum exponent.
557      */
558     private static void filterWeights(long[] values, int[] exponents, int alpha, int maxExponent) {
559         if (alpha > 0) {
560             // Filter weights. This must be done before the values are shifted so
561             // the exponent represents the approximate magnitude of the value.
562             for (int i = 0; i < exponents.length; i++) {
563                 if (maxExponent - exponents[i] > alpha) {
564                     values[i] = 0;
565                 }
566             }
567         }
568     }
569 
570     /**
571      * Scale the weights represented as integers {@code value * 2^exponent} to use a
572      * minimum exponent of zero. The values are scaled to remove any common trailing zeros
573      * in their representation. This ultimately reduces the size of the discrete distribution
574      * generating (DGG) tree.
575      *
576      * @param values 53-bit values.
577      * @param exponents Power of 2 exponent.
578      */
579     private static void scaleWeights(long[] values, int[] exponents) {
580         // Find the minimum exponent and common trailing zeros.
581         int minExponent = Integer.MAX_VALUE;
582         for (int i = 0; i < exponents.length; i++) {
583             if (values[i] != 0) {
584                 minExponent = Math.min(minExponent, exponents[i]);
585             }
586         }
587         // Trailing zeros occur when the original weights have a representation with
588         // less than 52 binary digits, e.g. {1.5, 0.5, 0.25}.
589         int trailingZeros = Long.SIZE;
590         for (int i = 0; i < values.length && trailingZeros != 0; i++) {
591             trailingZeros = Math.min(trailingZeros, Long.numberOfTrailingZeros(values[i]));
592         }
593         // Scale by a power of 2 so the minimum exponent is zero.
594         for (int i = 0; i < exponents.length; i++) {
595             exponents[i] -= minExponent;
596         }
597         // Remove common trailing zeros.
598         if (trailingZeros != 0) {
599             for (int i = 0; i < values.length; i++) {
600                 values[i] >>>= trailingZeros;
601             }
602         }
603     }
604 
605     /**
606      * Sum the integers at the specified indices.
607      * Integers are represented as {@code value * 2^exponent}.
608      *
609      * @param values 53-bit values.
610      * @param exponents Power of 2 exponent.
611      * @param indices Indices to sum.
612      * @return the sum
613      */
614     private static BigInteger sum(long[] values, int[] exponents, int[] indices) {
615         BigInteger m = BigInteger.ZERO;
616         for (final int i : indices) {
617             m = m.add(toBigInteger(values[i], exponents[i]));
618         }
619         return m;
620     }
621 
622     /**
623      * Convert the value and left shift offset to a BigInteger.
624      * It is assumed the value is at most 53-bits. This allows optimising the left
625      * shift if it is below 11 bits.
626      *
627      * @param value 53-bit value.
628      * @param offset Left shift offset (must be positive).
629      * @return the BigInteger
630      */
631     private static BigInteger toBigInteger(long value, int offset) {
632         // Ignore zeros. The sum method uses indices of non-zero values.
633         if (offset <= MAX_OFFSET) {
634             // Assume (value << offset) <= Long.MAX_VALUE
635             return BigInteger.valueOf(value << offset);
636         }
637         return BigInteger.valueOf(value).shiftLeft(offset);
638     }
639 
640     /**
641      * Creates the sampler.
642      *
643      * <p>It is assumed the frequencies are all positive and the sum does not
644      * overflow.
645      *
646      * @param rng Generator of uniformly distributed random numbers.
647      * @param frequencies Observed frequencies of the discrete distribution.
648      * @param indices Indices of non-zero frequencies.
649      * @param m Sum of the frequencies.
650      * @return the sampler
651      */
652     private static FastLoadedDiceRollerDiscreteSampler createSampler(UniformRandomProvider rng,
653                                                                      long[] frequencies,
654                                                                      int[] indices,
655                                                                      long m) {
656         // ALGORITHM 5: PREPROCESS
657         // a == frequencies
658         // m = sum(a)
659         // h = leaf node count
660         // H = leaf node label (lH)
661 
662         final int n = frequencies.length;
663 
664         // k = ceil(log2(m))
665         final int k = 64 - Long.numberOfLeadingZeros(m - 1);
666         // r = a(n+1) = 2^k - m
667         final long r = (1L << k) - m;
668 
669         // Note:
670         // A sparse matrix can often be used for H, as most of its entries are empty.
671         // This implementation uses a 1D array for efficiency at the cost of memory.
672         // This is limited to approximately ((2^31 - 1) / k), e.g. 34087042 when the sum of
673         // observations is large enough to create k=63.
674         // This could be handled using a 2D array. In practice a number of categories this
675         // large is not expected and is currently not supported.
676         final int[] h = new int[k];
677         final int[] lH = new int[checkArraySize((n + 1L) * k)];
678         Arrays.fill(lH, NO_LABEL);
679 
680         int d;
681         for (int j = 0; j < k; j++) {
682             final int shift = (k - 1) - j;
683             final long bitMask = 1L << shift;
684 
685             d = 0;
686             for (final int i : indices) {
687                 // bool w ← (a[i] >> (k − 1) − j)) & 1
688                 // h[j] = h[j] + w
689                 // if w then:
690                 if ((frequencies[i] & bitMask) != 0) {
691                     h[j]++;
692                     // H[d][j] = i
693                     lH[d * k + j] = i;
694                     d++;
695                 }
696             }
697             // process a(n+1) without extending the input frequencies array by 1
698             if ((r & bitMask) != 0) {
699                 h[j]++;
700                 lH[d * k + j] = n;
701             }
702         }
703 
704         return new FLDRSampler(rng, n, k, h, lH);
705     }
706 
707     /**
708      * Creates the sampler. Frequencies are represented as a 53-bit value with a
709      * left-shift offset.
710      * <pre>
711      * BigInteger.valueOf(value).shiftLeft(offset)
712      * </pre>
713      *
714      * <p>It is assumed the frequencies are all positive.
715      *
716      * @param rng Generator of uniformly distributed random numbers.
717      * @param frequencies Observed frequencies of the discrete distribution.
718      * @param offsets Left shift offsets (must be positive).
719      * @param indices Indices of non-zero frequencies.
720      * @param m Sum of the frequencies.
721      * @return the sampler
722      */
723     private static FastLoadedDiceRollerDiscreteSampler createSampler(UniformRandomProvider rng,
724                                                                      long[] frequencies,
725                                                                      int[] offsets,
726                                                                      int[] indices,
727                                                                      BigInteger m) {
728         // Repeat the logic from createSampler(...) using extended arithmetic to test the bits
729 
730         // ALGORITHM 5: PREPROCESS
731         // a == frequencies
732         // m = sum(a)
733         // h = leaf node count
734         // H = leaf node label (lH)
735 
736         final int n = frequencies.length;
737 
738         // k = ceil(log2(m))
739         final int k = m.subtract(BigInteger.ONE).bitLength();
740         // r = a(n+1) = 2^k - m
741         final BigInteger r = BigInteger.ONE.shiftLeft(k).subtract(m);
742 
743         final int[] h = new int[k];
744         final int[] lH = new int[checkArraySize((n + 1L) * k)];
745         Arrays.fill(lH, NO_LABEL);
746 
747         int d;
748         for (int j = 0; j < k; j++) {
749             final int shift = (k - 1) - j;
750 
751             d = 0;
752             for (final int i : indices) {
753                 // bool w ← (a[i] >> (k − 1) − j)) & 1
754                 // h[j] = h[j] + w
755                 // if w then:
756                 if (testBit(frequencies[i], offsets[i], shift)) {
757                     h[j]++;
758                     // H[d][j] = i
759                     lH[d * k + j] = i;
760                     d++;
761                 }
762             }
763             // process a(n+1) without extending the input frequencies array by 1
764             if (r.testBit(shift)) {
765                 h[j]++;
766                 lH[d * k + j] = n;
767             }
768         }
769 
770         return new FLDRSampler(rng, n, k, h, lH);
771     }
772 
773     /**
774      * Test the logical bit of the shifted integer representation.
775      * The value is assumed to have at most 53-bits of information. The offset
776      * is assumed to be positive. This is functionally equivalent to:
777      * <pre>
778      * BigInteger.valueOf(value).shiftLeft(offset).testBit(n)
779      * </pre>
780      *
781      * @param value 53-bit value.
782      * @param offset Left shift offset.
783      * @param n Index of bit to test.
784      * @return true if the bit is 1
785      */
786     private static boolean testBit(long value, int offset, int n) {
787         if (n < offset) {
788             // All logical trailing bits are zero
789             return false;
790         }
791         // Test if outside the 53-bit value (note that the implicit 1 bit
792         // has been added to the 52-bit mantissas for 'normal' floating-point numbers).
793         final int bit = n - offset;
794         return bit <= MANTISSA_SIZE && (value & (1L << bit)) != 0;
795     }
796 
797     /**
798      * Check the weights have a non-zero length.
799      *
800      * @param weights Weights.
801      * @return the length
802      */
803     private static int checkWeightsNonZeroLength(double[] weights) {
804         if (weights == null || weights.length == 0) {
805             throw new IllegalArgumentException("weights must contain at least 1 value");
806         }
807         return weights.length;
808     }
809 
810     /**
811      * Create the indices of non-zero values.
812      *
813      * @param values Values.
814      * @return the indices
815      */
816     private static int[] indicesOfNonZero(long[] values) {
817         int n = 0;
818         final int[] indices = new int[values.length];
819         for (int i = 0; i < values.length; i++) {
820             if (values[i] != 0) {
821                 indices[n++] = i;
822             }
823         }
824         return Arrays.copyOf(indices, n);
825     }
826 
827     /**
828      * Find the index of the first non-zero frequency.
829      *
830      * @param frequencies Frequencies.
831      * @return the index
832      * @throws IllegalStateException if all frequencies are zero.
833      */
834     static int indexOfNonZero(long[] frequencies) {
835         for (int i = 0; i < frequencies.length; i++) {
836             if (frequencies[i] != 0) {
837                 return i;
838             }
839         }
840         throw new IllegalStateException("All frequencies are zero");
841     }
842 
843     /**
844      * Check the size is valid for a 1D array.
845      *
846      * @param size Size
847      * @return the size as an {@code int}
848      * @throws IllegalArgumentException if the size is too large for a 1D array.
849      */
850     static int checkArraySize(long size) {
851         if (size > MAX_ARRAY_SIZE) {
852             throw new IllegalArgumentException("Unable to allocate array of size: " + size);
853         }
854         return (int) size;
855     }
856 }