001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.commons.rng.sampling.distribution; 018 019import org.apache.commons.rng.UniformRandomProvider; 020 021/** 022 * Sampler for a discrete distribution using an optimised look-up table. 023 * 024 * <ul> 025 * <li> 026 * The method requires 30-bit integer probabilities that sum to 2<sup>30</sup> as described 027 * in George Marsaglia, Wai Wan Tsang, Jingbo Wang (2004) Fast Generation of Discrete 028 * Random Variables. Journal of Statistical Software. Vol. 11, Issue. 3, pp. 1-11. 029 * </li> 030 * </ul> 031 * 032 * <p>Sampling uses 1 call to {@link UniformRandomProvider#nextInt()}.</p> 033 * 034 * <p>Memory requirements depend on the maximum number of possible sample values, {@code n}, 035 * and the values for the probabilities. Storage is optimised for {@code n}. The worst case 036 * scenario is a uniform distribution of the maximum sample size. This is capped at 0.06MB for 037 * {@code n <= } 2<sup>8</sup>, 17.0MB for {@code n <= } 2<sup>16</sup>, and 4.3GB for 038 * {@code n <=} 2<sup>30</sup>. Realistic requirements will be in the kB range.</p> 039 * 040 * <p>The sampler supports the following distributions:</p> 041 * 042 * <ul> 043 * <li>Enumerated distribution (probabilities must be provided for each sample) 044 * <li>Poisson distribution up to {@code mean = 1024} 045 * <li>Binomial distribution up to {@code trials = 65535} 046 * </ul> 047 * 048 * @see <a href="http://dx.doi.org/10.18637/jss.v011.i03">Margsglia, et al (2004) JSS Vol. 049 * 11, Issue 3</a> 050 * @since 1.3 051 */ 052public final class MarsagliaTsangWangDiscreteSampler { 053 /** The value 2<sup>8</sup> as an {@code int}. */ 054 private static final int INT_8 = 1 << 8; 055 /** The value 2<sup>16</sup> as an {@code int}. */ 056 private static final int INT_16 = 1 << 16; 057 /** The value 2<sup>30</sup> as an {@code int}. */ 058 private static final int INT_30 = 1 << 30; 059 /** The value 2<sup>31</sup> as a {@code double}. */ 060 private static final double DOUBLE_31 = 1L << 31; 061 062 // ========================================================================= 063 // Implementation note: 064 // 065 // This sampler uses prepared look-up tables that are searched using a single 066 // random int variate. The look-up tables contain the sample value. The tables 067 // are constructed using probabilities that sum to 2^30. The original paper 068 // by Marsaglia, et al (2004) describes the use of 5, 3, or 2 look-up tables 069 // indexed using digits of base 2^6, 2^10 or 2^15. Currently only base 64 (2^6) 070 // is supported using 5 look-up tables. 071 // 072 // The implementations use 8, 16 or 32 bit storage tables to support different 073 // distribution sizes with optimal storage. Separate class implementations of 074 // the same algorithm allow array storage to be accessed directly from 1D tables. 075 // This provides a performance gain over using: abstracted storage accessed via 076 // an interface; or a single 2D table. 077 // 078 // To allow the optimal implementation to be chosen the sampler is created 079 // using factory methods. The sampler supports any probability distribution 080 // when provided via an array of probabilities and the Poisson and Binomial 081 // distributions for a restricted set of parameters. The restrictions are 082 // imposed by the requirement to compute the entire probability distribution 083 // from the controlling parameter(s) using a recursive method. Factory 084 // constructors return a SharedStateDiscreteSampler instance. Each distribution 085 // type is contained in an inner class. 086 // ========================================================================= 087 088 /** 089 * The base class for Marsaglia-Tsang-Wang samplers. 090 */ 091 private abstract static class AbstractMarsagliaTsangWangDiscreteSampler 092 implements SharedStateDiscreteSampler { 093 /** Underlying source of randomness. */ 094 protected final UniformRandomProvider rng; 095 096 /** The name of the distribution. */ 097 private final String distributionName; 098 099 /** 100 * @param rng Generator of uniformly distributed random numbers. 101 * @param distributionName Distribution name. 102 */ 103 AbstractMarsagliaTsangWangDiscreteSampler(UniformRandomProvider rng, 104 String distributionName) { 105 this.rng = rng; 106 this.distributionName = distributionName; 107 } 108 109 /** 110 * @param rng Generator of uniformly distributed random numbers. 111 * @param source Source to copy. 112 */ 113 AbstractMarsagliaTsangWangDiscreteSampler(UniformRandomProvider rng, 114 AbstractMarsagliaTsangWangDiscreteSampler source) { 115 this.rng = rng; 116 this.distributionName = source.distributionName; 117 } 118 119 /** {@inheritDoc} */ 120 @Override 121 public String toString() { 122 return "Marsaglia Tsang Wang " + distributionName + " deviate [" + rng.toString() + "]"; 123 } 124 } 125 126 /** 127 * An implementation for the sample algorithm based on the decomposition of the 128 * index in the range {@code [0,2^30)} into 5 base-64 digits with 8-bit backing storage. 129 */ 130 private static class MarsagliaTsangWangBase64Int8DiscreteSampler 131 extends AbstractMarsagliaTsangWangDiscreteSampler { 132 /** The mask to convert a {@code byte} to an unsigned 8-bit integer. */ 133 private static final int MASK = 0xff; 134 135 /** Limit for look-up table 1. */ 136 private final int t1; 137 /** Limit for look-up table 2. */ 138 private final int t2; 139 /** Limit for look-up table 3. */ 140 private final int t3; 141 /** Limit for look-up table 4. */ 142 private final int t4; 143 144 /** Look-up table table1. */ 145 private final byte[] table1; 146 /** Look-up table table2. */ 147 private final byte[] table2; 148 /** Look-up table table3. */ 149 private final byte[] table3; 150 /** Look-up table table4. */ 151 private final byte[] table4; 152 /** Look-up table table5. */ 153 private final byte[] table5; 154 155 /** 156 * @param rng Generator of uniformly distributed random numbers. 157 * @param distributionName Distribution name. 158 * @param prob The probabilities. 159 * @param offset The offset (must be positive). 160 */ 161 MarsagliaTsangWangBase64Int8DiscreteSampler(UniformRandomProvider rng, 162 String distributionName, 163 int[] prob, 164 int offset) { 165 super(rng, distributionName); 166 167 // Get table sizes for each base-64 digit 168 int n1 = 0; 169 int n2 = 0; 170 int n3 = 0; 171 int n4 = 0; 172 int n5 = 0; 173 for (final int m : prob) { 174 n1 += getBase64Digit(m, 1); 175 n2 += getBase64Digit(m, 2); 176 n3 += getBase64Digit(m, 3); 177 n4 += getBase64Digit(m, 4); 178 n5 += getBase64Digit(m, 5); 179 } 180 181 table1 = new byte[n1]; 182 table2 = new byte[n2]; 183 table3 = new byte[n3]; 184 table4 = new byte[n4]; 185 table5 = new byte[n5]; 186 187 // Compute offsets 188 t1 = n1 << 24; 189 t2 = t1 + (n2 << 18); 190 t3 = t2 + (n3 << 12); 191 t4 = t3 + (n4 << 6); 192 n1 = n2 = n3 = n4 = n5 = 0; 193 194 // Fill tables 195 for (int i = 0; i < prob.length; i++) { 196 final int m = prob[i]; 197 // Primitive type conversion will extract lower 8 bits 198 final byte k = (byte) (i + offset); 199 n1 = fill(table1, n1, n1 + getBase64Digit(m, 1), k); 200 n2 = fill(table2, n2, n2 + getBase64Digit(m, 2), k); 201 n3 = fill(table3, n3, n3 + getBase64Digit(m, 3), k); 202 n4 = fill(table4, n4, n4 + getBase64Digit(m, 4), k); 203 n5 = fill(table5, n5, n5 + getBase64Digit(m, 5), k); 204 } 205 } 206 207 /** 208 * @param rng Generator of uniformly distributed random numbers. 209 * @param source Source to copy. 210 */ 211 private MarsagliaTsangWangBase64Int8DiscreteSampler(UniformRandomProvider rng, 212 MarsagliaTsangWangBase64Int8DiscreteSampler source) { 213 super(rng, source); 214 t1 = source.t1; 215 t2 = source.t2; 216 t3 = source.t3; 217 t4 = source.t4; 218 table1 = source.table1; 219 table2 = source.table2; 220 table3 = source.table3; 221 table4 = source.table4; 222 table5 = source.table5; 223 } 224 225 /** 226 * Fill the table with the value. 227 * 228 * @param table Table. 229 * @param from Lower bound index (inclusive) 230 * @param to Upper bound index (exclusive) 231 * @param value Value. 232 * @return the upper bound index 233 */ 234 private static int fill(byte[] table, int from, int to, byte value) { 235 for (int i = from; i < to; i++) { 236 table[i] = value; 237 } 238 return to; 239 } 240 241 @Override 242 public int sample() { 243 final int j = rng.nextInt() >>> 2; 244 if (j < t1) { 245 return table1[j >>> 24] & MASK; 246 } 247 if (j < t2) { 248 return table2[(j - t1) >>> 18] & MASK; 249 } 250 if (j < t3) { 251 return table3[(j - t2) >>> 12] & MASK; 252 } 253 if (j < t4) { 254 return table4[(j - t3) >>> 6] & MASK; 255 } 256 // Note the tables are filled on the assumption that the sum of the probabilities. 257 // is >=2^30. If this is not true then the final table table5 will be smaller by the 258 // difference. So the tables *must* be constructed correctly. 259 return table5[j - t4] & MASK; 260 } 261 262 @Override 263 public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) { 264 return new MarsagliaTsangWangBase64Int8DiscreteSampler(rng, this); 265 } 266 } 267 268 /** 269 * An implementation for the sample algorithm based on the decomposition of the 270 * index in the range {@code [0,2^30)} into 5 base-64 digits with 16-bit backing storage. 271 */ 272 private static class MarsagliaTsangWangBase64Int16DiscreteSampler 273 extends AbstractMarsagliaTsangWangDiscreteSampler { 274 /** The mask to convert a {@code byte} to an unsigned 16-bit integer. */ 275 private static final int MASK = 0xffff; 276 277 /** Limit for look-up table 1. */ 278 private final int t1; 279 /** Limit for look-up table 2. */ 280 private final int t2; 281 /** Limit for look-up table 3. */ 282 private final int t3; 283 /** Limit for look-up table 4. */ 284 private final int t4; 285 286 /** Look-up table table1. */ 287 private final short[] table1; 288 /** Look-up table table2. */ 289 private final short[] table2; 290 /** Look-up table table3. */ 291 private final short[] table3; 292 /** Look-up table table4. */ 293 private final short[] table4; 294 /** Look-up table table5. */ 295 private final short[] table5; 296 297 /** 298 * @param rng Generator of uniformly distributed random numbers. 299 * @param distributionName Distribution name. 300 * @param prob The probabilities. 301 * @param offset The offset (must be positive). 302 */ 303 MarsagliaTsangWangBase64Int16DiscreteSampler(UniformRandomProvider rng, 304 String distributionName, 305 int[] prob, 306 int offset) { 307 super(rng, distributionName); 308 309 // Get table sizes for each base-64 digit 310 int n1 = 0; 311 int n2 = 0; 312 int n3 = 0; 313 int n4 = 0; 314 int n5 = 0; 315 for (final int m : prob) { 316 n1 += getBase64Digit(m, 1); 317 n2 += getBase64Digit(m, 2); 318 n3 += getBase64Digit(m, 3); 319 n4 += getBase64Digit(m, 4); 320 n5 += getBase64Digit(m, 5); 321 } 322 323 table1 = new short[n1]; 324 table2 = new short[n2]; 325 table3 = new short[n3]; 326 table4 = new short[n4]; 327 table5 = new short[n5]; 328 329 // Compute offsets 330 t1 = n1 << 24; 331 t2 = t1 + (n2 << 18); 332 t3 = t2 + (n3 << 12); 333 t4 = t3 + (n4 << 6); 334 n1 = n2 = n3 = n4 = n5 = 0; 335 336 // Fill tables 337 for (int i = 0; i < prob.length; i++) { 338 final int m = prob[i]; 339 // Primitive type conversion will extract lower 16 bits 340 final short k = (short) (i + offset); 341 n1 = fill(table1, n1, n1 + getBase64Digit(m, 1), k); 342 n2 = fill(table2, n2, n2 + getBase64Digit(m, 2), k); 343 n3 = fill(table3, n3, n3 + getBase64Digit(m, 3), k); 344 n4 = fill(table4, n4, n4 + getBase64Digit(m, 4), k); 345 n5 = fill(table5, n5, n5 + getBase64Digit(m, 5), k); 346 } 347 } 348 349 /** 350 * @param rng Generator of uniformly distributed random numbers. 351 * @param source Source to copy. 352 */ 353 private MarsagliaTsangWangBase64Int16DiscreteSampler(UniformRandomProvider rng, 354 MarsagliaTsangWangBase64Int16DiscreteSampler source) { 355 super(rng, source); 356 t1 = source.t1; 357 t2 = source.t2; 358 t3 = source.t3; 359 t4 = source.t4; 360 table1 = source.table1; 361 table2 = source.table2; 362 table3 = source.table3; 363 table4 = source.table4; 364 table5 = source.table5; 365 } 366 367 /** 368 * Fill the table with the value. 369 * 370 * @param table Table. 371 * @param from Lower bound index (inclusive) 372 * @param to Upper bound index (exclusive) 373 * @param value Value. 374 * @return the upper bound index 375 */ 376 private static int fill(short[] table, int from, int to, short value) { 377 for (int i = from; i < to; i++) { 378 table[i] = value; 379 } 380 return to; 381 } 382 383 @Override 384 public int sample() { 385 final int j = rng.nextInt() >>> 2; 386 if (j < t1) { 387 return table1[j >>> 24] & MASK; 388 } 389 if (j < t2) { 390 return table2[(j - t1) >>> 18] & MASK; 391 } 392 if (j < t3) { 393 return table3[(j - t2) >>> 12] & MASK; 394 } 395 if (j < t4) { 396 return table4[(j - t3) >>> 6] & MASK; 397 } 398 // Note the tables are filled on the assumption that the sum of the probabilities. 399 // is >=2^30. If this is not true then the final table table5 will be smaller by the 400 // difference. So the tables *must* be constructed correctly. 401 return table5[j - t4] & MASK; 402 } 403 404 @Override 405 public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) { 406 return new MarsagliaTsangWangBase64Int16DiscreteSampler(rng, this); 407 } 408 } 409 410 /** 411 * An implementation for the sample algorithm based on the decomposition of the 412 * index in the range {@code [0,2^30)} into 5 base-64 digits with 32-bit backing storage. 413 */ 414 private static class MarsagliaTsangWangBase64Int32DiscreteSampler 415 extends AbstractMarsagliaTsangWangDiscreteSampler { 416 /** Limit for look-up table 1. */ 417 private final int t1; 418 /** Limit for look-up table 2. */ 419 private final int t2; 420 /** Limit for look-up table 3. */ 421 private final int t3; 422 /** Limit for look-up table 4. */ 423 private final int t4; 424 425 /** Look-up table table1. */ 426 private final int[] table1; 427 /** Look-up table table2. */ 428 private final int[] table2; 429 /** Look-up table table3. */ 430 private final int[] table3; 431 /** Look-up table table4. */ 432 private final int[] table4; 433 /** Look-up table table5. */ 434 private final int[] table5; 435 436 /** 437 * @param rng Generator of uniformly distributed random numbers. 438 * @param distributionName Distribution name. 439 * @param prob The probabilities. 440 * @param offset The offset (must be positive). 441 */ 442 MarsagliaTsangWangBase64Int32DiscreteSampler(UniformRandomProvider rng, 443 String distributionName, 444 int[] prob, 445 int offset) { 446 super(rng, distributionName); 447 448 // Get table sizes for each base-64 digit 449 int n1 = 0; 450 int n2 = 0; 451 int n3 = 0; 452 int n4 = 0; 453 int n5 = 0; 454 for (final int m : prob) { 455 n1 += getBase64Digit(m, 1); 456 n2 += getBase64Digit(m, 2); 457 n3 += getBase64Digit(m, 3); 458 n4 += getBase64Digit(m, 4); 459 n5 += getBase64Digit(m, 5); 460 } 461 462 table1 = new int[n1]; 463 table2 = new int[n2]; 464 table3 = new int[n3]; 465 table4 = new int[n4]; 466 table5 = new int[n5]; 467 468 // Compute offsets 469 t1 = n1 << 24; 470 t2 = t1 + (n2 << 18); 471 t3 = t2 + (n3 << 12); 472 t4 = t3 + (n4 << 6); 473 n1 = n2 = n3 = n4 = n5 = 0; 474 475 // Fill tables 476 for (int i = 0; i < prob.length; i++) { 477 final int m = prob[i]; 478 final int k = i + offset; 479 n1 = fill(table1, n1, n1 + getBase64Digit(m, 1), k); 480 n2 = fill(table2, n2, n2 + getBase64Digit(m, 2), k); 481 n3 = fill(table3, n3, n3 + getBase64Digit(m, 3), k); 482 n4 = fill(table4, n4, n4 + getBase64Digit(m, 4), k); 483 n5 = fill(table5, n5, n5 + getBase64Digit(m, 5), k); 484 } 485 } 486 487 /** 488 * @param rng Generator of uniformly distributed random numbers. 489 * @param source Source to copy. 490 */ 491 private MarsagliaTsangWangBase64Int32DiscreteSampler(UniformRandomProvider rng, 492 MarsagliaTsangWangBase64Int32DiscreteSampler source) { 493 super(rng, source); 494 t1 = source.t1; 495 t2 = source.t2; 496 t3 = source.t3; 497 t4 = source.t4; 498 table1 = source.table1; 499 table2 = source.table2; 500 table3 = source.table3; 501 table4 = source.table4; 502 table5 = source.table5; 503 } 504 505 /** 506 * Fill the table with the value. 507 * 508 * @param table Table. 509 * @param from Lower bound index (inclusive) 510 * @param to Upper bound index (exclusive) 511 * @param value Value. 512 * @return the upper bound index 513 */ 514 private static int fill(int[] table, int from, int to, int value) { 515 for (int i = from; i < to; i++) { 516 table[i] = value; 517 } 518 return to; 519 } 520 521 @Override 522 public int sample() { 523 final int j = rng.nextInt() >>> 2; 524 if (j < t1) { 525 return table1[j >>> 24]; 526 } 527 if (j < t2) { 528 return table2[(j - t1) >>> 18]; 529 } 530 if (j < t3) { 531 return table3[(j - t2) >>> 12]; 532 } 533 if (j < t4) { 534 return table4[(j - t3) >>> 6]; 535 } 536 // Note the tables are filled on the assumption that the sum of the probabilities. 537 // is >=2^30. If this is not true then the final table table5 will be smaller by the 538 // difference. So the tables *must* be constructed correctly. 539 return table5[j - t4]; 540 } 541 542 @Override 543 public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) { 544 return new MarsagliaTsangWangBase64Int32DiscreteSampler(rng, this); 545 } 546 } 547 548 549 550 /** Class contains only static methods. */ 551 private MarsagliaTsangWangDiscreteSampler() {} 552 553 /** 554 * Gets the k<sup>th</sup> base 64 digit of {@code m}. 555 * 556 * @param m the value m. 557 * @param k the digit. 558 * @return the base 64 digit 559 */ 560 private static int getBase64Digit(int m, int k) { 561 return (m >>> (30 - 6 * k)) & 63; 562 } 563 564 /** 565 * Convert the probability to an integer in the range [0,2^30]. This is the numerator of 566 * a fraction with assumed denominator 2<sup>30</sup>. 567 * 568 * @param p Probability. 569 * @return the fraction numerator 570 */ 571 private static int toUnsignedInt30(double p) { 572 return (int) (p * INT_30 + 0.5); 573 } 574 575 /** 576 * Create a new instance for probabilities {@code p(i)} where the sample value {@code x} is 577 * {@code i + offset}. 578 * 579 * <p>The sum of the probabilities must be >= 2<sup>30</sup>. Only the 580 * values for cumulative probability up to 2<sup>30</sup> will be sampled.</p> 581 * 582 * @param rng Generator of uniformly distributed random numbers. 583 * @param distributionName Distribution name. 584 * @param prob The probabilities. 585 * @param offset The offset (must be positive). 586 * @return Sampler. 587 */ 588 private static SharedStateDiscreteSampler createSampler(UniformRandomProvider rng, 589 String distributionName, 590 int[] prob, 591 int offset) { 592 // Note: No argument checks for private method. 593 594 // Choose implementation based on the maximum index 595 final int maxIndex = prob.length + offset - 1; 596 if (maxIndex < INT_8) { 597 return new MarsagliaTsangWangBase64Int8DiscreteSampler(rng, distributionName, prob, offset); 598 } 599 if (maxIndex < INT_16) { 600 return new MarsagliaTsangWangBase64Int16DiscreteSampler(rng, distributionName, prob, offset); 601 } 602 return new MarsagliaTsangWangBase64Int32DiscreteSampler(rng, distributionName, prob, offset); 603 } 604 605 // ========================================================================= 606 // The following public classes provide factory methods to construct a sampler for: 607 // - Enumerated probability distribution (from provided double[] probabilities) 608 // - Poisson distribution for mean <= 1024 609 // - Binomial distribution for trials <= 65535 610 // ========================================================================= 611 612 /** 613 * Create a sampler for an enumerated distribution of {@code n} values each with an 614 * associated probability. 615 * The samples corresponding to each probability are assumed to be a natural sequence 616 * starting at zero. 617 */ 618 public static final class Enumerated { 619 /** The name of the enumerated probability distribution. */ 620 private static final String ENUMERATED_NAME = "Enumerated"; 621 622 /** Class contains only static methods. */ 623 private Enumerated() {} 624 625 /** 626 * Creates a sampler for an enumerated distribution of {@code n} values each with an 627 * associated probability. 628 * 629 * <p>The probabilities will be normalised using their sum. The only requirement 630 * is the sum is positive.</p> 631 * 632 * <p>The sum of the probabilities is normalised to 2<sup>30</sup>. Note that 633 * probabilities are adjusted to the nearest 2<sup>-30</sup> due to round-off during 634 * the normalisation conversion. Consequently any probability less than 2<sup>-31</sup> 635 * will not be observed in samples.</p> 636 * 637 * @param rng Generator of uniformly distributed random numbers. 638 * @param probabilities The list of probabilities. 639 * @return Sampler. 640 * @throws IllegalArgumentException if {@code probabilities} is null or empty, a 641 * probability is negative, infinite or {@code NaN}, or the sum of all 642 * probabilities is not strictly positive. 643 */ 644 public static SharedStateDiscreteSampler of(UniformRandomProvider rng, 645 double[] probabilities) { 646 return createSampler(rng, ENUMERATED_NAME, normaliseProbabilities(probabilities), 0); 647 } 648 649 /** 650 * Normalise the probabilities to integers that sum to 2<sup>30</sup>. 651 * 652 * @param probabilities The list of probabilities. 653 * @return the normalised probabilities. 654 * @throws IllegalArgumentException if {@code probabilities} is null or empty, a 655 * probability is negative, infinite or {@code NaN}, or the sum of all 656 * probabilities is not strictly positive. 657 */ 658 private static int[] normaliseProbabilities(double[] probabilities) { 659 final double sumProb = InternalUtils.validateProbabilities(probabilities); 660 661 // Compute the normalisation: 2^30 / sum 662 final double normalisation = INT_30 / sumProb; 663 final int[] prob = new int[probabilities.length]; 664 int sum = 0; 665 int max = 0; 666 int mode = 0; 667 for (int i = 0; i < prob.length; i++) { 668 // Add 0.5 for rounding 669 final int p = (int) (probabilities[i] * normalisation + 0.5); 670 sum += p; 671 // Find the mode (maximum probability) 672 if (max < p) { 673 max = p; 674 mode = i; 675 } 676 prob[i] = p; 677 } 678 679 // The sum must be >= 2^30. 680 // Here just compensate the difference onto the highest probability. 681 prob[mode] += INT_30 - sum; 682 683 return prob; 684 } 685 } 686 687 /** 688 * Create a sampler for the Poisson distribution. 689 */ 690 public static final class Poisson { 691 /** The name of the Poisson distribution. */ 692 private static final String POISSON_NAME = "Poisson"; 693 694 /** 695 * Upper bound on the mean for the Poisson distribution. 696 * 697 * <p>The original source code provided in Marsaglia, et al (2004) has no explicit 698 * limit but the code fails at mean >= 1941 as the transform to compute p(x=mode) 699 * produces infinity. Use a conservative limit of 1024.</p> 700 */ 701 702 private static final double MAX_MEAN = 1024; 703 /** 704 * The threshold for the mean of the Poisson distribution to switch the method used 705 * to compute the probabilities. This is taken from the example software provided by 706 * Marsaglia, et al (2004). 707 */ 708 private static final double MEAN_THRESHOLD = 21.4; 709 710 /** Class contains only static methods. */ 711 private Poisson() {} 712 713 /** 714 * Creates a sampler for the Poisson distribution. 715 * 716 * <p>Any probability less than 2<sup>-31</sup> will not be observed in samples.</p> 717 * 718 * <p>Storage requirements depend on the tabulated probability values. Example storage 719 * requirements are listed below.</p> 720 * 721 * <pre> 722 * mean table size kB 723 * 0.25 882 0.88 724 * 0.5 1135 1.14 725 * 1 1200 1.20 726 * 2 1451 1.45 727 * 4 1955 1.96 728 * 8 2961 2.96 729 * 16 4410 4.41 730 * 32 6115 6.11 731 * 64 8499 8.50 732 * 128 11528 11.53 733 * 256 15935 31.87 734 * 512 20912 41.82 735 * 1024 30614 61.23 736 * </pre> 737 * 738 * <p>Note: Storage changes to 2 bytes per index between {@code mean=128} and {@code mean=256}.</p> 739 * 740 * @param rng Generator of uniformly distributed random numbers. 741 * @param mean Mean. 742 * @return Sampler. 743 * @throws IllegalArgumentException if {@code mean <= 0} or {@code mean > 1024}. 744 */ 745 public static SharedStateDiscreteSampler of(UniformRandomProvider rng, 746 double mean) { 747 validatePoissonDistributionParameters(mean); 748 749 // Create the distribution either from X=0 or from X=mode when the mean is high. 750 return mean < MEAN_THRESHOLD ? 751 createPoissonDistributionFromX0(rng, mean) : 752 createPoissonDistributionFromXMode(rng, mean); 753 } 754 755 /** 756 * Validate the Poisson distribution parameters. 757 * 758 * @param mean Mean. 759 * @throws IllegalArgumentException if {@code mean <= 0} or {@code mean > 1024}. 760 */ 761 private static void validatePoissonDistributionParameters(double mean) { 762 if (mean <= 0) { 763 throw new IllegalArgumentException("mean is not strictly positive: " + mean); 764 } 765 if (mean > MAX_MEAN) { 766 throw new IllegalArgumentException("mean " + mean + " > " + MAX_MEAN); 767 } 768 } 769 770 /** 771 * Creates the Poisson distribution by computing probabilities recursively from {@code X=0}. 772 * 773 * @param rng Generator of uniformly distributed random numbers. 774 * @param mean Mean. 775 * @return Sampler. 776 */ 777 private static SharedStateDiscreteSampler createPoissonDistributionFromX0( 778 UniformRandomProvider rng, double mean) { 779 final double p0 = Math.exp(-mean); 780 781 // Recursive update of Poisson probability until the value is too small 782 // p(x + 1) = p(x) * mean / (x + 1) 783 double p = p0; 784 int i = 1; 785 while (p * DOUBLE_31 >= 1) { 786 p *= mean / i++; 787 } 788 789 // Probabilities are 30-bit integers, assumed denominator 2^30 790 final int size = i - 1; 791 final int[] prob = new int[size]; 792 793 p = p0; 794 prob[0] = toUnsignedInt30(p); 795 // The sum must exceed 2^30. In edges cases this is false due to round-off. 796 int sum = prob[0]; 797 for (i = 1; i < prob.length; i++) { 798 p *= mean / i; 799 prob[i] = toUnsignedInt30(p); 800 sum += prob[i]; 801 } 802 803 // If the sum is < 2^30 add the remaining sum to the mode (floor(mean)). 804 prob[(int) mean] += Math.max(0, INT_30 - sum); 805 806 // Note: offset = 0 807 return createSampler(rng, POISSON_NAME, prob, 0); 808 } 809 810 /** 811 * Creates the Poisson distribution by computing probabilities recursively upward and downward 812 * from {@code X=mode}, the location of the largest p-value. 813 * 814 * @param rng Generator of uniformly distributed random numbers. 815 * @param mean Mean. 816 * @return Sampler. 817 */ 818 private static SharedStateDiscreteSampler createPoissonDistributionFromXMode( 819 UniformRandomProvider rng, double mean) { 820 // If mean >= 21.4, generate from largest p-value up, then largest down. 821 // The largest p-value will be at the mode (floor(mean)). 822 823 // Find p(x=mode) 824 final int mode = (int) mean; 825 // This transform is stable until mean >= 1941 where p will result in Infinity 826 // before the divisor i is large enough to start reducing the product (i.e. i > c). 827 final double c = mean * Math.exp(-mean / mode); 828 double p = 1.0; 829 for (int i = 1; i <= mode; i++) { 830 p *= c / i; 831 } 832 final double pMode = p; 833 834 // Find the upper limit using recursive computation of the p-value. 835 // Note this will exit when i overflows to negative so no check on the range 836 int i = mode + 1; 837 while (p * DOUBLE_31 >= 1) { 838 p *= mean / i++; 839 } 840 final int last = i - 2; 841 842 // Find the lower limit using recursive computation of the p-value. 843 p = pMode; 844 int j = -1; 845 for (i = mode - 1; i >= 0; i--) { 846 p *= (i + 1) / mean; 847 if (p * DOUBLE_31 < 1) { 848 j = i; 849 break; 850 } 851 } 852 853 // Probabilities are 30-bit integers, assumed denominator 2^30. 854 // This is the minimum sample value: prob[x - offset] = p(x) 855 final int offset = j + 1; 856 final int size = last - offset + 1; 857 final int[] prob = new int[size]; 858 859 p = pMode; 860 prob[mode - offset] = toUnsignedInt30(p); 861 // The sum must exceed 2^30. In edges cases this is false due to round-off. 862 int sum = prob[mode - offset]; 863 // From mode to upper limit 864 for (i = mode + 1; i <= last; i++) { 865 p *= mean / i; 866 prob[i - offset] = toUnsignedInt30(p); 867 sum += prob[i - offset]; 868 } 869 // From mode to lower limit 870 p = pMode; 871 for (i = mode - 1; i >= offset; i--) { 872 p *= (i + 1) / mean; 873 prob[i - offset] = toUnsignedInt30(p); 874 sum += prob[i - offset]; 875 } 876 877 // If the sum is < 2^30 add the remaining sum to the mode. 878 // If above 2^30 then the effect is truncation of the long tail of the distribution. 879 prob[mode - offset] += Math.max(0, INT_30 - sum); 880 881 return createSampler(rng, POISSON_NAME, prob, offset); 882 } 883 } 884 885 /** 886 * Create a sampler for the Binomial distribution. 887 */ 888 public static final class Binomial { 889 /** The name of the Binomial distribution. */ 890 private static final String BINOMIAL_NAME = "Binomial"; 891 892 /** 893 * Return a fixed result for the Binomial distribution. This is a special class to handle 894 * an edge case of probability of success equal to 0 or 1. 895 */ 896 private static class MarsagliaTsangWangFixedResultBinomialSampler 897 extends AbstractMarsagliaTsangWangDiscreteSampler { 898 /** The result. */ 899 private final int result; 900 901 /** 902 * @param result Result. 903 */ 904 MarsagliaTsangWangFixedResultBinomialSampler(int result) { 905 super(null, BINOMIAL_NAME); 906 this.result = result; 907 } 908 909 @Override 910 public int sample() { 911 return result; 912 } 913 914 @Override 915 public String toString() { 916 return BINOMIAL_NAME + " deviate"; 917 } 918 919 @Override 920 public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) { 921 // No shared state 922 return this; 923 } 924 } 925 926 /** 927 * Return an inversion result for the Binomial distribution. This assumes the 928 * following: 929 * 930 * <pre> 931 * Binomial(n, p) = 1 - Binomial(n, 1 - p) 932 * </pre> 933 */ 934 private static class MarsagliaTsangWangInversionBinomialSampler 935 extends AbstractMarsagliaTsangWangDiscreteSampler { 936 /** The number of trials. */ 937 private final int trials; 938 /** The Binomial distribution sampler. */ 939 private final SharedStateDiscreteSampler sampler; 940 941 /** 942 * @param trials Number of trials. 943 * @param sampler Binomial distribution sampler. 944 */ 945 MarsagliaTsangWangInversionBinomialSampler(int trials, 946 SharedStateDiscreteSampler sampler) { 947 super(null, BINOMIAL_NAME); 948 this.trials = trials; 949 this.sampler = sampler; 950 } 951 952 @Override 953 public int sample() { 954 return trials - sampler.sample(); 955 } 956 957 @Override 958 public String toString() { 959 return sampler.toString(); 960 } 961 962 @Override 963 public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) { 964 return new MarsagliaTsangWangInversionBinomialSampler(this.trials, 965 this.sampler.withUniformRandomProvider(rng)); 966 } 967 } 968 969 /** Class contains only static methods. */ 970 private Binomial() {} 971 972 /** 973 * Creates a sampler for the Binomial distribution. 974 * 975 * <p>Any probability less than 2<sup>-31</sup> will not be observed in samples.</p> 976 * 977 * <p>Storage requirements depend on the tabulated probability values. Example storage 978 * requirements are listed below (in kB).</p> 979 * 980 * <pre> 981 * p 982 * trials 0.5 0.1 0.01 0.001 983 * 4 0.06 0.63 0.44 0.44 984 * 16 0.69 1.14 0.76 0.44 985 * 64 4.73 2.40 1.14 0.51 986 * 256 8.63 5.17 1.89 0.82 987 * 1024 31.12 9.45 3.34 0.89 988 * </pre> 989 * 990 * <p>The method requires that the Binomial distribution probability at {@code x=0} can be computed. 991 * This will fail when {@code (1 - p)^trials == 0} which requires {@code trials} to be large 992 * and/or {@code p} to be small. In this case an exception is raised.</p> 993 * 994 * @param rng Generator of uniformly distributed random numbers. 995 * @param trials Number of trials. 996 * @param probabilityOfSuccess Probability of success (p). 997 * @return Sampler. 998 * @throws IllegalArgumentException if {@code trials < 0} or {@code trials >= 2^16}, 999 * {@code p} is not in the range {@code [0-1]}, or the probability distribution cannot 1000 * be computed. 1001 */ 1002 public static SharedStateDiscreteSampler of(UniformRandomProvider rng, 1003 int trials, 1004 double probabilityOfSuccess) { 1005 validateBinomialDistributionParameters(trials, probabilityOfSuccess); 1006 1007 // Handle edge cases 1008 if (probabilityOfSuccess == 0) { 1009 return new MarsagliaTsangWangFixedResultBinomialSampler(0); 1010 } 1011 if (probabilityOfSuccess == 1) { 1012 return new MarsagliaTsangWangFixedResultBinomialSampler(trials); 1013 } 1014 1015 // Check the supported size. 1016 if (trials >= INT_16) { 1017 throw new IllegalArgumentException("Unsupported number of trials: " + trials); 1018 } 1019 1020 return createBinomialDistributionSampler(rng, trials, probabilityOfSuccess); 1021 } 1022 1023 /** 1024 * Validate the Binomial distribution parameters. 1025 * 1026 * @param trials Number of trials. 1027 * @param probabilityOfSuccess Probability of success (p). 1028 * @throws IllegalArgumentException if {@code trials < 0} or 1029 * {@code p} is not in the range {@code [0-1]} 1030 */ 1031 private static void validateBinomialDistributionParameters(int trials, double probabilityOfSuccess) { 1032 if (trials < 0) { 1033 throw new IllegalArgumentException("Trials is not positive: " + trials); 1034 } 1035 if (probabilityOfSuccess < 0 || probabilityOfSuccess > 1) { 1036 throw new IllegalArgumentException("Probability is not in range [0,1]: " + probabilityOfSuccess); 1037 } 1038 } 1039 1040 /** 1041 * Creates the Binomial distribution sampler. 1042 * 1043 * <p>This assumes the parameters for the distribution are valid. The method 1044 * will only fail if the initial probability for {@code X=0} is zero.</p> 1045 * 1046 * @param rng Generator of uniformly distributed random numbers. 1047 * @param trials Number of trials. 1048 * @param probabilityOfSuccess Probability of success (p). 1049 * @return Sampler. 1050 * @throws IllegalArgumentException if the probability distribution cannot be 1051 * computed. 1052 */ 1053 private static SharedStateDiscreteSampler createBinomialDistributionSampler( 1054 UniformRandomProvider rng, int trials, double probabilityOfSuccess) { 1055 1056 // The maximum supported value for Math.exp is approximately -744. 1057 // This occurs when trials is large and p is close to 1. 1058 // Handle this by using an inversion: generate j=Binomial(n,1-p), return n-j 1059 final boolean useInversion = probabilityOfSuccess > 0.5; 1060 final double p = useInversion ? 1 - probabilityOfSuccess : probabilityOfSuccess; 1061 1062 // Check if the distribution can be computed 1063 final double p0 = Math.exp(trials * Math.log(1 - p)); 1064 if (p0 < Double.MIN_VALUE) { 1065 throw new IllegalArgumentException("Unable to compute distribution"); 1066 } 1067 1068 // First find size of probability array 1069 double t = p0; 1070 final double h = p / (1 - p); 1071 // Find first probability above the threshold of 2^-31 1072 int begin = 0; 1073 if (t * DOUBLE_31 < 1) { 1074 // Somewhere after p(0) 1075 // Note: 1076 // If this loop is entered p(0) is < 2^-31. 1077 // This has been tested at the extreme for p(0)=Double.MIN_VALUE and either 1078 // p=0.5 or trials=2^16-1 and does not fail to find the beginning. 1079 for (int i = 1; i <= trials; i++) { 1080 t *= (trials + 1 - i) * h / i; 1081 if (t * DOUBLE_31 >= 1) { 1082 begin = i; 1083 break; 1084 } 1085 } 1086 } 1087 // Find last probability 1088 int end = trials; 1089 for (int i = begin + 1; i <= trials; i++) { 1090 t *= (trials + 1 - i) * h / i; 1091 if (t * DOUBLE_31 < 1) { 1092 end = i - 1; 1093 break; 1094 } 1095 } 1096 1097 return createBinomialDistributionSamplerFromRange(rng, trials, p, useInversion, 1098 p0, begin, end); 1099 } 1100 1101 /** 1102 * Creates the Binomial distribution sampler using only the probability values for {@code X} 1103 * between the begin and the end (inclusive). 1104 * 1105 * @param rng Generator of uniformly distributed random numbers. 1106 * @param trials Number of trials. 1107 * @param p Probability of success (p). 1108 * @param useInversion Set to {@code true} if the probability was inverted. 1109 * @param p0 Probability at {@code X=0} 1110 * @param begin Begin value {@code X} for the distribution. 1111 * @param end End value {@code X} for the distribution. 1112 * @return Sampler. 1113 */ 1114 private static SharedStateDiscreteSampler createBinomialDistributionSamplerFromRange( 1115 UniformRandomProvider rng, int trials, double p, 1116 boolean useInversion, double p0, int begin, int end) { 1117 1118 // Assign probability values as 30-bit integers 1119 final int size = end - begin + 1; 1120 final int[] prob = new int[size]; 1121 double t = p0; 1122 final double h = p / (1 - p); 1123 for (int i = 1; i <= begin; i++) { 1124 t *= (trials + 1 - i) * h / i; 1125 } 1126 int sum = toUnsignedInt30(t); 1127 prob[0] = sum; 1128 for (int i = begin + 1; i <= end; i++) { 1129 t *= (trials + 1 - i) * h / i; 1130 prob[i - begin] = toUnsignedInt30(t); 1131 sum += prob[i - begin]; 1132 } 1133 1134 // If the sum is < 2^30 add the remaining sum to the mode (floor((n+1)p))). 1135 // If above 2^30 then the effect is truncation of the long tail of the distribution. 1136 final int mode = (int) ((trials + 1) * p) - begin; 1137 prob[mode] += Math.max(0, INT_30 - sum); 1138 1139 final SharedStateDiscreteSampler sampler = createSampler(rng, BINOMIAL_NAME, prob, begin); 1140 1141 // Check if an inversion was made 1142 return useInversion ? 1143 new MarsagliaTsangWangInversionBinomialSampler(trials, sampler) : 1144 sampler; 1145 } 1146 } 1147}