001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.rng.sampling.distribution;
018
019import org.apache.commons.rng.UniformRandomProvider;
020
021import java.util.Arrays;
022
023/**
024 * Distribution sampler that uses the <a
025 * href="https://en.wikipedia.org/wiki/Alias_method">Alias method</a>. It can be used to
026 * sample from {@code n} values each with an associated probability. If all unique items
027 * are assigned the same probability it is more efficient to use the {@link DiscreteUniformSampler}.
028 *
029 * <p>This implementation is based on the detailed explanation of the alias method by
030 * Keith Schartz and implements Vose's algorithm.</p>
031 *
032 * <ul>
033 *  <li>
034 *   <blockquote>
035 *    Vose, M.D.,
036 *    <i>A linear algorithm for generating random numbers with a given distribution,</i>
037 *     IEEE Transactions on Software Engineering, 17, 972-975, 1991.
038 *   </blockquote>
039 *  </li>
040 * </ul>
041 *
042 * <p>The algorithm will sample values in {@code O(1)} time after a pre-processing step of
043 * {@code O(n)} time.</p>
044 *
045 * <p>The alias tables are constructed using fraction probabilities with an assumed denominator
046 * of 2<sup>53</sup>. In the generic case sampling uses {@link UniformRandomProvider#nextInt(int)}
047 * and the upper 53-bits from {@link UniformRandomProvider#nextLong()}.</p>
048 *
049 * <p>Zero padding the input probabilities can be used to make more sampling more efficient.
050 * Any zero entry will always be aliased removing the requirement to compute a {@code long}.
051 * Increased sampling speed comes at the cost of increased storage space. The algorithm requires
052 * approximately 12 bytes of storage per input probability, that is {@code n * 12} for size
053 * {@code n}. Zero-padding only requires 4 bytes of storage per padded value as the probability is
054 * known to be zero. A table can be padded to a power of 2 using the utility function
055 * {@link #of(UniformRandomProvider, double[], int)} to construct the sampler.</p>
056 *
057 * <p>An optimisation is performed for small table sizes that are a power of 2. In this case the
058 * sampling uses 1 or 2 calls from {@link UniformRandomProvider#nextInt()} to generate up to
059 * 64-bits for creation of an 11-bit index and 53-bits for the {@code long}. This optimisation
060 * requires a generator with a high cycle length for the lower order bits.</p>
061 *
062 * <p>Larger table sizes that are a power of 2 will benefit from fast algorithms for
063 * {@link UniformRandomProvider#nextInt(int)} that exploit the power of 2.</p>
064 *
065 * @see <a href="https://en.wikipedia.org/wiki/Alias_method">Alias Method</a>
066 * @see <a href="http://www.keithschwarz.com/darts-dice-coins/">Darts, Dice, and Coins:
067 * Sampling from a Discrete Distribution by Keith Schwartz</a>
068 * @see <a href="https://ieeexplore.ieee.org/document/92917">Vose (1991) IEEE Transactions
069 * on Software Engineering 17, 972-975.</a>
070 * @since 1.3
071 */
072public class AliasMethodDiscreteSampler
073    implements SharedStateDiscreteSampler {
074    /**
075     * The default alpha factor for zero-padding an input probability table. The default
076     * value will pad the probabilities by to the next power-of-2.
077     */
078    private static final int DEFAULT_ALPHA = 0;
079    /** The value zero for a {@code double}. */
080    private static final double ZERO = 0.0;
081    /** The value 1.0 represented as the numerator of a fraction with denominator 2<sup>53</sup>. */
082    private static final long ONE_AS_NUMERATOR = 1L << 53;
083    /**
084     * The multiplier to convert a {@code double} probability in the range {@code [0, 1]}
085     * to the numerator of a fraction with denominator 2<sup>53</sup>.
086     */
087    private static final double CONVERT_TO_NUMERATOR = ONE_AS_NUMERATOR;
088    /**
089     * The maximum size of the small alias table. This is 2<sup>11</sup>.
090     */
091    private static final int MAX_SMALL_POWER_2_SIZE = 1 << 11;
092
093    /** Underlying source of randomness. */
094    protected final UniformRandomProvider rng;
095
096    /**
097     * The probability table. During sampling a random index into this table is selected.
098     * A random probability is compared to the value at this index: if lower then the sample is the
099     * index; if higher then the sample uses the corresponding entry in the alias table.
100     *
101     * <p>This has entries up to the last non-zero element since there is no need to store
102     * probabilities of zero. This is an optimisation for zero-padded input. Any zero value will
103     * always be aliased so any look-up index outside this table always uses the alias.</p>
104     *
105     * <p>Note that a uniform double in the range [0,1) can be generated using 53-bits from a long
106     * to sample all the dyadic rationals with a denominator of 2<sup>53</sup>
107     * (e.g. see org.apache.commons.rng.core.utils.NumberFactory.makeDouble(long)). To avoid
108     * computation of a double and comparison to the probability as a double the probabilities are
109     * stored as 53-bit longs to use integer arithmetic. This is the equivalent of storing the
110     * numerator of a fraction with the denominator of 2<sup>53</sup>.</p>
111     *
112     * <p>During conversion of the probability to a double it is rounded up to the next integer
113     * value. This ensures the functionality of comparing a uniform deviate distributed evenly on
114     * the interval 1/2^53 to the unevenly distributed probability is equivalent, i.e. a uniform
115     * deviate is either below the probability or above it:
116     *
117     * <pre>
118     * Uniform deviate
119     *  1/2^53    2/2^53    3/2^53    4/2^53
120     * --|---------|---------|---------|---
121     *      ^
122     *      |
123     *  probability
124     *             ^
125     *             |
126     *         rounded up
127     * </pre>
128     *
129     * <p>Round-up ensures a non-zero probability is always non-zero and zero probability remains
130     * zero. Thus any item with a non-zero input probability can always be sampled, and a zero
131     * input probability cannot be sampled.</p>
132     *
133     * @see <a href="https://en.wikipedia.org/wiki/Dyadic_rational">Dyadic rational</a>
134     */
135    protected final long[] probability;
136
137    /**
138     * The alias table. During sampling if the random probability is not below the entry in the
139     * probability table then the sample is the alias.
140     */
141    protected final int[] alias;
142
143    /**
144     * Sample from the computed tables exploiting the small power-of-two table size.
145     * This implements a variant of the optimised algorithm as per Vose (1991):
146     *
147     * <pre>
148     * bits = obtained required number of random bits
149     * v = (some of the bits) * constant1
150     * j = (rest of the bits) * constant2
151     * if v &lt; prob[j] then
152     *   return j
153     * else
154     *   return alias[j]
155     * </pre>
156     *
157     * <p>This is a variant because the bits are not multiplied by constants. In the case of
158     * {@code v} the constant is a scale that is pre-applied to the probability table. In the
159     * case of {@code j} the constant is not used to scale a deviate to an index; the index is
160     * from a power-of-2 range and so the bits are used directly.</p>
161     *
162     * <p>This is implemented using up to 64 bits from the random generator.
163     * The index for the table is computed using a mask to extract up to 11 of the lower bits
164     * from an integer. The probability is computed using a second integer combined with the
165     * remaining bits to create 53-bits for the numerator of a fraction with denominator
166     * 2<sup>53</sup>. This is only computed on demand.</p>
167     *
168     * <p>Note: This supports a table size of up to 2^11, or 2048, exclusive. Any larger requires
169     * consuming more than 64-bits and the algorithm is not more efficient than the
170     * {@link AliasMethodDiscreteSampler}.</p>
171     *
172     * <p>Sampling uses 1 or 2 calls to {@link UniformRandomProvider#nextInt()}.</p>
173     */
174    private static final class SmallTableAliasMethodDiscreteSampler extends AliasMethodDiscreteSampler {
175        /** The mask to isolate the lower bits. */
176        private final int mask;
177
178        /**
179         * Create a new instance.
180         *
181         * @param rng Generator of uniformly distributed random numbers.
182         * @param probability Probability table.
183         * @param alias Alias table.
184         */
185        SmallTableAliasMethodDiscreteSampler(final UniformRandomProvider rng,
186                                             final long[] probability,
187                                             final int[] alias) {
188            super(rng, probability, alias);
189            // Assume the table size is a power of 2 and create the mask
190            mask = alias.length - 1;
191        }
192
193        @Override
194        public int sample() {
195            final int bits = rng.nextInt();
196            // Isolate lower bits
197            final int j = bits & mask;
198
199            // Optimisation for zero-padded input tables
200            if (j >= probability.length) {
201                // No probability must use the alias
202                return alias[j];
203            }
204
205            // Create a uniform random deviate as a long.
206            // This replicates functionality from the o.a.c.rng.core.utils.NumberFactory.makeLong
207            final long longBits = (((long) rng.nextInt()) << 32) | (bits & 0xffffffffL);
208
209            // Choose between the two. Use a 53-bit long for the probability.
210            return (longBits >>> 11) < probability[j] ? j : alias[j];
211        }
212
213        /** {@inheritDoc} */
214        @Override
215        public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
216            return new SmallTableAliasMethodDiscreteSampler(rng, probability, alias);
217        }
218    }
219
220    /**
221     * Creates a sampler.
222     *
223     * <p>The input parameters are not validated and must be correctly computed alias tables.</p>
224     *
225     * @param rng Generator of uniformly distributed random numbers.
226     * @param probability Probability table.
227     * @param alias Alias table.
228     */
229    AliasMethodDiscreteSampler(final UniformRandomProvider rng,
230                               final long[] probability,
231                               final int[] alias) {
232        this.rng = rng;
233        // Deliberate direct storage of input arrays
234        this.probability = probability;
235        this.alias = alias;
236    }
237
238    /** {@inheritDoc} */
239    @Override
240    public int sample() {
241        // This implements the algorithm as per Vose (1991):
242        // v = uniform()  in [0, 1)
243        // j = uniform(n) in [0, n)
244        // if v < prob[j] then
245        //   return j
246        // else
247        //   return alias[j]
248
249        final int j = rng.nextInt(alias.length);
250
251        // Optimisation for zero-padded input tables
252        if (j >= probability.length) {
253            // No probability must use the alias
254            return alias[j];
255        }
256
257        // Note: We could check the probability before computing a deviate.
258        // p(j) == 0  => alias[j]
259        // p(j) == 1  => j
260        // However it is assumed these edge cases are rare:
261        //
262        // The probability table will be 1 for approximately 1/n samples, i.e. only the
263        // last unpaired probability. This is only worth checking for when the table size (n)
264        // is small. But in that case the user should zero-pad the table for performance.
265        //
266        // The probability table will be 0 when an input probability was zero. We
267        // will assume this is also rare if modelling a discrete distribution where
268        // all samples are possible. The edge case for zero-padded tables is handled above.
269
270        // Choose between the two. Use a 53-bit long for the probability.
271        return (rng.nextLong() >>> 11) < probability[j] ? j : alias[j];
272    }
273
274    /** {@inheritDoc} */
275    @Override
276    public String toString() {
277        return "Alias method [" + rng.toString() + "]";
278    }
279
280    /** {@inheritDoc} */
281    @Override
282    public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
283        return new AliasMethodDiscreteSampler(rng, probability, alias);
284    }
285
286    /**
287     * Creates a sampler.
288     *
289     * <p>The probabilities will be normalised using their sum. The only requirement
290     * is the sum is strictly positive.</p>
291     *
292     * <p>Where possible this method zero-pads the probabilities so the length is the next
293     * power-of-two. Padding is bounded by the upper limit on the size of an array.</p>
294     *
295     * <p>To avoid zero-padding use the
296     * {@link #of(UniformRandomProvider, double[], int)} method with a negative
297     * {@code alpha} factor.</p>
298     *
299     * @param rng Generator of uniformly distributed random numbers.
300     * @param probabilities The list of probabilities.
301     * @return the sampler
302     * @throws IllegalArgumentException if {@code probabilities} is null or empty, a
303     * probability is negative, infinite or {@code NaN}, or the sum of all
304     * probabilities is not strictly positive.
305     * @see #of(UniformRandomProvider, double[], int)
306     */
307    public static SharedStateDiscreteSampler of(final UniformRandomProvider rng,
308                                                final double[] probabilities) {
309        return of(rng, probabilities, DEFAULT_ALPHA);
310    }
311
312    /**
313     * Creates a sampler.
314     *
315     * <p>The probabilities will be normalised using their sum. The only requirement
316     * is the sum is strictly positive.</p>
317     *
318     * <p>Where possible this method zero-pads the probabilities to improve sampling
319     * efficiency. Padding is bounded by the upper limit on the size of an array and
320     * controlled by the {@code alpha} argument. Set to negative to disable
321     * padding.</p>
322     *
323     * <p>For each zero padded value an entry is added to the tables which is always
324     * aliased. This can be sampled with fewer bits required from the
325     * {@link UniformRandomProvider}. Increasing the padding of zeros increases the
326     * chance of using this fast path to selecting a sample. The penalty is
327     * two-fold: initialisation is bounded by {@code O(n)} time with {@code n} the
328     * size <strong>after</strong> padding; an additional memory cost of 4 bytes per
329     * padded value.</p>
330     *
331     * <p>Zero padding to any length improves performance; using a power of 2 allows
332     * the index into the tables to be more efficiently generated. The argument
333     * {@code alpha} controls the level of padding. Positive values of {@code alpha}
334     * represent a scale factor in powers of 2. The size of the input array will be
335     * increased by a factor of 2<sup>alpha</sup> and then rounded-up to the next
336     * power of 2. Padding is bounded by the upper limit on the size of an
337     * array.</p>
338     *
339     * <p>The chance of executing the slow path is upper bounded at
340     * 2<sup>-alpha</sup> when padding is enabled. Each successive doubling of
341     * padding will have diminishing performance gains.</p>
342     *
343     * @param rng Generator of uniformly distributed random numbers.
344     * @param probabilities The list of probabilities.
345     * @param alpha The alpha factor controlling the zero padding.
346     * @return the sampler
347     * @throws IllegalArgumentException if {@code probabilities} is null or empty, a
348     * probability is negative, infinite or {@code NaN}, or the sum of all
349     * probabilities is not strictly positive.
350     */
351    public static SharedStateDiscreteSampler of(final UniformRandomProvider rng,
352                                                final double[] probabilities,
353                                                int alpha) {
354        // The Alias method balances N categories with counts around the mean into N sections,
355        // each allocated 'mean' observations.
356        //
357        // Consider 4 categories with counts 6,3,2,1. The histogram can be balanced into a
358        // 2D array as 4 sections with a height of the mean:
359        //
360        // 6
361        // 6
362        // 6
363        // 63   => 6366   --
364        // 632     6326    |-- mean
365        // 6321    6321   --
366        //
367        // section abcd
368        //
369        // Each section is divided as:
370        // a: 6=1/1
371        // b: 3=1/1
372        // c: 2=2/3; 6=1/3   (6 is the alias)
373        // d: 1=1/3; 6=2/3   (6 is the alias)
374        //
375        // The sample is obtained by randomly selecting a section, then choosing which category
376        // from the pair based on a uniform random deviate.
377
378        final double sumProb = InternalUtils.validateProbabilities(probabilities);
379
380        // Allow zero-padding
381        final int n = computeSize(probabilities.length, alpha);
382
383        // Partition into small and large by splitting on the average.
384        final double mean = sumProb / n;
385        // The cardinality of smallSize + largeSize = n.
386        // So fill the same array from either end.
387        final int[] indices = new int[n];
388        int large = n;
389        int small = 0;
390        for (int i = 0; i < probabilities.length; i++) {
391            if (probabilities[i] >= mean) {
392                indices[--large] = i;
393            } else {
394                indices[small++] = i;
395            }
396        }
397
398        small = fillRemainingIndices(probabilities.length, indices, small);
399
400        // This may be smaller than the input length if the probabilities were already padded.
401        final int nonZeroIndex = findLastNonZeroIndex(probabilities);
402
403        // The probabilities are modified so use a copy.
404        // Note: probabilities are required only up to last nonZeroIndex
405        final double[] remainingProbabilities = Arrays.copyOf(probabilities, nonZeroIndex + 1);
406
407        // Allocate the final tables.
408        // Probability table may be truncated (when zero padded).
409        // The alias table is full length.
410        final long[] probability = new long[remainingProbabilities.length];
411        final int[] alias = new int[n];
412
413        // This loop uses each large in turn to fill the alias table for small probabilities that
414        // do not reach the requirement to fill an entire section alone (i.e. p < mean).
415        // Since the sum of the small should be less than the sum of the large it should use up
416        // all the small first. However floating point round-off can result in
417        // misclassification of items as small or large. The Vose algorithm handles this using
418        // a while loop conditioned on the size of both sets and a subsequent loop to use
419        // unpaired items.
420        while (large != n && small != 0) {
421            // Index of the small and the large probabilities.
422            final int j = indices[--small];
423            final int k = indices[large++];
424
425            // Optimisation for zero-padded input:
426            // p(j) = 0 above the last nonZeroIndex
427            if (j > nonZeroIndex) {
428                // The entire amount for the section is taken from the alias.
429                remainingProbabilities[k] -= mean;
430            } else {
431                final double pj = remainingProbabilities[j];
432
433                // Item j is a small probability that is below the mean.
434                // Compute the weight of the section for item j: pj / mean.
435                // This is scaled by 2^53 and the ceiling function used to round-up
436                // the probability to a numerator of a fraction in the range [1,2^53].
437                // Ceiling ensures non-zero values.
438                probability[j] = (long) Math.ceil(CONVERT_TO_NUMERATOR * (pj / mean));
439
440                // The remaining amount for the section is taken from the alias.
441                // Effectively: probabilities[k] -= (mean - pj)
442                remainingProbabilities[k] += pj - mean;
443            }
444
445            // If not j then the alias is k
446            alias[j] = k;
447
448            // Add the remaining probability from large to the appropriate list.
449            if (remainingProbabilities[k] >= mean) {
450                indices[--large] = k;
451            } else {
452                indices[small++] = k;
453            }
454        }
455
456        // Final loop conditions to consume unpaired items.
457        // Note: The large set should never be non-empty but this can occur due to round-off
458        // error so consume from both.
459        fillTable(probability, alias, indices, 0, small);
460        fillTable(probability, alias, indices, large, n);
461
462        // Change the algorithm for small power of 2 sized tables
463        return isSmallPowerOf2(n) ?
464            new SmallTableAliasMethodDiscreteSampler(rng, probability, alias) :
465            new AliasMethodDiscreteSampler(rng, probability, alias);
466    }
467
468    /**
469     * Allocate the remaining indices from zero padding as small probabilities. The
470     * number to add is from the length of the probability array to the length of
471     * the padded probability array (which is the same length as the indices array).
472     *
473     * @param length Length of probability array.
474     * @param indices Indices.
475     * @param small Number of small indices.
476     * @return the updated number of small indices
477     */
478    private static int fillRemainingIndices(final int length, final int[] indices, int small) {
479        int updatedSmall = small;
480        for (int i = length; i < indices.length; i++) {
481            indices[updatedSmall++] = i;
482        }
483        return updatedSmall;
484    }
485
486    /**
487     * Find the last non-zero index in the probabilities. This may be smaller than
488     * the input length if the probabilities were already padded.
489     *
490     * @param probabilities The list of probabilities.
491     * @return the index
492     */
493    private static int findLastNonZeroIndex(final double[] probabilities) {
494        // No bounds check is performed when decrementing as the array contains at least one
495        // value above zero.
496        int nonZeroIndex = probabilities.length - 1;
497        while (probabilities[nonZeroIndex] == ZERO) {
498            nonZeroIndex--;
499        }
500        return nonZeroIndex;
501    }
502
503    /**
504     * Compute the size after padding. A value of {@code alpha < 0} disables
505     * padding. Otherwise the length will be increased by 2<sup>alpha</sup>
506     * rounded-up to the next power of 2.
507     *
508     * @param length Length of probability array.
509     * @param alpha The alpha factor controlling the zero padding.
510     * @return the padded size
511     */
512    private static int computeSize(int length, int alpha) {
513        if (alpha < 0) {
514            // No padding
515            return length;
516        }
517        // Use the number of leading zeros function to find the next power of 2,
518        // i.e. ceil(log2(x))
519        int pow2 = 32 - Integer.numberOfLeadingZeros(length - 1);
520        // Increase by the alpha. Clip this to limit to a positive integer (2^30)
521        pow2 = Math.min(30, pow2 + alpha);
522        // Use max to handle a length above the highest possible power of 2
523        return Math.max(length, 1 << pow2);
524    }
525
526    /**
527     * Fill the tables using unpaired items that are in the range between {@code start} inclusive
528     * and {@code end} exclusive.
529     *
530     * <p>Anything left must fill the entire section so the probability table is set
531     * to 1 and there is no alias. This will occur for 1/n samples, i.e. the last
532     * remaining unpaired probability. Note: When the tables are zero-padded the
533     * remaining indices are from an input probability that is above zero so the
534     * index will be allowed in the truncated probability array and no
535     * index-out-of-bounds exception will occur.
536     *
537     * @param probability Probability table.
538     * @param alias Alias table.
539     * @param indices Unpaired indices.
540     * @param start Start position.
541     * @param end End position.
542     */
543    private static void fillTable(long[] probability, int[] alias, int[] indices, int start, int end) {
544        for (int i = start; i < end; i++) {
545            final int index = indices[i];
546            probability[index] = ONE_AS_NUMERATOR;
547            alias[index] = index;
548        }
549    }
550
551    /**
552     * Checks if the size is a small power of 2 so can be supported by the
553     * {@link SmallTableAliasMethodDiscreteSampler}.
554     *
555     * @param n Size of the alias table.
556     * @return true if supported by {@link SmallTableAliasMethodDiscreteSampler}
557     */
558    private static boolean isSmallPowerOf2(int n) {
559        return n <= MAX_SMALL_POWER_2_SIZE && (n & (n - 1)) == 0;
560    }
561}