001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.commons.rng.sampling.distribution; 018 019import java.math.BigInteger; 020import java.util.Arrays; 021import org.apache.commons.rng.UniformRandomProvider; 022 023/** 024 * Distribution sampler that uses the Fast Loaded Dice Roller (FLDR). It can be used to 025 * sample from {@code n} values each with an associated relative weight. If all unique items 026 * are assigned the same weight it is more efficient to use the {@link DiscreteUniformSampler}. 027 * 028 * <p>Given a list {@code L} of {@code n} positive numbers, 029 * where {@code L[i]} represents the relative weight of the {@code i}th side, FLDR returns 030 * integer {@code i} with relative probability {@code L[i]}. 031 * 032 * <p>FLDR produces <em>exact</em> samples from the specified probability distribution. 033 * <ul> 034 * <li>For integer weights, the probability of returning {@code i} is precisely equal to the 035 * rational number {@code L[i] / m}, where {@code m} is the sum of {@code L}. 036 * <li>For floating-points weights, each weight {@code L[i]} is converted to the 037 * corresponding rational number {@code p[i] / q[i]} where {@code p[i]} is a positive integer and 038 * {@code q[i]} is a power of 2. The rational weights are then normalized (exactly) to sum to unity. 039 * </ul> 040 * 041 * <p>Note that if <em>exact</em> samples are not required then an alternative sampler that 042 * ignores very small relative weights may have improved sampling performance. 043 * 044 * <p>This implementation is based on the algorithm in: 045 * 046 * <blockquote> 047 * Feras A. Saad, Cameron E. Freer, Martin C. Rinard, and Vikash K. Mansinghka. 048 * The Fast Loaded Dice Roller: A Near-Optimal Exact Sampler for Discrete Probability 049 * Distributions. In AISTATS 2020: Proceedings of the 23rd International Conference on 050 * Artificial Intelligence and Statistics, Proceedings of Machine Learning Research 108, 051 * Palermo, Sicily, Italy, 2020. 052 * </blockquote> 053 * 054 * <p>Sampling uses {@link UniformRandomProvider#nextInt()} as the source of random bits. 055 * 056 * @see <a href="https://arxiv.org/abs/2003.03830">Saad et al (2020) 057 * Proceedings of the 23rd International Conference on Artificial Intelligence and Statistics, 058 * PMLR 108:1036-1046.</a> 059 * @since 1.5 060 */ 061public abstract class FastLoadedDiceRollerDiscreteSampler 062 implements SharedStateDiscreteSampler { 063 /** 064 * The maximum size of an array. 065 * 066 * <p>This value is taken from the limit in Open JDK 8 {@code java.util.ArrayList}. 067 * It allows VMs to reserve some header words in an array. 068 */ 069 private static final int MAX_ARRAY_SIZE = Integer.MAX_VALUE - 8; 070 /** The maximum biased exponent for a finite double. 071 * This is offset by 1023 from {@code Math.getExponent(Double.MAX_VALUE)}. */ 072 private static final int MAX_BIASED_EXPONENT = 2046; 073 /** Size of the mantissa of a double. Equal to 52 bits. */ 074 private static final int MANTISSA_SIZE = 52; 075 /** Mask to extract the 52-bit mantissa from a long representation of a double. */ 076 private static final long MANTISSA_MASK = 0x000f_ffff_ffff_ffffL; 077 /** BigInteger representation of {@link Long#MAX_VALUE}. */ 078 private static final BigInteger MAX_LONG = BigInteger.valueOf(Long.MAX_VALUE); 079 /** The maximum offset that will avoid loss of bits for a left shift of a 53-bit value. 080 * The value will remain positive for any shift {@code <=} this value. */ 081 private static final int MAX_OFFSET = 10; 082 /** Initial value for no leaf node label. */ 083 private static final int NO_LABEL = Integer.MAX_VALUE; 084 /** Name of the sampler. */ 085 private static final String SAMPLER_NAME = "Fast Loaded Dice Roller"; 086 087 /** 088 * Class to handle the edge case of observations in only one category. 089 */ 090 private static final class FixedValueDiscreteSampler extends FastLoadedDiceRollerDiscreteSampler { 091 /** The sample value. */ 092 private final int sampleValue; 093 094 /** 095 * @param sampleValue Sample value. 096 */ 097 FixedValueDiscreteSampler(int sampleValue) { 098 this.sampleValue = sampleValue; 099 } 100 101 @Override 102 public int sample() { 103 return sampleValue; 104 } 105 106 @Override 107 public FastLoadedDiceRollerDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) { 108 return this; 109 } 110 111 @Override 112 public String toString() { 113 return SAMPLER_NAME; 114 } 115 } 116 117 /** 118 * Class to implement the FLDR sample algorithm. 119 */ 120 private static final class FLDRSampler extends FastLoadedDiceRollerDiscreteSampler { 121 /** Empty boolean source. This is the location of the sign-bit after 31 right shifts on 122 * the boolean source. */ 123 private static final int EMPTY_BOOL_SOURCE = 1; 124 125 /** Underlying source of randomness. */ 126 private final UniformRandomProvider rng; 127 /** Number of categories. */ 128 private final int n; 129 /** Number of levels in the discrete distribution generating (DDG) tree. 130 * Equal to {@code ceil(log2(m))} where {@code m} is the sum of observations. */ 131 private final int k; 132 /** Number of leaf nodes at each level. */ 133 private final int[] h; 134 /** Stores the leaf node labels in increasing order. Named {@code H} in the FLDR paper. */ 135 private final int[] lH; 136 137 /** 138 * Provides a bit source for booleans. 139 * 140 * <p>A cached value from a call to {@link UniformRandomProvider#nextInt()}. 141 * 142 * <p>Only stores 31-bits when full as 1 bit has already been consumed. 143 * The sign bit is a flag that shifts down so the source eventually equals 1 144 * when all bits are consumed and will trigger a refill. 145 */ 146 private int booleanSource = EMPTY_BOOL_SOURCE; 147 148 /** 149 * Creates a sampler. 150 * 151 * <p>The input parameters are not validated and must be correctly computed tables. 152 * 153 * @param rng Generator of uniformly distributed random numbers. 154 * @param n Number of categories 155 * @param k Number of levels in the discrete distribution generating (DDG) tree. 156 * Equal to {@code ceil(log2(m))} where {@code m} is the sum of observations. 157 * @param h Number of leaf nodes at each level. 158 * @param lH Stores the leaf node labels in increasing order. 159 */ 160 FLDRSampler(UniformRandomProvider rng, 161 int n, 162 int k, 163 int[] h, 164 int[] lH) { 165 this.rng = rng; 166 this.n = n; 167 this.k = k; 168 // Deliberate direct storage of input arrays 169 this.h = h; 170 this.lH = lH; 171 } 172 173 /** 174 * Creates a copy with a new source of randomness. 175 * 176 * @param rng Generator of uniformly distributed random numbers. 177 * @param source Source to copy. 178 */ 179 private FLDRSampler(UniformRandomProvider rng, 180 FLDRSampler source) { 181 this.rng = rng; 182 this.n = source.n; 183 this.k = source.k; 184 this.h = source.h; 185 this.lH = source.lH; 186 } 187 188 /** {@inheritDoc} */ 189 @Override 190 public int sample() { 191 // ALGORITHM 5: SAMPLE 192 int c = 0; 193 int d = 0; 194 for (;;) { 195 // b = flip() 196 // d = 2 * d + (1 - b) 197 d = (d << 1) + flip(); 198 if (d < h[c]) { 199 // z = H[d][c] 200 final int z = lH[d * k + c]; 201 // assert z != NO_LABEL 202 if (z < n) { 203 return z; 204 } 205 d = 0; 206 c = 0; 207 } else { 208 d = d - h[c]; 209 c++; 210 } 211 } 212 } 213 214 /** 215 * Provides a source of boolean bits. 216 * 217 * <p>Note: This replicates the boolean cache functionality of 218 * {@code o.a.c.rng.core.source32.IntProvider}. The method has been simplified to return 219 * an {@code int} value rather than a {@code boolean}. 220 * 221 * @return the bit (0 or 1) 222 */ 223 private int flip() { 224 int bits = booleanSource; 225 if (bits == 1) { 226 // Refill 227 bits = rng.nextInt(); 228 // Store a refill flag in the sign bit and the unused 31 bits, return lowest bit 229 booleanSource = Integer.MIN_VALUE | (bits >>> 1); 230 return bits & 0x1; 231 } 232 // Shift down eventually triggering refill, return current lowest bit 233 booleanSource = bits >>> 1; 234 return bits & 0x1; 235 } 236 237 /** {@inheritDoc} */ 238 @Override 239 public String toString() { 240 return SAMPLER_NAME + " [" + rng.toString() + "]"; 241 } 242 243 /** {@inheritDoc} */ 244 @Override 245 public FastLoadedDiceRollerDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) { 246 return new FLDRSampler(rng, this); 247 } 248 } 249 250 /** Package-private constructor. */ 251 FastLoadedDiceRollerDiscreteSampler() { 252 // Intentionally empty 253 } 254 255 /** {@inheritDoc} */ 256 // Redeclare the signature to return a FastLoadedDiceRollerSampler not a SharedStateLongSampler 257 @Override 258 public abstract FastLoadedDiceRollerDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng); 259 260 /** 261 * Creates a sampler. 262 * 263 * <p>Note: The discrete distribution generating (DDG) tree requires {@code (n + 1) * k} entries 264 * where {@code n} is the number of categories, {@code k == ceil(log2(m))} and {@code m} 265 * is the sum of the observed frequencies. An exception is raised if this cannot be allocated 266 * as a single array. 267 * 268 * <p>For reference the sum is limited to {@link Long#MAX_VALUE} and the value {@code k} to 63. 269 * The number of categories is limited to approximately {@code ((2^31 - 1) / k) = 34,087,042} 270 * when the sum of frequencies is large enough to create k=63. 271 * 272 * @param rng Generator of uniformly distributed random numbers. 273 * @param frequencies Observed frequencies of the discrete distribution. 274 * @return the sampler 275 * @throws IllegalArgumentException if {@code frequencies} is null or empty, a 276 * frequency is negative, the sum of all frequencies is either zero or 277 * above {@link Long#MAX_VALUE}, or the size of the discrete distribution generating tree 278 * is too large. 279 */ 280 public static FastLoadedDiceRollerDiscreteSampler of(UniformRandomProvider rng, 281 long[] frequencies) { 282 final long m = sum(frequencies); 283 284 // Obtain indices of non-zero frequencies 285 final int[] indices = indicesOfNonZero(frequencies); 286 287 // Edge case for 1 non-zero weight. This also handles edge case for 1 observation 288 // (as log2(m) == 0 will break the computation of the DDG tree). 289 if (indices.length == 1) { 290 return new FixedValueDiscreteSampler(indexOfNonZero(frequencies)); 291 } 292 293 return createSampler(rng, frequencies, indices, m); 294 } 295 296 /** 297 * Creates a sampler. 298 * 299 * <p>Weights are converted to rational numbers {@code p / q} where {@code q} is a power of 2. 300 * The numerators {@code p} are scaled to use a common denominator before summing. 301 * 302 * <p>All weights are used to create the sampler. Weights with a small magnitude relative 303 * to the largest weight can be excluded using the constructor method with the 304 * relative magnitude parameter {@code alpha} (see {@link #of(UniformRandomProvider, double[], int)}). 305 * 306 * @param rng Generator of uniformly distributed random numbers. 307 * @param weights Weights of the discrete distribution. 308 * @return the sampler 309 * @throws IllegalArgumentException if {@code weights} is null or empty, a 310 * weight is negative, infinite or {@code NaN}, the sum of all weights is zero, or the size 311 * of the discrete distribution generating tree is too large. 312 * @see #of(UniformRandomProvider, double[], int) 313 */ 314 public static FastLoadedDiceRollerDiscreteSampler of(UniformRandomProvider rng, 315 double[] weights) { 316 return of(rng, weights, 0); 317 } 318 319 /** 320 * Creates a sampler. 321 * 322 * <p>Weights are converted to rational numbers {@code p / q} where {@code q} is 323 * a power of 2. The numerators {@code p} are scaled to use a common 324 * denominator before summing. 325 * 326 * <p>Note: The discrete distribution generating (DDG) tree requires 327 * {@code (n + 1) * k} entries where {@code n} is the number of categories, 328 * {@code k == ceil(log2(m))} and {@code m} is the sum of the weight numerators 329 * {@code q}. An exception is raised if this cannot be allocated as a single 330 * array. 331 * 332 * <p>For reference the value {@code k} is equal to or greater than the ratio of 333 * the largest to the smallest weight expressed as a power of 2. For 334 * {@code Double.MAX_VALUE / Double.MIN_VALUE} this is ~2098. The value 335 * {@code k} increases with the sum of the weight numerators. A number of 336 * weights in excess of 1,000,000 with values equal to {@link Double#MAX_VALUE} 337 * would be required to raise an exception when the minimum weight is 338 * {@link Double#MIN_VALUE}. 339 * 340 * <p>Weights with a small magnitude relative to the largest weight can be 341 * excluded using the relative magnitude parameter {@code alpha}. This will set 342 * any weight to zero if the magnitude is approximately 2<sup>alpha</sup> 343 * <em>smaller</em> than the largest weight. This comparison is made using only 344 * the exponent of the input weights. The {@code alpha} parameter is ignored if 345 * not above zero. Note that a small {@code alpha} parameter will exclude more 346 * weights than a large {@code alpha} parameter. 347 * 348 * <p>The alpha parameter can be used to exclude categories that 349 * have a very low probability of occurrence and will improve the construction 350 * performance of the sampler. The effect on sampling performance depends on 351 * the relative weights of the excluded categories; typically a high {@code alpha} 352 * is used to exclude categories that would be visited with a very low probability 353 * and the sampling performance is unchanged. 354 * 355 * <p><b>Implementation Note</b> 356 * 357 * <p>This method creates a sampler with <em>exact</em> samples from the 358 * specified probability distribution. It is recommended to use this method: 359 * <ul> 360 * <li>if the weights are computed, for example from a probability mass function; or 361 * <li>if the weights sum to an infinite value. 362 * </ul> 363 * 364 * <p>If the weights are computed from empirical observations then it is 365 * recommended to use the factory method 366 * {@link #of(UniformRandomProvider, long[]) accepting frequencies}. This 367 * requires the total number of observations to be representable as a long 368 * integer. 369 * 370 * <p>Note that if all weights are scaled by a power of 2 to be integers, and 371 * each integer can be represented as a positive 64-bit long value, then the 372 * sampler created using this method will match the output from a sampler 373 * created with the scaled weights converted to long values for the factory 374 * method {@link #of(UniformRandomProvider, long[]) accepting frequencies}. This 375 * assumes the sum of the integer values does not overflow. 376 * 377 * <p>It should be noted that the conversion of weights to rational numbers has 378 * a performance overhead during construction (sampling performance is not 379 * affected). This may be avoided by first converting them to integer values 380 * that can be summed without overflow. For example by scaling values by 381 * {@code 2^62 / sum} and converting to long by casting or rounding. 382 * 383 * <p>This approach may increase the efficiency of construction. The resulting 384 * sampler may no longer produce <em>exact</em> samples from the distribution. 385 * In particular any weights with a converted frequency of zero cannot be 386 * sampled. 387 * 388 * @param rng Generator of uniformly distributed random numbers. 389 * @param weights Weights of the discrete distribution. 390 * @param alpha Alpha parameter. 391 * @return the sampler 392 * @throws IllegalArgumentException if {@code weights} is null or empty, a 393 * weight is negative, infinite or {@code NaN}, the sum of all weights is zero, 394 * or the size of the discrete distribution generating tree is too large. 395 * @see #of(UniformRandomProvider, long[]) 396 */ 397 public static FastLoadedDiceRollerDiscreteSampler of(UniformRandomProvider rng, 398 double[] weights, 399 int alpha) { 400 final int n = checkWeightsNonZeroLength(weights); 401 402 // Convert floating-point double to a relative weight 403 // using a shifted integer representation 404 final long[] frequencies = new long[n]; 405 final int[] offsets = new int[n]; 406 convertToIntegers(weights, frequencies, offsets, alpha); 407 408 // Obtain indices of non-zero weights 409 final int[] indices = indicesOfNonZero(frequencies); 410 411 // Edge case for 1 non-zero weight. 412 if (indices.length == 1) { 413 return new FixedValueDiscreteSampler(indexOfNonZero(frequencies)); 414 } 415 416 final BigInteger m = sum(frequencies, offsets, indices); 417 418 // Use long arithmetic if possible. This occurs when the weights are similar in magnitude. 419 if (m.compareTo(MAX_LONG) <= 0) { 420 // Apply the offset 421 for (int i = 0; i < n; i++) { 422 frequencies[i] <<= offsets[i]; 423 } 424 return createSampler(rng, frequencies, indices, m.longValue()); 425 } 426 427 return createSampler(rng, frequencies, offsets, indices, m); 428 } 429 430 /** 431 * Sum the frequencies. 432 * 433 * @param frequencies Frequencies. 434 * @return the sum 435 * @throws IllegalArgumentException if {@code frequencies} is null or empty, a 436 * frequency is negative, or the sum of all frequencies is either zero or above 437 * {@link Long#MAX_VALUE} 438 */ 439 private static long sum(long[] frequencies) { 440 // Validate 441 if (frequencies == null || frequencies.length == 0) { 442 throw new IllegalArgumentException("frequencies must contain at least 1 value"); 443 } 444 445 // Sum the values. 446 // Combine all the sign bits in the observations and the intermediate sum in a flag. 447 long m = 0; 448 long signFlag = 0; 449 for (final long o : frequencies) { 450 m += o; 451 signFlag |= o | m; 452 } 453 454 // Check for a sign-bit. 455 if (signFlag < 0) { 456 // One or more observations were negative, or the sum overflowed. 457 for (final long o : frequencies) { 458 if (o < 0) { 459 throw new IllegalArgumentException("frequencies must contain positive values: " + o); 460 } 461 } 462 throw new IllegalArgumentException("Overflow when summing frequencies"); 463 } 464 if (m == 0) { 465 throw new IllegalArgumentException("Sum of frequencies is zero"); 466 } 467 return m; 468 } 469 470 /** 471 * Convert the floating-point weights to relative weights represented as 472 * integers {@code value * 2^exponent}. The relative weight as an integer is: 473 * 474 * <pre> 475 * BigInteger.valueOf(value).shiftLeft(exponent) 476 * </pre> 477 * 478 * <p>Note that the weights are created using a common power-of-2 scaling 479 * operation so the minimum exponent is zero. 480 * 481 * <p>A positive {@code alpha} parameter is used to set any weight to zero if 482 * the magnitude is approximately 2<sup>alpha</sup> <em>smaller</em> than the 483 * largest weight. This comparison is made using only the exponent of the input 484 * weights. 485 * 486 * @param weights Weights of the discrete distribution. 487 * @param values Output floating-point mantissas converted to 53-bit integers. 488 * @param exponents Output power of 2 exponent. 489 * @param alpha Alpha parameter. 490 * @throws IllegalArgumentException if a weight is negative, infinite or 491 * {@code NaN}, or the sum of all weights is zero. 492 */ 493 private static void convertToIntegers(double[] weights, long[] values, int[] exponents, int alpha) { 494 int maxExponent = Integer.MIN_VALUE; 495 for (int i = 0; i < weights.length; i++) { 496 final double weight = weights[i]; 497 // Ignore zero. 498 // When creating the integer value later using bit shifts the result will remain zero. 499 if (weight == 0) { 500 continue; 501 } 502 final long bits = Double.doubleToRawLongBits(weight); 503 504 // For the IEEE 754 format see Double.longBitsToDouble(long). 505 506 // Extract the exponent (with the sign bit) 507 int exp = (int) (bits >>> MANTISSA_SIZE); 508 // Detect negative, infinite or NaN. 509 // Note: Negative values sign bit will cause the exponent to be too high. 510 if (exp > MAX_BIASED_EXPONENT) { 511 throw new IllegalArgumentException("Invalid weight: " + weight); 512 } 513 long mantissa; 514 if (exp == 0) { 515 // Sub-normal number: 516 mantissa = (bits & MANTISSA_MASK) << 1; 517 // Here we convert to a normalised number by counting the leading zeros 518 // to obtain the number of shifts of the most significant bit in 519 // the mantissa that is required to get a 1 at position 53 (i.e. as 520 // if it were a normal number with assumed leading bit). 521 final int shift = Long.numberOfLeadingZeros(mantissa << 11); 522 mantissa <<= shift; 523 exp -= shift; 524 } else { 525 // Normal number. Add the implicit leading 1-bit. 526 mantissa = (bits & MANTISSA_MASK) | (1L << MANTISSA_SIZE); 527 } 528 529 // Here the floating-point number is equal to: 530 // mantissa * 2^(exp-1075) 531 532 values[i] = mantissa; 533 exponents[i] = exp; 534 maxExponent = Math.max(maxExponent, exp); 535 } 536 537 // No exponent indicates that all weights are zero 538 if (maxExponent == Integer.MIN_VALUE) { 539 throw new IllegalArgumentException("Sum of weights is zero"); 540 } 541 542 filterWeights(values, exponents, alpha, maxExponent); 543 scaleWeights(values, exponents); 544 } 545 546 /** 547 * Filters small weights using the {@code alpha} parameter. 548 * A positive {@code alpha} parameter is used to set any weight to zero if 549 * the magnitude is approximately 2<sup>alpha</sup> <em>smaller</em> than the 550 * largest weight. This comparison is made using only the exponent of the input 551 * weights. 552 * 553 * @param values 53-bit values. 554 * @param exponents Power of 2 exponent. 555 * @param alpha Alpha parameter. 556 * @param maxExponent Maximum exponent. 557 */ 558 private static void filterWeights(long[] values, int[] exponents, int alpha, int maxExponent) { 559 if (alpha > 0) { 560 // Filter weights. This must be done before the values are shifted so 561 // the exponent represents the approximate magnitude of the value. 562 for (int i = 0; i < exponents.length; i++) { 563 if (maxExponent - exponents[i] > alpha) { 564 values[i] = 0; 565 } 566 } 567 } 568 } 569 570 /** 571 * Scale the weights represented as integers {@code value * 2^exponent} to use a 572 * minimum exponent of zero. The values are scaled to remove any common trailing zeros 573 * in their representation. This ultimately reduces the size of the discrete distribution 574 * generating (DGG) tree. 575 * 576 * @param values 53-bit values. 577 * @param exponents Power of 2 exponent. 578 */ 579 private static void scaleWeights(long[] values, int[] exponents) { 580 // Find the minimum exponent and common trailing zeros. 581 int minExponent = Integer.MAX_VALUE; 582 for (int i = 0; i < exponents.length; i++) { 583 if (values[i] != 0) { 584 minExponent = Math.min(minExponent, exponents[i]); 585 } 586 } 587 // Trailing zeros occur when the original weights have a representation with 588 // less than 52 binary digits, e.g. {1.5, 0.5, 0.25}. 589 int trailingZeros = Long.SIZE; 590 for (int i = 0; i < values.length && trailingZeros != 0; i++) { 591 trailingZeros = Math.min(trailingZeros, Long.numberOfTrailingZeros(values[i])); 592 } 593 // Scale by a power of 2 so the minimum exponent is zero. 594 for (int i = 0; i < exponents.length; i++) { 595 exponents[i] -= minExponent; 596 } 597 // Remove common trailing zeros. 598 if (trailingZeros != 0) { 599 for (int i = 0; i < values.length; i++) { 600 values[i] >>>= trailingZeros; 601 } 602 } 603 } 604 605 /** 606 * Sum the integers at the specified indices. 607 * Integers are represented as {@code value * 2^exponent}. 608 * 609 * @param values 53-bit values. 610 * @param exponents Power of 2 exponent. 611 * @param indices Indices to sum. 612 * @return the sum 613 */ 614 private static BigInteger sum(long[] values, int[] exponents, int[] indices) { 615 BigInteger m = BigInteger.ZERO; 616 for (final int i : indices) { 617 m = m.add(toBigInteger(values[i], exponents[i])); 618 } 619 return m; 620 } 621 622 /** 623 * Convert the value and left shift offset to a BigInteger. 624 * It is assumed the value is at most 53-bits. This allows optimising the left 625 * shift if it is below 11 bits. 626 * 627 * @param value 53-bit value. 628 * @param offset Left shift offset (must be positive). 629 * @return the BigInteger 630 */ 631 private static BigInteger toBigInteger(long value, int offset) { 632 // Ignore zeros. The sum method uses indices of non-zero values. 633 if (offset <= MAX_OFFSET) { 634 // Assume (value << offset) <= Long.MAX_VALUE 635 return BigInteger.valueOf(value << offset); 636 } 637 return BigInteger.valueOf(value).shiftLeft(offset); 638 } 639 640 /** 641 * Creates the sampler. 642 * 643 * <p>It is assumed the frequencies are all positive and the sum does not 644 * overflow. 645 * 646 * @param rng Generator of uniformly distributed random numbers. 647 * @param frequencies Observed frequencies of the discrete distribution. 648 * @param indices Indices of non-zero frequencies. 649 * @param m Sum of the frequencies. 650 * @return the sampler 651 */ 652 private static FastLoadedDiceRollerDiscreteSampler createSampler(UniformRandomProvider rng, 653 long[] frequencies, 654 int[] indices, 655 long m) { 656 // ALGORITHM 5: PREPROCESS 657 // a == frequencies 658 // m = sum(a) 659 // h = leaf node count 660 // H = leaf node label (lH) 661 662 final int n = frequencies.length; 663 664 // k = ceil(log2(m)) 665 final int k = 64 - Long.numberOfLeadingZeros(m - 1); 666 // r = a(n+1) = 2^k - m 667 final long r = (1L << k) - m; 668 669 // Note: 670 // A sparse matrix can often be used for H, as most of its entries are empty. 671 // This implementation uses a 1D array for efficiency at the cost of memory. 672 // This is limited to approximately ((2^31 - 1) / k), e.g. 34087042 when the sum of 673 // observations is large enough to create k=63. 674 // This could be handled using a 2D array. In practice a number of categories this 675 // large is not expected and is currently not supported. 676 final int[] h = new int[k]; 677 final int[] lH = new int[checkArraySize((n + 1L) * k)]; 678 Arrays.fill(lH, NO_LABEL); 679 680 int d; 681 for (int j = 0; j < k; j++) { 682 final int shift = (k - 1) - j; 683 final long bitMask = 1L << shift; 684 685 d = 0; 686 for (final int i : indices) { 687 // bool w ← (a[i] >> (k − 1) − j)) & 1 688 // h[j] = h[j] + w 689 // if w then: 690 if ((frequencies[i] & bitMask) != 0) { 691 h[j]++; 692 // H[d][j] = i 693 lH[d * k + j] = i; 694 d++; 695 } 696 } 697 // process a(n+1) without extending the input frequencies array by 1 698 if ((r & bitMask) != 0) { 699 h[j]++; 700 lH[d * k + j] = n; 701 } 702 } 703 704 return new FLDRSampler(rng, n, k, h, lH); 705 } 706 707 /** 708 * Creates the sampler. Frequencies are represented as a 53-bit value with a 709 * left-shift offset. 710 * <pre> 711 * BigInteger.valueOf(value).shiftLeft(offset) 712 * </pre> 713 * 714 * <p>It is assumed the frequencies are all positive. 715 * 716 * @param rng Generator of uniformly distributed random numbers. 717 * @param frequencies Observed frequencies of the discrete distribution. 718 * @param offsets Left shift offsets (must be positive). 719 * @param indices Indices of non-zero frequencies. 720 * @param m Sum of the frequencies. 721 * @return the sampler 722 */ 723 private static FastLoadedDiceRollerDiscreteSampler createSampler(UniformRandomProvider rng, 724 long[] frequencies, 725 int[] offsets, 726 int[] indices, 727 BigInteger m) { 728 // Repeat the logic from createSampler(...) using extended arithmetic to test the bits 729 730 // ALGORITHM 5: PREPROCESS 731 // a == frequencies 732 // m = sum(a) 733 // h = leaf node count 734 // H = leaf node label (lH) 735 736 final int n = frequencies.length; 737 738 // k = ceil(log2(m)) 739 final int k = m.subtract(BigInteger.ONE).bitLength(); 740 // r = a(n+1) = 2^k - m 741 final BigInteger r = BigInteger.ONE.shiftLeft(k).subtract(m); 742 743 final int[] h = new int[k]; 744 final int[] lH = new int[checkArraySize((n + 1L) * k)]; 745 Arrays.fill(lH, NO_LABEL); 746 747 int d; 748 for (int j = 0; j < k; j++) { 749 final int shift = (k - 1) - j; 750 751 d = 0; 752 for (final int i : indices) { 753 // bool w ← (a[i] >> (k − 1) − j)) & 1 754 // h[j] = h[j] + w 755 // if w then: 756 if (testBit(frequencies[i], offsets[i], shift)) { 757 h[j]++; 758 // H[d][j] = i 759 lH[d * k + j] = i; 760 d++; 761 } 762 } 763 // process a(n+1) without extending the input frequencies array by 1 764 if (r.testBit(shift)) { 765 h[j]++; 766 lH[d * k + j] = n; 767 } 768 } 769 770 return new FLDRSampler(rng, n, k, h, lH); 771 } 772 773 /** 774 * Test the logical bit of the shifted integer representation. 775 * The value is assumed to have at most 53-bits of information. The offset 776 * is assumed to be positive. This is functionally equivalent to: 777 * <pre> 778 * BigInteger.valueOf(value).shiftLeft(offset).testBit(n) 779 * </pre> 780 * 781 * @param value 53-bit value. 782 * @param offset Left shift offset. 783 * @param n Index of bit to test. 784 * @return true if the bit is 1 785 */ 786 private static boolean testBit(long value, int offset, int n) { 787 if (n < offset) { 788 // All logical trailing bits are zero 789 return false; 790 } 791 // Test if outside the 53-bit value (note that the implicit 1 bit 792 // has been added to the 52-bit mantissas for 'normal' floating-point numbers). 793 final int bit = n - offset; 794 return bit <= MANTISSA_SIZE && (value & (1L << bit)) != 0; 795 } 796 797 /** 798 * Check the weights have a non-zero length. 799 * 800 * @param weights Weights. 801 * @return the length 802 */ 803 private static int checkWeightsNonZeroLength(double[] weights) { 804 if (weights == null || weights.length == 0) { 805 throw new IllegalArgumentException("weights must contain at least 1 value"); 806 } 807 return weights.length; 808 } 809 810 /** 811 * Create the indices of non-zero values. 812 * 813 * @param values Values. 814 * @return the indices 815 */ 816 private static int[] indicesOfNonZero(long[] values) { 817 int n = 0; 818 final int[] indices = new int[values.length]; 819 for (int i = 0; i < values.length; i++) { 820 if (values[i] != 0) { 821 indices[n++] = i; 822 } 823 } 824 return Arrays.copyOf(indices, n); 825 } 826 827 /** 828 * Find the index of the first non-zero frequency. 829 * 830 * @param frequencies Frequencies. 831 * @return the index 832 * @throws IllegalStateException if all frequencies are zero. 833 */ 834 static int indexOfNonZero(long[] frequencies) { 835 for (int i = 0; i < frequencies.length; i++) { 836 if (frequencies[i] != 0) { 837 return i; 838 } 839 } 840 throw new IllegalStateException("All frequencies are zero"); 841 } 842 843 /** 844 * Check the size is valid for a 1D array. 845 * 846 * @param size Size 847 * @return the size as an {@code int} 848 * @throws IllegalArgumentException if the size is too large for a 1D array. 849 */ 850 static int checkArraySize(long size) { 851 if (size > MAX_ARRAY_SIZE) { 852 throw new IllegalArgumentException("Unable to allocate array of size: " + size); 853 } 854 return (int) size; 855 } 856}