001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.rng.sampling.distribution;
018
019import java.math.BigInteger;
020import java.util.Arrays;
021import org.apache.commons.rng.UniformRandomProvider;
022
023/**
024 * Distribution sampler that uses the Fast Loaded Dice Roller (FLDR). It can be used to
025 * sample from {@code n} values each with an associated relative weight. If all unique items
026 * are assigned the same weight it is more efficient to use the {@link DiscreteUniformSampler}.
027 *
028 * <p>Given a list {@code L} of {@code n} positive numbers,
029 * where {@code L[i]} represents the relative weight of the {@code i}th side, FLDR returns
030 * integer {@code i} with relative probability {@code L[i]}.
031 *
032 * <p>FLDR produces <em>exact</em> samples from the specified probability distribution.
033 * <ul>
034 *   <li>For integer weights, the probability of returning {@code i} is precisely equal to the
035 *   rational number {@code L[i] / m}, where {@code m} is the sum of {@code L}.
036 *   <li>For floating-points weights, each weight {@code L[i]} is converted to the
037 *   corresponding rational number {@code p[i] / q[i]} where {@code p[i]} is a positive integer and
038 *   {@code q[i]} is a power of 2. The rational weights are then normalized (exactly) to sum to unity.
039 * </ul>
040 *
041 * <p>Note that if <em>exact</em> samples are not required then an alternative sampler that
042 * ignores very small relative weights may have improved sampling performance.
043 *
044 * <p>This implementation is based on the algorithm in:
045 *
046 * <blockquote>
047 *  Feras A. Saad, Cameron E. Freer, Martin C. Rinard, and Vikash K. Mansinghka.
048 *  The Fast Loaded Dice Roller: A Near-Optimal Exact Sampler for Discrete Probability
049 *  Distributions. In AISTATS 2020: Proceedings of the 23rd International Conference on
050 *  Artificial Intelligence and Statistics, Proceedings of Machine Learning Research 108,
051 *  Palermo, Sicily, Italy, 2020.
052 * </blockquote>
053 *
054 * <p>Sampling uses {@link UniformRandomProvider#nextInt()} as the source of random bits.
055 *
056 * @see <a href="https://arxiv.org/abs/2003.03830">Saad et al (2020)
057 * Proceedings of the 23rd International Conference on Artificial Intelligence and Statistics,
058 * PMLR 108:1036-1046.</a>
059 * @since 1.5
060 */
061public abstract class FastLoadedDiceRollerDiscreteSampler
062    implements SharedStateDiscreteSampler {
063    /**
064     * The maximum size of an array.
065     *
066     * <p>This value is taken from the limit in Open JDK 8 {@code java.util.ArrayList}.
067     * It allows VMs to reserve some header words in an array.
068     */
069    private static final int MAX_ARRAY_SIZE = Integer.MAX_VALUE - 8;
070    /** The maximum biased exponent for a finite double.
071     * This is offset by 1023 from {@code Math.getExponent(Double.MAX_VALUE)}. */
072    private static final int MAX_BIASED_EXPONENT = 2046;
073    /** Size of the mantissa of a double. Equal to 52 bits. */
074    private static final int MANTISSA_SIZE = 52;
075    /** Mask to extract the 52-bit mantissa from a long representation of a double. */
076    private static final long MANTISSA_MASK = 0x000f_ffff_ffff_ffffL;
077    /** BigInteger representation of {@link Long#MAX_VALUE}. */
078    private static final BigInteger MAX_LONG = BigInteger.valueOf(Long.MAX_VALUE);
079    /** The maximum offset that will avoid loss of bits for a left shift of a 53-bit value.
080     * The value will remain positive for any shift {@code <=} this value. */
081    private static final int MAX_OFFSET = 10;
082    /** Initial value for no leaf node label. */
083    private static final int NO_LABEL = Integer.MAX_VALUE;
084    /** Name of the sampler. */
085    private static final String SAMPLER_NAME = "Fast Loaded Dice Roller";
086
087    /**
088     * Class to handle the edge case of observations in only one category.
089     */
090    private static final class FixedValueDiscreteSampler extends FastLoadedDiceRollerDiscreteSampler {
091        /** The sample value. */
092        private final int sampleValue;
093
094        /**
095         * @param sampleValue Sample value.
096         */
097        FixedValueDiscreteSampler(int sampleValue) {
098            this.sampleValue = sampleValue;
099        }
100
101        @Override
102        public int sample() {
103            return sampleValue;
104        }
105
106        @Override
107        public FastLoadedDiceRollerDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
108            return this;
109        }
110
111        @Override
112        public String toString() {
113            return SAMPLER_NAME;
114        }
115    }
116
117    /**
118     * Class to implement the FLDR sample algorithm.
119     */
120    private static final class FLDRSampler extends FastLoadedDiceRollerDiscreteSampler {
121        /** Empty boolean source. This is the location of the sign-bit after 31 right shifts on
122         * the boolean source. */
123        private static final int EMPTY_BOOL_SOURCE = 1;
124
125        /** Underlying source of randomness. */
126        private final UniformRandomProvider rng;
127        /** Number of categories. */
128        private final int n;
129        /** Number of levels in the discrete distribution generating (DDG) tree.
130         * Equal to {@code ceil(log2(m))} where {@code m} is the sum of observations. */
131        private final int k;
132        /** Number of leaf nodes at each level. */
133        private final int[] h;
134        /** Stores the leaf node labels in increasing order. Named {@code H} in the FLDR paper. */
135        private final int[] lH;
136
137        /**
138         * Provides a bit source for booleans.
139         *
140         * <p>A cached value from a call to {@link UniformRandomProvider#nextInt()}.
141         *
142         * <p>Only stores 31-bits when full as 1 bit has already been consumed.
143         * The sign bit is a flag that shifts down so the source eventually equals 1
144         * when all bits are consumed and will trigger a refill.
145         */
146        private int booleanSource = EMPTY_BOOL_SOURCE;
147
148        /**
149         * Creates a sampler.
150         *
151         * <p>The input parameters are not validated and must be correctly computed tables.
152         *
153         * @param rng Generator of uniformly distributed random numbers.
154         * @param n Number of categories
155         * @param k Number of levels in the discrete distribution generating (DDG) tree.
156         * Equal to {@code ceil(log2(m))} where {@code m} is the sum of observations.
157         * @param h Number of leaf nodes at each level.
158         * @param lH Stores the leaf node labels in increasing order.
159         */
160        FLDRSampler(UniformRandomProvider rng,
161                    int n,
162                    int k,
163                    int[] h,
164                    int[] lH) {
165            this.rng = rng;
166            this.n = n;
167            this.k = k;
168            // Deliberate direct storage of input arrays
169            this.h = h;
170            this.lH = lH;
171        }
172
173        /**
174         * Creates a copy with a new source of randomness.
175         *
176         * @param rng Generator of uniformly distributed random numbers.
177         * @param source Source to copy.
178         */
179        private FLDRSampler(UniformRandomProvider rng,
180                            FLDRSampler source) {
181            this.rng = rng;
182            this.n = source.n;
183            this.k = source.k;
184            this.h = source.h;
185            this.lH = source.lH;
186        }
187
188        /** {@inheritDoc} */
189        @Override
190        public int sample() {
191            // ALGORITHM 5: SAMPLE
192            int c = 0;
193            int d = 0;
194            for (;;) {
195                // b = flip()
196                // d = 2 * d + (1 - b)
197                d = (d << 1) + flip();
198                if (d < h[c]) {
199                    // z = H[d][c]
200                    final int z = lH[d * k + c];
201                    // assert z != NO_LABEL
202                    if (z < n) {
203                        return z;
204                    }
205                    d = 0;
206                    c = 0;
207                } else {
208                    d = d - h[c];
209                    c++;
210                }
211            }
212        }
213
214        /**
215         * Provides a source of boolean bits.
216         *
217         * <p>Note: This replicates the boolean cache functionality of
218         * {@code o.a.c.rng.core.source32.IntProvider}. The method has been simplified to return
219         * an {@code int} value rather than a {@code boolean}.
220         *
221         * @return the bit (0 or 1)
222         */
223        private int flip() {
224            int bits = booleanSource;
225            if (bits == 1) {
226                // Refill
227                bits = rng.nextInt();
228                // Store a refill flag in the sign bit and the unused 31 bits, return lowest bit
229                booleanSource = Integer.MIN_VALUE | (bits >>> 1);
230                return bits & 0x1;
231            }
232            // Shift down eventually triggering refill, return current lowest bit
233            booleanSource = bits >>> 1;
234            return bits & 0x1;
235        }
236
237        /** {@inheritDoc} */
238        @Override
239        public String toString() {
240            return SAMPLER_NAME + " [" + rng.toString() + "]";
241        }
242
243        /** {@inheritDoc} */
244        @Override
245        public FastLoadedDiceRollerDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
246            return new FLDRSampler(rng, this);
247        }
248    }
249
250    /** Package-private constructor. */
251    FastLoadedDiceRollerDiscreteSampler() {
252        // Intentionally empty
253    }
254
255    /** {@inheritDoc} */
256    // Redeclare the signature to return a FastLoadedDiceRollerSampler not a SharedStateLongSampler
257    @Override
258    public abstract FastLoadedDiceRollerDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng);
259
260    /**
261     * Creates a sampler.
262     *
263     * <p>Note: The discrete distribution generating (DDG) tree requires {@code (n + 1) * k} entries
264     * where {@code n} is the number of categories, {@code k == ceil(log2(m))} and {@code m}
265     * is the sum of the observed frequencies. An exception is raised if this cannot be allocated
266     * as a single array.
267     *
268     * <p>For reference the sum is limited to {@link Long#MAX_VALUE} and the value {@code k} to 63.
269     * The number of categories is limited to approximately {@code ((2^31 - 1) / k) = 34,087,042}
270     * when the sum of frequencies is large enough to create k=63.
271     *
272     * @param rng Generator of uniformly distributed random numbers.
273     * @param frequencies Observed frequencies of the discrete distribution.
274     * @return the sampler
275     * @throws IllegalArgumentException if {@code frequencies} is null or empty, a
276     * frequency is negative, the sum of all frequencies is either zero or
277     * above {@link Long#MAX_VALUE}, or the size of the discrete distribution generating tree
278     * is too large.
279     */
280    public static FastLoadedDiceRollerDiscreteSampler of(UniformRandomProvider rng,
281                                                         long[] frequencies) {
282        final long m = sum(frequencies);
283
284        // Obtain indices of non-zero frequencies
285        final int[] indices = indicesOfNonZero(frequencies);
286
287        // Edge case for 1 non-zero weight. This also handles edge case for 1 observation
288        // (as log2(m) == 0 will break the computation of the DDG tree).
289        if (indices.length == 1) {
290            return new FixedValueDiscreteSampler(indexOfNonZero(frequencies));
291        }
292
293        return createSampler(rng, frequencies, indices, m);
294    }
295
296    /**
297     * Creates a sampler.
298     *
299     * <p>Weights are converted to rational numbers {@code p / q} where {@code q} is a power of 2.
300     * The numerators {@code p} are scaled to use a common denominator before summing.
301     *
302     * <p>All weights are used to create the sampler. Weights with a small magnitude relative
303     * to the largest weight can be excluded using the constructor method with the
304     * relative magnitude parameter {@code alpha} (see {@link #of(UniformRandomProvider, double[], int)}).
305     *
306     * @param rng Generator of uniformly distributed random numbers.
307     * @param weights Weights of the discrete distribution.
308     * @return the sampler
309     * @throws IllegalArgumentException if {@code weights} is null or empty, a
310     * weight is negative, infinite or {@code NaN}, the sum of all weights is zero, or the size
311     * of the discrete distribution generating tree is too large.
312     * @see #of(UniformRandomProvider, double[], int)
313     */
314    public static FastLoadedDiceRollerDiscreteSampler of(UniformRandomProvider rng,
315                                                         double[] weights) {
316        return of(rng, weights, 0);
317    }
318
319    /**
320     * Creates a sampler.
321     *
322     * <p>Weights are converted to rational numbers {@code p / q} where {@code q} is
323     * a power of 2. The numerators {@code p} are scaled to use a common
324     * denominator before summing.
325     *
326     * <p>Note: The discrete distribution generating (DDG) tree requires
327     * {@code (n + 1) * k} entries where {@code n} is the number of categories,
328     * {@code k == ceil(log2(m))} and {@code m} is the sum of the weight numerators
329     * {@code q}. An exception is raised if this cannot be allocated as a single
330     * array.
331     *
332     * <p>For reference the value {@code k} is equal to or greater than the ratio of
333     * the largest to the smallest weight expressed as a power of 2. For
334     * {@code Double.MAX_VALUE / Double.MIN_VALUE} this is ~2098. The value
335     * {@code k} increases with the sum of the weight numerators. A number of
336     * weights in excess of 1,000,000 with values equal to {@link Double#MAX_VALUE}
337     * would be required to raise an exception when the minimum weight is
338     * {@link Double#MIN_VALUE}.
339     *
340     * <p>Weights with a small magnitude relative to the largest weight can be
341     * excluded using the relative magnitude parameter {@code alpha}. This will set
342     * any weight to zero if the magnitude is approximately 2<sup>alpha</sup>
343     * <em>smaller</em> than the largest weight. This comparison is made using only
344     * the exponent of the input weights. The {@code alpha} parameter is ignored if
345     * not above zero. Note that a small {@code alpha} parameter will exclude more
346     * weights than a large {@code alpha} parameter.
347     *
348     * <p>The alpha parameter can be used to exclude categories that
349     * have a very low probability of occurrence and will improve the construction
350     * performance of the sampler. The effect on sampling performance depends on
351     * the relative weights of the excluded categories; typically a high {@code alpha}
352     * is used to exclude categories that would be visited with a very low probability
353     * and the sampling performance is unchanged.
354     *
355     * <p><b>Implementation Note</b>
356     *
357     * <p>This method creates a sampler with <em>exact</em> samples from the
358     * specified probability distribution. It is recommended to use this method:
359     * <ul>
360     *  <li>if the weights are computed, for example from a probability mass function; or
361     *  <li>if the weights sum to an infinite value.
362     * </ul>
363     *
364     * <p>If the weights are computed from empirical observations then it is
365     * recommended to use the factory method
366     * {@link #of(UniformRandomProvider, long[]) accepting frequencies}. This
367     * requires the total number of observations to be representable as a long
368     * integer.
369     *
370     * <p>Note that if all weights are scaled by a power of 2 to be integers, and
371     * each integer can be represented as a positive 64-bit long value, then the
372     * sampler created using this method will match the output from a sampler
373     * created with the scaled weights converted to long values for the factory
374     * method {@link #of(UniformRandomProvider, long[]) accepting frequencies}. This
375     * assumes the sum of the integer values does not overflow.
376     *
377     * <p>It should be noted that the conversion of weights to rational numbers has
378     * a performance overhead during construction (sampling performance is not
379     * affected). This may be avoided by first converting them to integer values
380     * that can be summed without overflow. For example by scaling values by
381     * {@code 2^62 / sum} and converting to long by casting or rounding.
382     *
383     * <p>This approach may increase the efficiency of construction. The resulting
384     * sampler may no longer produce <em>exact</em> samples from the distribution.
385     * In particular any weights with a converted frequency of zero cannot be
386     * sampled.
387     *
388     * @param rng Generator of uniformly distributed random numbers.
389     * @param weights Weights of the discrete distribution.
390     * @param alpha Alpha parameter.
391     * @return the sampler
392     * @throws IllegalArgumentException if {@code weights} is null or empty, a
393     * weight is negative, infinite or {@code NaN}, the sum of all weights is zero,
394     * or the size of the discrete distribution generating tree is too large.
395     * @see #of(UniformRandomProvider, long[])
396     */
397    public static FastLoadedDiceRollerDiscreteSampler of(UniformRandomProvider rng,
398                                                         double[] weights,
399                                                         int alpha) {
400        final int n = checkWeightsNonZeroLength(weights);
401
402        // Convert floating-point double to a relative weight
403        // using a shifted integer representation
404        final long[] frequencies = new long[n];
405        final int[] offsets = new int[n];
406        convertToIntegers(weights, frequencies, offsets, alpha);
407
408        // Obtain indices of non-zero weights
409        final int[] indices = indicesOfNonZero(frequencies);
410
411        // Edge case for 1 non-zero weight.
412        if (indices.length == 1) {
413            return new FixedValueDiscreteSampler(indexOfNonZero(frequencies));
414        }
415
416        final BigInteger m = sum(frequencies, offsets, indices);
417
418        // Use long arithmetic if possible. This occurs when the weights are similar in magnitude.
419        if (m.compareTo(MAX_LONG) <= 0) {
420            // Apply the offset
421            for (int i = 0; i < n; i++) {
422                frequencies[i] <<= offsets[i];
423            }
424            return createSampler(rng, frequencies, indices, m.longValue());
425        }
426
427        return createSampler(rng, frequencies, offsets, indices, m);
428    }
429
430    /**
431     * Sum the frequencies.
432     *
433     * @param frequencies Frequencies.
434     * @return the sum
435     * @throws IllegalArgumentException if {@code frequencies} is null or empty, a
436     * frequency is negative, or the sum of all frequencies is either zero or above
437     * {@link Long#MAX_VALUE}
438     */
439    private static long sum(long[] frequencies) {
440        // Validate
441        if (frequencies == null || frequencies.length == 0) {
442            throw new IllegalArgumentException("frequencies must contain at least 1 value");
443        }
444
445        // Sum the values.
446        // Combine all the sign bits in the observations and the intermediate sum in a flag.
447        long m = 0;
448        long signFlag = 0;
449        for (final long o : frequencies) {
450            m += o;
451            signFlag |= o | m;
452        }
453
454        // Check for a sign-bit.
455        if (signFlag < 0) {
456            // One or more observations were negative, or the sum overflowed.
457            for (final long o : frequencies) {
458                if (o < 0) {
459                    throw new IllegalArgumentException("frequencies must contain positive values: " + o);
460                }
461            }
462            throw new IllegalArgumentException("Overflow when summing frequencies");
463        }
464        if (m == 0) {
465            throw new IllegalArgumentException("Sum of frequencies is zero");
466        }
467        return m;
468    }
469
470    /**
471     * Convert the floating-point weights to relative weights represented as
472     * integers {@code value * 2^exponent}. The relative weight as an integer is:
473     *
474     * <pre>
475     * BigInteger.valueOf(value).shiftLeft(exponent)
476     * </pre>
477     *
478     * <p>Note that the weights are created using a common power-of-2 scaling
479     * operation so the minimum exponent is zero.
480     *
481     * <p>A positive {@code alpha} parameter is used to set any weight to zero if
482     * the magnitude is approximately 2<sup>alpha</sup> <em>smaller</em> than the
483     * largest weight. This comparison is made using only the exponent of the input
484     * weights.
485     *
486     * @param weights Weights of the discrete distribution.
487     * @param values Output floating-point mantissas converted to 53-bit integers.
488     * @param exponents Output power of 2 exponent.
489     * @param alpha Alpha parameter.
490     * @throws IllegalArgumentException if a weight is negative, infinite or
491     * {@code NaN}, or the sum of all weights is zero.
492     */
493    private static void convertToIntegers(double[] weights, long[] values, int[] exponents, int alpha) {
494        int maxExponent = Integer.MIN_VALUE;
495        for (int i = 0; i < weights.length; i++) {
496            final double weight = weights[i];
497            // Ignore zero.
498            // When creating the integer value later using bit shifts the result will remain zero.
499            if (weight == 0) {
500                continue;
501            }
502            final long bits = Double.doubleToRawLongBits(weight);
503
504            // For the IEEE 754 format see Double.longBitsToDouble(long).
505
506            // Extract the exponent (with the sign bit)
507            int exp = (int) (bits >>> MANTISSA_SIZE);
508            // Detect negative, infinite or NaN.
509            // Note: Negative values sign bit will cause the exponent to be too high.
510            if (exp > MAX_BIASED_EXPONENT) {
511                throw new IllegalArgumentException("Invalid weight: " + weight);
512            }
513            long mantissa;
514            if (exp == 0) {
515                // Sub-normal number:
516                mantissa = (bits & MANTISSA_MASK) << 1;
517                // Here we convert to a normalised number by counting the leading zeros
518                // to obtain the number of shifts of the most significant bit in
519                // the mantissa that is required to get a 1 at position 53 (i.e. as
520                // if it were a normal number with assumed leading bit).
521                final int shift = Long.numberOfLeadingZeros(mantissa << 11);
522                mantissa <<= shift;
523                exp -= shift;
524            } else {
525                // Normal number. Add the implicit leading 1-bit.
526                mantissa = (bits & MANTISSA_MASK) | (1L << MANTISSA_SIZE);
527            }
528
529            // Here the floating-point number is equal to:
530            // mantissa * 2^(exp-1075)
531
532            values[i] = mantissa;
533            exponents[i] = exp;
534            maxExponent = Math.max(maxExponent, exp);
535        }
536
537        // No exponent indicates that all weights are zero
538        if (maxExponent == Integer.MIN_VALUE) {
539            throw new IllegalArgumentException("Sum of weights is zero");
540        }
541
542        filterWeights(values, exponents, alpha, maxExponent);
543        scaleWeights(values, exponents);
544    }
545
546    /**
547     * Filters small weights using the {@code alpha} parameter.
548     * A positive {@code alpha} parameter is used to set any weight to zero if
549     * the magnitude is approximately 2<sup>alpha</sup> <em>smaller</em> than the
550     * largest weight. This comparison is made using only the exponent of the input
551     * weights.
552     *
553     * @param values 53-bit values.
554     * @param exponents Power of 2 exponent.
555     * @param alpha Alpha parameter.
556     * @param maxExponent Maximum exponent.
557     */
558    private static void filterWeights(long[] values, int[] exponents, int alpha, int maxExponent) {
559        if (alpha > 0) {
560            // Filter weights. This must be done before the values are shifted so
561            // the exponent represents the approximate magnitude of the value.
562            for (int i = 0; i < exponents.length; i++) {
563                if (maxExponent - exponents[i] > alpha) {
564                    values[i] = 0;
565                }
566            }
567        }
568    }
569
570    /**
571     * Scale the weights represented as integers {@code value * 2^exponent} to use a
572     * minimum exponent of zero. The values are scaled to remove any common trailing zeros
573     * in their representation. This ultimately reduces the size of the discrete distribution
574     * generating (DGG) tree.
575     *
576     * @param values 53-bit values.
577     * @param exponents Power of 2 exponent.
578     */
579    private static void scaleWeights(long[] values, int[] exponents) {
580        // Find the minimum exponent and common trailing zeros.
581        int minExponent = Integer.MAX_VALUE;
582        for (int i = 0; i < exponents.length; i++) {
583            if (values[i] != 0) {
584                minExponent = Math.min(minExponent, exponents[i]);
585            }
586        }
587        // Trailing zeros occur when the original weights have a representation with
588        // less than 52 binary digits, e.g. {1.5, 0.5, 0.25}.
589        int trailingZeros = Long.SIZE;
590        for (int i = 0; i < values.length && trailingZeros != 0; i++) {
591            trailingZeros = Math.min(trailingZeros, Long.numberOfTrailingZeros(values[i]));
592        }
593        // Scale by a power of 2 so the minimum exponent is zero.
594        for (int i = 0; i < exponents.length; i++) {
595            exponents[i] -= minExponent;
596        }
597        // Remove common trailing zeros.
598        if (trailingZeros != 0) {
599            for (int i = 0; i < values.length; i++) {
600                values[i] >>>= trailingZeros;
601            }
602        }
603    }
604
605    /**
606     * Sum the integers at the specified indices.
607     * Integers are represented as {@code value * 2^exponent}.
608     *
609     * @param values 53-bit values.
610     * @param exponents Power of 2 exponent.
611     * @param indices Indices to sum.
612     * @return the sum
613     */
614    private static BigInteger sum(long[] values, int[] exponents, int[] indices) {
615        BigInteger m = BigInteger.ZERO;
616        for (final int i : indices) {
617            m = m.add(toBigInteger(values[i], exponents[i]));
618        }
619        return m;
620    }
621
622    /**
623     * Convert the value and left shift offset to a BigInteger.
624     * It is assumed the value is at most 53-bits. This allows optimising the left
625     * shift if it is below 11 bits.
626     *
627     * @param value 53-bit value.
628     * @param offset Left shift offset (must be positive).
629     * @return the BigInteger
630     */
631    private static BigInteger toBigInteger(long value, int offset) {
632        // Ignore zeros. The sum method uses indices of non-zero values.
633        if (offset <= MAX_OFFSET) {
634            // Assume (value << offset) <= Long.MAX_VALUE
635            return BigInteger.valueOf(value << offset);
636        }
637        return BigInteger.valueOf(value).shiftLeft(offset);
638    }
639
640    /**
641     * Creates the sampler.
642     *
643     * <p>It is assumed the frequencies are all positive and the sum does not
644     * overflow.
645     *
646     * @param rng Generator of uniformly distributed random numbers.
647     * @param frequencies Observed frequencies of the discrete distribution.
648     * @param indices Indices of non-zero frequencies.
649     * @param m Sum of the frequencies.
650     * @return the sampler
651     */
652    private static FastLoadedDiceRollerDiscreteSampler createSampler(UniformRandomProvider rng,
653                                                                     long[] frequencies,
654                                                                     int[] indices,
655                                                                     long m) {
656        // ALGORITHM 5: PREPROCESS
657        // a == frequencies
658        // m = sum(a)
659        // h = leaf node count
660        // H = leaf node label (lH)
661
662        final int n = frequencies.length;
663
664        // k = ceil(log2(m))
665        final int k = 64 - Long.numberOfLeadingZeros(m - 1);
666        // r = a(n+1) = 2^k - m
667        final long r = (1L << k) - m;
668
669        // Note:
670        // A sparse matrix can often be used for H, as most of its entries are empty.
671        // This implementation uses a 1D array for efficiency at the cost of memory.
672        // This is limited to approximately ((2^31 - 1) / k), e.g. 34087042 when the sum of
673        // observations is large enough to create k=63.
674        // This could be handled using a 2D array. In practice a number of categories this
675        // large is not expected and is currently not supported.
676        final int[] h = new int[k];
677        final int[] lH = new int[checkArraySize((n + 1L) * k)];
678        Arrays.fill(lH, NO_LABEL);
679
680        int d;
681        for (int j = 0; j < k; j++) {
682            final int shift = (k - 1) - j;
683            final long bitMask = 1L << shift;
684
685            d = 0;
686            for (final int i : indices) {
687                // bool w ← (a[i] >> (k − 1) − j)) & 1
688                // h[j] = h[j] + w
689                // if w then:
690                if ((frequencies[i] & bitMask) != 0) {
691                    h[j]++;
692                    // H[d][j] = i
693                    lH[d * k + j] = i;
694                    d++;
695                }
696            }
697            // process a(n+1) without extending the input frequencies array by 1
698            if ((r & bitMask) != 0) {
699                h[j]++;
700                lH[d * k + j] = n;
701            }
702        }
703
704        return new FLDRSampler(rng, n, k, h, lH);
705    }
706
707    /**
708     * Creates the sampler. Frequencies are represented as a 53-bit value with a
709     * left-shift offset.
710     * <pre>
711     * BigInteger.valueOf(value).shiftLeft(offset)
712     * </pre>
713     *
714     * <p>It is assumed the frequencies are all positive.
715     *
716     * @param rng Generator of uniformly distributed random numbers.
717     * @param frequencies Observed frequencies of the discrete distribution.
718     * @param offsets Left shift offsets (must be positive).
719     * @param indices Indices of non-zero frequencies.
720     * @param m Sum of the frequencies.
721     * @return the sampler
722     */
723    private static FastLoadedDiceRollerDiscreteSampler createSampler(UniformRandomProvider rng,
724                                                                     long[] frequencies,
725                                                                     int[] offsets,
726                                                                     int[] indices,
727                                                                     BigInteger m) {
728        // Repeat the logic from createSampler(...) using extended arithmetic to test the bits
729
730        // ALGORITHM 5: PREPROCESS
731        // a == frequencies
732        // m = sum(a)
733        // h = leaf node count
734        // H = leaf node label (lH)
735
736        final int n = frequencies.length;
737
738        // k = ceil(log2(m))
739        final int k = m.subtract(BigInteger.ONE).bitLength();
740        // r = a(n+1) = 2^k - m
741        final BigInteger r = BigInteger.ONE.shiftLeft(k).subtract(m);
742
743        final int[] h = new int[k];
744        final int[] lH = new int[checkArraySize((n + 1L) * k)];
745        Arrays.fill(lH, NO_LABEL);
746
747        int d;
748        for (int j = 0; j < k; j++) {
749            final int shift = (k - 1) - j;
750
751            d = 0;
752            for (final int i : indices) {
753                // bool w ← (a[i] >> (k − 1) − j)) & 1
754                // h[j] = h[j] + w
755                // if w then:
756                if (testBit(frequencies[i], offsets[i], shift)) {
757                    h[j]++;
758                    // H[d][j] = i
759                    lH[d * k + j] = i;
760                    d++;
761                }
762            }
763            // process a(n+1) without extending the input frequencies array by 1
764            if (r.testBit(shift)) {
765                h[j]++;
766                lH[d * k + j] = n;
767            }
768        }
769
770        return new FLDRSampler(rng, n, k, h, lH);
771    }
772
773    /**
774     * Test the logical bit of the shifted integer representation.
775     * The value is assumed to have at most 53-bits of information. The offset
776     * is assumed to be positive. This is functionally equivalent to:
777     * <pre>
778     * BigInteger.valueOf(value).shiftLeft(offset).testBit(n)
779     * </pre>
780     *
781     * @param value 53-bit value.
782     * @param offset Left shift offset.
783     * @param n Index of bit to test.
784     * @return true if the bit is 1
785     */
786    private static boolean testBit(long value, int offset, int n) {
787        if (n < offset) {
788            // All logical trailing bits are zero
789            return false;
790        }
791        // Test if outside the 53-bit value (note that the implicit 1 bit
792        // has been added to the 52-bit mantissas for 'normal' floating-point numbers).
793        final int bit = n - offset;
794        return bit <= MANTISSA_SIZE && (value & (1L << bit)) != 0;
795    }
796
797    /**
798     * Check the weights have a non-zero length.
799     *
800     * @param weights Weights.
801     * @return the length
802     */
803    private static int checkWeightsNonZeroLength(double[] weights) {
804        if (weights == null || weights.length == 0) {
805            throw new IllegalArgumentException("weights must contain at least 1 value");
806        }
807        return weights.length;
808    }
809
810    /**
811     * Create the indices of non-zero values.
812     *
813     * @param values Values.
814     * @return the indices
815     */
816    private static int[] indicesOfNonZero(long[] values) {
817        int n = 0;
818        final int[] indices = new int[values.length];
819        for (int i = 0; i < values.length; i++) {
820            if (values[i] != 0) {
821                indices[n++] = i;
822            }
823        }
824        return Arrays.copyOf(indices, n);
825    }
826
827    /**
828     * Find the index of the first non-zero frequency.
829     *
830     * @param frequencies Frequencies.
831     * @return the index
832     * @throws IllegalStateException if all frequencies are zero.
833     */
834    static int indexOfNonZero(long[] frequencies) {
835        for (int i = 0; i < frequencies.length; i++) {
836            if (frequencies[i] != 0) {
837                return i;
838            }
839        }
840        throw new IllegalStateException("All frequencies are zero");
841    }
842
843    /**
844     * Check the size is valid for a 1D array.
845     *
846     * @param size Size
847     * @return the size as an {@code int}
848     * @throws IllegalArgumentException if the size is too large for a 1D array.
849     */
850    static int checkArraySize(long size) {
851        if (size > MAX_ARRAY_SIZE) {
852            throw new IllegalArgumentException("Unable to allocate array of size: " + size);
853        }
854        return (int) size;
855    }
856}