1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17 package org.apache.commons.rng.sampling.distribution;
18
19 import java.math.BigInteger;
20 import java.util.Arrays;
21 import org.apache.commons.rng.UniformRandomProvider;
22
23 /**
24 * Distribution sampler that uses the Fast Loaded Dice Roller (FLDR). It can be used to
25 * sample from {@code n} values each with an associated relative weight. If all unique items
26 * are assigned the same weight it is more efficient to use the {@link DiscreteUniformSampler}.
27 *
28 * <p>Given a list {@code L} of {@code n} positive numbers,
29 * where {@code L[i]} represents the relative weight of the {@code i}th side, FLDR returns
30 * integer {@code i} with relative probability {@code L[i]}.
31 *
32 * <p>FLDR produces <em>exact</em> samples from the specified probability distribution.
33 * <ul>
34 * <li>For integer weights, the probability of returning {@code i} is precisely equal to the
35 * rational number {@code L[i] / m}, where {@code m} is the sum of {@code L}.
36 * <li>For floating-points weights, each weight {@code L[i]} is converted to the
37 * corresponding rational number {@code p[i] / q[i]} where {@code p[i]} is a positive integer and
38 * {@code q[i]} is a power of 2. The rational weights are then normalized (exactly) to sum to unity.
39 * </ul>
40 *
41 * <p>Note that if <em>exact</em> samples are not required then an alternative sampler that
42 * ignores very small relative weights may have improved sampling performance.
43 *
44 * <p>This implementation is based on the algorithm in:
45 *
46 * <blockquote>
47 * Feras A. Saad, Cameron E. Freer, Martin C. Rinard, and Vikash K. Mansinghka.
48 * The Fast Loaded Dice Roller: A Near-Optimal Exact Sampler for Discrete Probability
49 * Distributions. In AISTATS 2020: Proceedings of the 23rd International Conference on
50 * Artificial Intelligence and Statistics, Proceedings of Machine Learning Research 108,
51 * Palermo, Sicily, Italy, 2020.
52 * </blockquote>
53 *
54 * <p>Sampling uses {@link UniformRandomProvider#nextInt()} as the source of random bits.
55 *
56 * @see <a href="https://arxiv.org/abs/2003.03830">Saad et al (2020)
57 * Proceedings of the 23rd International Conference on Artificial Intelligence and Statistics,
58 * PMLR 108:1036-1046.</a>
59 * @since 1.5
60 */
61 public abstract class FastLoadedDiceRollerDiscreteSampler
62 implements SharedStateDiscreteSampler {
63 /**
64 * The maximum size of an array.
65 *
66 * <p>This value is taken from the limit in Open JDK 8 {@code java.util.ArrayList}.
67 * It allows VMs to reserve some header words in an array.
68 */
69 private static final int MAX_ARRAY_SIZE = Integer.MAX_VALUE - 8;
70 /** The maximum biased exponent for a finite double.
71 * This is offset by 1023 from {@code Math.getExponent(Double.MAX_VALUE)}. */
72 private static final int MAX_BIASED_EXPONENT = 2046;
73 /** Size of the mantissa of a double. Equal to 52 bits. */
74 private static final int MANTISSA_SIZE = 52;
75 /** Mask to extract the 52-bit mantissa from a long representation of a double. */
76 private static final long MANTISSA_MASK = 0x000f_ffff_ffff_ffffL;
77 /** BigInteger representation of {@link Long#MAX_VALUE}. */
78 private static final BigInteger MAX_LONG = BigInteger.valueOf(Long.MAX_VALUE);
79 /** The maximum offset that will avoid loss of bits for a left shift of a 53-bit value.
80 * The value will remain positive for any shift {@code <=} this value. */
81 private static final int MAX_OFFSET = 10;
82 /** Initial value for no leaf node label. */
83 private static final int NO_LABEL = Integer.MAX_VALUE;
84 /** Name of the sampler. */
85 private static final String SAMPLER_NAME = "Fast Loaded Dice Roller";
86
87 /**
88 * Class to handle the edge case of observations in only one category.
89 */
90 private static final class FixedValueDiscreteSampler extends FastLoadedDiceRollerDiscreteSampler {
91 /** The sample value. */
92 private final int sampleValue;
93
94 /**
95 * @param sampleValue Sample value.
96 */
97 FixedValueDiscreteSampler(int sampleValue) {
98 this.sampleValue = sampleValue;
99 }
100
101 @Override
102 public int sample() {
103 return sampleValue;
104 }
105
106 @Override
107 public FastLoadedDiceRollerDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
108 return this;
109 }
110
111 @Override
112 public String toString() {
113 return SAMPLER_NAME;
114 }
115 }
116
117 /**
118 * Class to implement the FLDR sample algorithm.
119 */
120 private static final class FLDRSampler extends FastLoadedDiceRollerDiscreteSampler {
121 /** Empty boolean source. This is the location of the sign-bit after 31 right shifts on
122 * the boolean source. */
123 private static final int EMPTY_BOOL_SOURCE = 1;
124
125 /** Underlying source of randomness. */
126 private final UniformRandomProvider rng;
127 /** Number of categories. */
128 private final int n;
129 /** Number of levels in the discrete distribution generating (DDG) tree.
130 * Equal to {@code ceil(log2(m))} where {@code m} is the sum of observations. */
131 private final int k;
132 /** Number of leaf nodes at each level. */
133 private final int[] h;
134 /** Stores the leaf node labels in increasing order. Named {@code H} in the FLDR paper. */
135 private final int[] lH;
136
137 /**
138 * Provides a bit source for booleans.
139 *
140 * <p>A cached value from a call to {@link UniformRandomProvider#nextInt()}.
141 *
142 * <p>Only stores 31-bits when full as 1 bit has already been consumed.
143 * The sign bit is a flag that shifts down so the source eventually equals 1
144 * when all bits are consumed and will trigger a refill.
145 */
146 private int booleanSource = EMPTY_BOOL_SOURCE;
147
148 /**
149 * Creates a sampler.
150 *
151 * <p>The input parameters are not validated and must be correctly computed tables.
152 *
153 * @param rng Generator of uniformly distributed random numbers.
154 * @param n Number of categories
155 * @param k Number of levels in the discrete distribution generating (DDG) tree.
156 * Equal to {@code ceil(log2(m))} where {@code m} is the sum of observations.
157 * @param h Number of leaf nodes at each level.
158 * @param lH Stores the leaf node labels in increasing order.
159 */
160 FLDRSampler(UniformRandomProvider rng,
161 int n,
162 int k,
163 int[] h,
164 int[] lH) {
165 this.rng = rng;
166 this.n = n;
167 this.k = k;
168 // Deliberate direct storage of input arrays
169 this.h = h;
170 this.lH = lH;
171 }
172
173 /**
174 * Creates a copy with a new source of randomness.
175 *
176 * @param rng Generator of uniformly distributed random numbers.
177 * @param source Source to copy.
178 */
179 private FLDRSampler(UniformRandomProvider rng,
180 FLDRSampler source) {
181 this.rng = rng;
182 this.n = source.n;
183 this.k = source.k;
184 this.h = source.h;
185 this.lH = source.lH;
186 }
187
188 /** {@inheritDoc} */
189 @Override
190 public int sample() {
191 // ALGORITHM 5: SAMPLE
192 int c = 0;
193 int d = 0;
194 for (;;) {
195 // b = flip()
196 // d = 2 * d + (1 - b)
197 d = (d << 1) + flip();
198 if (d < h[c]) {
199 // z = H[d][c]
200 final int z = lH[d * k + c];
201 // assert z != NO_LABEL
202 if (z < n) {
203 return z;
204 }
205 d = 0;
206 c = 0;
207 } else {
208 d = d - h[c];
209 c++;
210 }
211 }
212 }
213
214 /**
215 * Provides a source of boolean bits.
216 *
217 * <p>Note: This replicates the boolean cache functionality of
218 * {@code o.a.c.rng.core.source32.IntProvider}. The method has been simplified to return
219 * an {@code int} value rather than a {@code boolean}.
220 *
221 * @return the bit (0 or 1)
222 */
223 private int flip() {
224 int bits = booleanSource;
225 if (bits == 1) {
226 // Refill
227 bits = rng.nextInt();
228 // Store a refill flag in the sign bit and the unused 31 bits, return lowest bit
229 booleanSource = Integer.MIN_VALUE | (bits >>> 1);
230 return bits & 0x1;
231 }
232 // Shift down eventually triggering refill, return current lowest bit
233 booleanSource = bits >>> 1;
234 return bits & 0x1;
235 }
236
237 /** {@inheritDoc} */
238 @Override
239 public String toString() {
240 return SAMPLER_NAME + " [" + rng.toString() + "]";
241 }
242
243 /** {@inheritDoc} */
244 @Override
245 public FastLoadedDiceRollerDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
246 return new FLDRSampler(rng, this);
247 }
248 }
249
250 /** Package-private constructor. */
251 FastLoadedDiceRollerDiscreteSampler() {
252 // Intentionally empty
253 }
254
255 /** {@inheritDoc} */
256 // Redeclare the signature to return a FastLoadedDiceRollerSampler not a SharedStateLongSampler
257 @Override
258 public abstract FastLoadedDiceRollerDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng);
259
260 /**
261 * Creates a sampler.
262 *
263 * <p>Note: The discrete distribution generating (DDG) tree requires {@code (n + 1) * k} entries
264 * where {@code n} is the number of categories, {@code k == ceil(log2(m))} and {@code m}
265 * is the sum of the observed frequencies. An exception is raised if this cannot be allocated
266 * as a single array.
267 *
268 * <p>For reference the sum is limited to {@link Long#MAX_VALUE} and the value {@code k} to 63.
269 * The number of categories is limited to approximately {@code ((2^31 - 1) / k) = 34,087,042}
270 * when the sum of frequencies is large enough to create k=63.
271 *
272 * @param rng Generator of uniformly distributed random numbers.
273 * @param frequencies Observed frequencies of the discrete distribution.
274 * @return the sampler
275 * @throws IllegalArgumentException if {@code frequencies} is null or empty, a
276 * frequency is negative, the sum of all frequencies is either zero or
277 * above {@link Long#MAX_VALUE}, or the size of the discrete distribution generating tree
278 * is too large.
279 */
280 public static FastLoadedDiceRollerDiscreteSampler of(UniformRandomProvider rng,
281 long[] frequencies) {
282 final long m = sum(frequencies);
283
284 // Obtain indices of non-zero frequencies
285 final int[] indices = indicesOfNonZero(frequencies);
286
287 // Edge case for 1 non-zero weight. This also handles edge case for 1 observation
288 // (as log2(m) == 0 will break the computation of the DDG tree).
289 if (indices.length == 1) {
290 return new FixedValueDiscreteSampler(indexOfNonZero(frequencies));
291 }
292
293 return createSampler(rng, frequencies, indices, m);
294 }
295
296 /**
297 * Creates a sampler.
298 *
299 * <p>Weights are converted to rational numbers {@code p / q} where {@code q} is a power of 2.
300 * The numerators {@code p} are scaled to use a common denominator before summing.
301 *
302 * <p>All weights are used to create the sampler. Weights with a small magnitude relative
303 * to the largest weight can be excluded using the constructor method with the
304 * relative magnitude parameter {@code alpha} (see {@link #of(UniformRandomProvider, double[], int)}).
305 *
306 * @param rng Generator of uniformly distributed random numbers.
307 * @param weights Weights of the discrete distribution.
308 * @return the sampler
309 * @throws IllegalArgumentException if {@code weights} is null or empty, a
310 * weight is negative, infinite or {@code NaN}, the sum of all weights is zero, or the size
311 * of the discrete distribution generating tree is too large.
312 * @see #of(UniformRandomProvider, double[], int)
313 */
314 public static FastLoadedDiceRollerDiscreteSampler of(UniformRandomProvider rng,
315 double[] weights) {
316 return of(rng, weights, 0);
317 }
318
319 /**
320 * Creates a sampler.
321 *
322 * <p>Weights are converted to rational numbers {@code p / q} where {@code q} is
323 * a power of 2. The numerators {@code p} are scaled to use a common
324 * denominator before summing.
325 *
326 * <p>Note: The discrete distribution generating (DDG) tree requires
327 * {@code (n + 1) * k} entries where {@code n} is the number of categories,
328 * {@code k == ceil(log2(m))} and {@code m} is the sum of the weight numerators
329 * {@code q}. An exception is raised if this cannot be allocated as a single
330 * array.
331 *
332 * <p>For reference the value {@code k} is equal to or greater than the ratio of
333 * the largest to the smallest weight expressed as a power of 2. For
334 * {@code Double.MAX_VALUE / Double.MIN_VALUE} this is ~2098. The value
335 * {@code k} increases with the sum of the weight numerators. A number of
336 * weights in excess of 1,000,000 with values equal to {@link Double#MAX_VALUE}
337 * would be required to raise an exception when the minimum weight is
338 * {@link Double#MIN_VALUE}.
339 *
340 * <p>Weights with a small magnitude relative to the largest weight can be
341 * excluded using the relative magnitude parameter {@code alpha}. This will set
342 * any weight to zero if the magnitude is approximately 2<sup>alpha</sup>
343 * <em>smaller</em> than the largest weight. This comparison is made using only
344 * the exponent of the input weights. The {@code alpha} parameter is ignored if
345 * not above zero. Note that a small {@code alpha} parameter will exclude more
346 * weights than a large {@code alpha} parameter.
347 *
348 * <p>The alpha parameter can be used to exclude categories that
349 * have a very low probability of occurrence and will improve the construction
350 * performance of the sampler. The effect on sampling performance depends on
351 * the relative weights of the excluded categories; typically a high {@code alpha}
352 * is used to exclude categories that would be visited with a very low probability
353 * and the sampling performance is unchanged.
354 *
355 * <p><b>Implementation Note</b>
356 *
357 * <p>This method creates a sampler with <em>exact</em> samples from the
358 * specified probability distribution. It is recommended to use this method:
359 * <ul>
360 * <li>if the weights are computed, for example from a probability mass function; or
361 * <li>if the weights sum to an infinite value.
362 * </ul>
363 *
364 * <p>If the weights are computed from empirical observations then it is
365 * recommended to use the factory method
366 * {@link #of(UniformRandomProvider, long[]) accepting frequencies}. This
367 * requires the total number of observations to be representable as a long
368 * integer.
369 *
370 * <p>Note that if all weights are scaled by a power of 2 to be integers, and
371 * each integer can be represented as a positive 64-bit long value, then the
372 * sampler created using this method will match the output from a sampler
373 * created with the scaled weights converted to long values for the factory
374 * method {@link #of(UniformRandomProvider, long[]) accepting frequencies}. This
375 * assumes the sum of the integer values does not overflow.
376 *
377 * <p>It should be noted that the conversion of weights to rational numbers has
378 * a performance overhead during construction (sampling performance is not
379 * affected). This may be avoided by first converting them to integer values
380 * that can be summed without overflow. For example by scaling values by
381 * {@code 2^62 / sum} and converting to long by casting or rounding.
382 *
383 * <p>This approach may increase the efficiency of construction. The resulting
384 * sampler may no longer produce <em>exact</em> samples from the distribution.
385 * In particular any weights with a converted frequency of zero cannot be
386 * sampled.
387 *
388 * @param rng Generator of uniformly distributed random numbers.
389 * @param weights Weights of the discrete distribution.
390 * @param alpha Alpha parameter.
391 * @return the sampler
392 * @throws IllegalArgumentException if {@code weights} is null or empty, a
393 * weight is negative, infinite or {@code NaN}, the sum of all weights is zero,
394 * or the size of the discrete distribution generating tree is too large.
395 * @see #of(UniformRandomProvider, long[])
396 */
397 public static FastLoadedDiceRollerDiscreteSampler of(UniformRandomProvider rng,
398 double[] weights,
399 int alpha) {
400 final int n = checkWeightsNonZeroLength(weights);
401
402 // Convert floating-point double to a relative weight
403 // using a shifted integer representation
404 final long[] frequencies = new long[n];
405 final int[] offsets = new int[n];
406 convertToIntegers(weights, frequencies, offsets, alpha);
407
408 // Obtain indices of non-zero weights
409 final int[] indices = indicesOfNonZero(frequencies);
410
411 // Edge case for 1 non-zero weight.
412 if (indices.length == 1) {
413 return new FixedValueDiscreteSampler(indexOfNonZero(frequencies));
414 }
415
416 final BigInteger m = sum(frequencies, offsets, indices);
417
418 // Use long arithmetic if possible. This occurs when the weights are similar in magnitude.
419 if (m.compareTo(MAX_LONG) <= 0) {
420 // Apply the offset
421 for (int i = 0; i < n; i++) {
422 frequencies[i] <<= offsets[i];
423 }
424 return createSampler(rng, frequencies, indices, m.longValue());
425 }
426
427 return createSampler(rng, frequencies, offsets, indices, m);
428 }
429
430 /**
431 * Sum the frequencies.
432 *
433 * @param frequencies Frequencies.
434 * @return the sum
435 * @throws IllegalArgumentException if {@code frequencies} is null or empty, a
436 * frequency is negative, or the sum of all frequencies is either zero or above
437 * {@link Long#MAX_VALUE}
438 */
439 private static long sum(long[] frequencies) {
440 // Validate
441 if (frequencies == null || frequencies.length == 0) {
442 throw new IllegalArgumentException("frequencies must contain at least 1 value");
443 }
444
445 // Sum the values.
446 // Combine all the sign bits in the observations and the intermediate sum in a flag.
447 long m = 0;
448 long signFlag = 0;
449 for (final long o : frequencies) {
450 m += o;
451 signFlag |= o | m;
452 }
453
454 // Check for a sign-bit.
455 if (signFlag < 0) {
456 // One or more observations were negative, or the sum overflowed.
457 for (final long o : frequencies) {
458 if (o < 0) {
459 throw new IllegalArgumentException("frequencies must contain positive values: " + o);
460 }
461 }
462 throw new IllegalArgumentException("Overflow when summing frequencies");
463 }
464 if (m == 0) {
465 throw new IllegalArgumentException("Sum of frequencies is zero");
466 }
467 return m;
468 }
469
470 /**
471 * Convert the floating-point weights to relative weights represented as
472 * integers {@code value * 2^exponent}. The relative weight as an integer is:
473 *
474 * <pre>
475 * BigInteger.valueOf(value).shiftLeft(exponent)
476 * </pre>
477 *
478 * <p>Note that the weights are created using a common power-of-2 scaling
479 * operation so the minimum exponent is zero.
480 *
481 * <p>A positive {@code alpha} parameter is used to set any weight to zero if
482 * the magnitude is approximately 2<sup>alpha</sup> <em>smaller</em> than the
483 * largest weight. This comparison is made using only the exponent of the input
484 * weights.
485 *
486 * @param weights Weights of the discrete distribution.
487 * @param values Output floating-point mantissas converted to 53-bit integers.
488 * @param exponents Output power of 2 exponent.
489 * @param alpha Alpha parameter.
490 * @throws IllegalArgumentException if a weight is negative, infinite or
491 * {@code NaN}, or the sum of all weights is zero.
492 */
493 private static void convertToIntegers(double[] weights, long[] values, int[] exponents, int alpha) {
494 int maxExponent = Integer.MIN_VALUE;
495 for (int i = 0; i < weights.length; i++) {
496 final double weight = weights[i];
497 // Ignore zero.
498 // When creating the integer value later using bit shifts the result will remain zero.
499 if (weight == 0) {
500 continue;
501 }
502 final long bits = Double.doubleToRawLongBits(weight);
503
504 // For the IEEE 754 format see Double.longBitsToDouble(long).
505
506 // Extract the exponent (with the sign bit)
507 int exp = (int) (bits >>> MANTISSA_SIZE);
508 // Detect negative, infinite or NaN.
509 // Note: Negative values sign bit will cause the exponent to be too high.
510 if (exp > MAX_BIASED_EXPONENT) {
511 throw new IllegalArgumentException("Invalid weight: " + weight);
512 }
513 long mantissa;
514 if (exp == 0) {
515 // Sub-normal number:
516 mantissa = (bits & MANTISSA_MASK) << 1;
517 // Here we convert to a normalised number by counting the leading zeros
518 // to obtain the number of shifts of the most significant bit in
519 // the mantissa that is required to get a 1 at position 53 (i.e. as
520 // if it were a normal number with assumed leading bit).
521 final int shift = Long.numberOfLeadingZeros(mantissa << 11);
522 mantissa <<= shift;
523 exp -= shift;
524 } else {
525 // Normal number. Add the implicit leading 1-bit.
526 mantissa = (bits & MANTISSA_MASK) | (1L << MANTISSA_SIZE);
527 }
528
529 // Here the floating-point number is equal to:
530 // mantissa * 2^(exp-1075)
531
532 values[i] = mantissa;
533 exponents[i] = exp;
534 maxExponent = Math.max(maxExponent, exp);
535 }
536
537 // No exponent indicates that all weights are zero
538 if (maxExponent == Integer.MIN_VALUE) {
539 throw new IllegalArgumentException("Sum of weights is zero");
540 }
541
542 filterWeights(values, exponents, alpha, maxExponent);
543 scaleWeights(values, exponents);
544 }
545
546 /**
547 * Filters small weights using the {@code alpha} parameter.
548 * A positive {@code alpha} parameter is used to set any weight to zero if
549 * the magnitude is approximately 2<sup>alpha</sup> <em>smaller</em> than the
550 * largest weight. This comparison is made using only the exponent of the input
551 * weights.
552 *
553 * @param values 53-bit values.
554 * @param exponents Power of 2 exponent.
555 * @param alpha Alpha parameter.
556 * @param maxExponent Maximum exponent.
557 */
558 private static void filterWeights(long[] values, int[] exponents, int alpha, int maxExponent) {
559 if (alpha > 0) {
560 // Filter weights. This must be done before the values are shifted so
561 // the exponent represents the approximate magnitude of the value.
562 for (int i = 0; i < exponents.length; i++) {
563 if (maxExponent - exponents[i] > alpha) {
564 values[i] = 0;
565 }
566 }
567 }
568 }
569
570 /**
571 * Scale the weights represented as integers {@code value * 2^exponent} to use a
572 * minimum exponent of zero. The values are scaled to remove any common trailing zeros
573 * in their representation. This ultimately reduces the size of the discrete distribution
574 * generating (DGG) tree.
575 *
576 * @param values 53-bit values.
577 * @param exponents Power of 2 exponent.
578 */
579 private static void scaleWeights(long[] values, int[] exponents) {
580 // Find the minimum exponent and common trailing zeros.
581 int minExponent = Integer.MAX_VALUE;
582 for (int i = 0; i < exponents.length; i++) {
583 if (values[i] != 0) {
584 minExponent = Math.min(minExponent, exponents[i]);
585 }
586 }
587 // Trailing zeros occur when the original weights have a representation with
588 // less than 52 binary digits, e.g. {1.5, 0.5, 0.25}.
589 int trailingZeros = Long.SIZE;
590 for (int i = 0; i < values.length && trailingZeros != 0; i++) {
591 trailingZeros = Math.min(trailingZeros, Long.numberOfTrailingZeros(values[i]));
592 }
593 // Scale by a power of 2 so the minimum exponent is zero.
594 for (int i = 0; i < exponents.length; i++) {
595 exponents[i] -= minExponent;
596 }
597 // Remove common trailing zeros.
598 if (trailingZeros != 0) {
599 for (int i = 0; i < values.length; i++) {
600 values[i] >>>= trailingZeros;
601 }
602 }
603 }
604
605 /**
606 * Sum the integers at the specified indices.
607 * Integers are represented as {@code value * 2^exponent}.
608 *
609 * @param values 53-bit values.
610 * @param exponents Power of 2 exponent.
611 * @param indices Indices to sum.
612 * @return the sum
613 */
614 private static BigInteger sum(long[] values, int[] exponents, int[] indices) {
615 BigInteger m = BigInteger.ZERO;
616 for (final int i : indices) {
617 m = m.add(toBigInteger(values[i], exponents[i]));
618 }
619 return m;
620 }
621
622 /**
623 * Convert the value and left shift offset to a BigInteger.
624 * It is assumed the value is at most 53-bits. This allows optimising the left
625 * shift if it is below 11 bits.
626 *
627 * @param value 53-bit value.
628 * @param offset Left shift offset (must be positive).
629 * @return the BigInteger
630 */
631 private static BigInteger toBigInteger(long value, int offset) {
632 // Ignore zeros. The sum method uses indices of non-zero values.
633 if (offset <= MAX_OFFSET) {
634 // Assume (value << offset) <= Long.MAX_VALUE
635 return BigInteger.valueOf(value << offset);
636 }
637 return BigInteger.valueOf(value).shiftLeft(offset);
638 }
639
640 /**
641 * Creates the sampler.
642 *
643 * <p>It is assumed the frequencies are all positive and the sum does not
644 * overflow.
645 *
646 * @param rng Generator of uniformly distributed random numbers.
647 * @param frequencies Observed frequencies of the discrete distribution.
648 * @param indices Indices of non-zero frequencies.
649 * @param m Sum of the frequencies.
650 * @return the sampler
651 */
652 private static FastLoadedDiceRollerDiscreteSampler createSampler(UniformRandomProvider rng,
653 long[] frequencies,
654 int[] indices,
655 long m) {
656 // ALGORITHM 5: PREPROCESS
657 // a == frequencies
658 // m = sum(a)
659 // h = leaf node count
660 // H = leaf node label (lH)
661
662 final int n = frequencies.length;
663
664 // k = ceil(log2(m))
665 final int k = 64 - Long.numberOfLeadingZeros(m - 1);
666 // r = a(n+1) = 2^k - m
667 final long r = (1L << k) - m;
668
669 // Note:
670 // A sparse matrix can often be used for H, as most of its entries are empty.
671 // This implementation uses a 1D array for efficiency at the cost of memory.
672 // This is limited to approximately ((2^31 - 1) / k), e.g. 34087042 when the sum of
673 // observations is large enough to create k=63.
674 // This could be handled using a 2D array. In practice a number of categories this
675 // large is not expected and is currently not supported.
676 final int[] h = new int[k];
677 final int[] lH = new int[checkArraySize((n + 1L) * k)];
678 Arrays.fill(lH, NO_LABEL);
679
680 int d;
681 for (int j = 0; j < k; j++) {
682 final int shift = (k - 1) - j;
683 final long bitMask = 1L << shift;
684
685 d = 0;
686 for (final int i : indices) {
687 // bool w ← (a[i] >> (k − 1) − j)) & 1
688 // h[j] = h[j] + w
689 // if w then:
690 if ((frequencies[i] & bitMask) != 0) {
691 h[j]++;
692 // H[d][j] = i
693 lH[d * k + j] = i;
694 d++;
695 }
696 }
697 // process a(n+1) without extending the input frequencies array by 1
698 if ((r & bitMask) != 0) {
699 h[j]++;
700 lH[d * k + j] = n;
701 }
702 }
703
704 return new FLDRSampler(rng, n, k, h, lH);
705 }
706
707 /**
708 * Creates the sampler. Frequencies are represented as a 53-bit value with a
709 * left-shift offset.
710 * <pre>
711 * BigInteger.valueOf(value).shiftLeft(offset)
712 * </pre>
713 *
714 * <p>It is assumed the frequencies are all positive.
715 *
716 * @param rng Generator of uniformly distributed random numbers.
717 * @param frequencies Observed frequencies of the discrete distribution.
718 * @param offsets Left shift offsets (must be positive).
719 * @param indices Indices of non-zero frequencies.
720 * @param m Sum of the frequencies.
721 * @return the sampler
722 */
723 private static FastLoadedDiceRollerDiscreteSampler createSampler(UniformRandomProvider rng,
724 long[] frequencies,
725 int[] offsets,
726 int[] indices,
727 BigInteger m) {
728 // Repeat the logic from createSampler(...) using extended arithmetic to test the bits
729
730 // ALGORITHM 5: PREPROCESS
731 // a == frequencies
732 // m = sum(a)
733 // h = leaf node count
734 // H = leaf node label (lH)
735
736 final int n = frequencies.length;
737
738 // k = ceil(log2(m))
739 final int k = m.subtract(BigInteger.ONE).bitLength();
740 // r = a(n+1) = 2^k - m
741 final BigInteger r = BigInteger.ONE.shiftLeft(k).subtract(m);
742
743 final int[] h = new int[k];
744 final int[] lH = new int[checkArraySize((n + 1L) * k)];
745 Arrays.fill(lH, NO_LABEL);
746
747 int d;
748 for (int j = 0; j < k; j++) {
749 final int shift = (k - 1) - j;
750
751 d = 0;
752 for (final int i : indices) {
753 // bool w ← (a[i] >> (k − 1) − j)) & 1
754 // h[j] = h[j] + w
755 // if w then:
756 if (testBit(frequencies[i], offsets[i], shift)) {
757 h[j]++;
758 // H[d][j] = i
759 lH[d * k + j] = i;
760 d++;
761 }
762 }
763 // process a(n+1) without extending the input frequencies array by 1
764 if (r.testBit(shift)) {
765 h[j]++;
766 lH[d * k + j] = n;
767 }
768 }
769
770 return new FLDRSampler(rng, n, k, h, lH);
771 }
772
773 /**
774 * Test the logical bit of the shifted integer representation.
775 * The value is assumed to have at most 53-bits of information. The offset
776 * is assumed to be positive. This is functionally equivalent to:
777 * <pre>
778 * BigInteger.valueOf(value).shiftLeft(offset).testBit(n)
779 * </pre>
780 *
781 * @param value 53-bit value.
782 * @param offset Left shift offset.
783 * @param n Index of bit to test.
784 * @return true if the bit is 1
785 */
786 private static boolean testBit(long value, int offset, int n) {
787 if (n < offset) {
788 // All logical trailing bits are zero
789 return false;
790 }
791 // Test if outside the 53-bit value (note that the implicit 1 bit
792 // has been added to the 52-bit mantissas for 'normal' floating-point numbers).
793 final int bit = n - offset;
794 return bit <= MANTISSA_SIZE && (value & (1L << bit)) != 0;
795 }
796
797 /**
798 * Check the weights have a non-zero length.
799 *
800 * @param weights Weights.
801 * @return the length
802 */
803 private static int checkWeightsNonZeroLength(double[] weights) {
804 if (weights == null || weights.length == 0) {
805 throw new IllegalArgumentException("weights must contain at least 1 value");
806 }
807 return weights.length;
808 }
809
810 /**
811 * Create the indices of non-zero values.
812 *
813 * @param values Values.
814 * @return the indices
815 */
816 private static int[] indicesOfNonZero(long[] values) {
817 int n = 0;
818 final int[] indices = new int[values.length];
819 for (int i = 0; i < values.length; i++) {
820 if (values[i] != 0) {
821 indices[n++] = i;
822 }
823 }
824 return Arrays.copyOf(indices, n);
825 }
826
827 /**
828 * Find the index of the first non-zero frequency.
829 *
830 * @param frequencies Frequencies.
831 * @return the index
832 * @throws IllegalStateException if all frequencies are zero.
833 */
834 static int indexOfNonZero(long[] frequencies) {
835 for (int i = 0; i < frequencies.length; i++) {
836 if (frequencies[i] != 0) {
837 return i;
838 }
839 }
840 throw new IllegalStateException("All frequencies are zero");
841 }
842
843 /**
844 * Check the size is valid for a 1D array.
845 *
846 * @param size Size
847 * @return the size as an {@code int}
848 * @throws IllegalArgumentException if the size is too large for a 1D array.
849 */
850 static int checkArraySize(long size) {
851 if (size > MAX_ARRAY_SIZE) {
852 throw new IllegalArgumentException("Unable to allocate array of size: " + size);
853 }
854 return (int) size;
855 }
856 }