001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.commons.rng.sampling.distribution; 018 019import org.apache.commons.rng.UniformRandomProvider; 020 021/** 022 * Sampler for a discrete distribution using an optimised look-up table. 023 * 024 * <ul> 025 * <li> 026 * The method requires 30-bit integer probabilities that sum to 2<sup>30</sup> as described 027 * in George Marsaglia, Wai Wan Tsang, Jingbo Wang (2004) Fast Generation of Discrete 028 * Random Variables. Journal of Statistical Software. Vol. 11, Issue. 3, pp. 1-11. 029 * </li> 030 * </ul> 031 * 032 * <p>Sampling uses 1 call to {@link UniformRandomProvider#nextInt()}.</p> 033 * 034 * <p>Memory requirements depend on the maximum number of possible sample values, {@code n}, 035 * and the values for the probabilities. Storage is optimised for {@code n}. The worst case 036 * scenario is a uniform distribution of the maximum sample size. This is capped at 0.06MB for 037 * {@code n <= } 2<sup>8</sup>, 17.0MB for {@code n <= } 2<sup>16</sup>, and 4.3GB for 038 * {@code n <=} 2<sup>30</sup>. Realistic requirements will be in the kB range.</p> 039 * 040 * <p>The sampler supports the following distributions:</p> 041 * 042 * <ul> 043 * <li>Enumerated distribution (probabilities must be provided for each sample) 044 * <li>Poisson distribution up to {@code mean = 1024} 045 * <li>Binomial distribution up to {@code trials = 65535} 046 * </ul> 047 * 048 * @see <a href="http://dx.doi.org/10.18637/jss.v011.i03">Margsglia, et al (2004) JSS Vol. 049 * 11, Issue 3</a> 050 * @since 1.3 051 */ 052public final class MarsagliaTsangWangDiscreteSampler { 053 /** The value 2<sup>8</sup> as an {@code int}. */ 054 private static final int INT_8 = 1 << 8; 055 /** The value 2<sup>16</sup> as an {@code int}. */ 056 private static final int INT_16 = 1 << 16; 057 /** The value 2<sup>30</sup> as an {@code int}. */ 058 private static final int INT_30 = 1 << 30; 059 /** The value 2<sup>31</sup> as a {@code double}. */ 060 private static final double DOUBLE_31 = 1L << 31; 061 062 // ========================================================================= 063 // Implementation note: 064 // 065 // This sampler uses prepared look-up tables that are searched using a single 066 // random int variate. The look-up tables contain the sample value. The tables 067 // are constructed using probabilities that sum to 2^30. The original paper 068 // by Marsaglia, et al (2004) describes the use of 5, 3, or 2 look-up tables 069 // indexed using digits of base 2^6, 2^10 or 2^15. Currently only base 64 (2^6) 070 // is supported using 5 look-up tables. 071 // 072 // The implementations use 8, 16 or 32 bit storage tables to support different 073 // distribution sizes with optimal storage. Separate class implementations of 074 // the same algorithm allow array storage to be accessed directly from 1D tables. 075 // This provides a performance gain over using: abstracted storage accessed via 076 // an interface; or a single 2D table. 077 // 078 // To allow the optimal implementation to be chosen the sampler is created 079 // using factory methods. The sampler supports any probability distribution 080 // when provided via an array of probabilities and the Poisson and Binomial 081 // distributions for a restricted set of parameters. The restrictions are 082 // imposed by the requirement to compute the entire probability distribution 083 // from the controlling parameter(s) using a recursive method. Factory 084 // constructors return a SharedStateDiscreteSampler instance. Each distribution 085 // type is contained in an inner class. 086 // ========================================================================= 087 088 /** 089 * The base class for Marsaglia-Tsang-Wang samplers. 090 */ 091 private abstract static class AbstractMarsagliaTsangWangDiscreteSampler 092 implements SharedStateDiscreteSampler { 093 /** Underlying source of randomness. */ 094 protected final UniformRandomProvider rng; 095 096 /** The name of the distribution. */ 097 private final String distributionName; 098 099 /** 100 * @param rng Generator of uniformly distributed random numbers. 101 * @param distributionName Distribution name. 102 */ 103 AbstractMarsagliaTsangWangDiscreteSampler(UniformRandomProvider rng, 104 String distributionName) { 105 this.rng = rng; 106 this.distributionName = distributionName; 107 } 108 109 /** 110 * @param rng Generator of uniformly distributed random numbers. 111 * @param source Source to copy. 112 */ 113 AbstractMarsagliaTsangWangDiscreteSampler(UniformRandomProvider rng, 114 AbstractMarsagliaTsangWangDiscreteSampler source) { 115 this.rng = rng; 116 this.distributionName = source.distributionName; 117 } 118 119 /** {@inheritDoc} */ 120 @Override 121 public String toString() { 122 return "Marsaglia Tsang Wang " + distributionName + " deviate [" + rng.toString() + "]"; 123 } 124 } 125 126 /** 127 * An implementation for the sample algorithm based on the decomposition of the 128 * index in the range {@code [0,2^30)} into 5 base-64 digits with 8-bit backing storage. 129 */ 130 private static final class MarsagliaTsangWangBase64Int8DiscreteSampler 131 extends AbstractMarsagliaTsangWangDiscreteSampler { 132 /** The mask to convert a {@code byte} to an unsigned 8-bit integer. */ 133 private static final int MASK = 0xff; 134 135 /** Limit for look-up table 1. */ 136 private final int t1; 137 /** Limit for look-up table 2. */ 138 private final int t2; 139 /** Limit for look-up table 3. */ 140 private final int t3; 141 /** Limit for look-up table 4. */ 142 private final int t4; 143 144 /** Look-up table table1. */ 145 private final byte[] table1; 146 /** Look-up table table2. */ 147 private final byte[] table2; 148 /** Look-up table table3. */ 149 private final byte[] table3; 150 /** Look-up table table4. */ 151 private final byte[] table4; 152 /** Look-up table table5. */ 153 private final byte[] table5; 154 155 /** 156 * @param rng Generator of uniformly distributed random numbers. 157 * @param distributionName Distribution name. 158 * @param prob The probabilities. 159 * @param offset The offset (must be positive). 160 */ 161 MarsagliaTsangWangBase64Int8DiscreteSampler(UniformRandomProvider rng, 162 String distributionName, 163 int[] prob, 164 int offset) { 165 super(rng, distributionName); 166 167 // Get table sizes for each base-64 digit 168 int n1 = 0; 169 int n2 = 0; 170 int n3 = 0; 171 int n4 = 0; 172 int n5 = 0; 173 for (final int m : prob) { 174 n1 += getBase64Digit(m, 1); 175 n2 += getBase64Digit(m, 2); 176 n3 += getBase64Digit(m, 3); 177 n4 += getBase64Digit(m, 4); 178 n5 += getBase64Digit(m, 5); 179 } 180 181 table1 = new byte[n1]; 182 table2 = new byte[n2]; 183 table3 = new byte[n3]; 184 table4 = new byte[n4]; 185 table5 = new byte[n5]; 186 187 // Compute offsets 188 t1 = n1 << 24; 189 t2 = t1 + (n2 << 18); 190 t3 = t2 + (n3 << 12); 191 t4 = t3 + (n4 << 6); 192 n1 = n2 = n3 = n4 = n5 = 0; 193 194 // Fill tables 195 for (int i = 0; i < prob.length; i++) { 196 final int m = prob[i]; 197 // Primitive type conversion will extract lower 8 bits 198 final byte k = (byte) (i + offset); 199 n1 = fill(table1, n1, n1 + getBase64Digit(m, 1), k); 200 n2 = fill(table2, n2, n2 + getBase64Digit(m, 2), k); 201 n3 = fill(table3, n3, n3 + getBase64Digit(m, 3), k); 202 n4 = fill(table4, n4, n4 + getBase64Digit(m, 4), k); 203 n5 = fill(table5, n5, n5 + getBase64Digit(m, 5), k); 204 } 205 } 206 207 /** 208 * @param rng Generator of uniformly distributed random numbers. 209 * @param source Source to copy. 210 */ 211 private MarsagliaTsangWangBase64Int8DiscreteSampler(UniformRandomProvider rng, 212 MarsagliaTsangWangBase64Int8DiscreteSampler source) { 213 super(rng, source); 214 t1 = source.t1; 215 t2 = source.t2; 216 t3 = source.t3; 217 t4 = source.t4; 218 table1 = source.table1; 219 table2 = source.table2; 220 table3 = source.table3; 221 table4 = source.table4; 222 table5 = source.table5; 223 } 224 225 /** 226 * Fill the table with the value. 227 * 228 * @param table Table. 229 * @param from Lower bound index (inclusive) 230 * @param to Upper bound index (exclusive) 231 * @param value Value. 232 * @return the upper bound index 233 */ 234 private static int fill(byte[] table, int from, int to, byte value) { 235 for (int i = from; i < to; i++) { 236 table[i] = value; 237 } 238 return to; 239 } 240 241 @Override 242 public int sample() { 243 final int j = rng.nextInt() >>> 2; 244 if (j < t1) { 245 return table1[j >>> 24] & MASK; 246 } 247 if (j < t2) { 248 return table2[(j - t1) >>> 18] & MASK; 249 } 250 if (j < t3) { 251 return table3[(j - t2) >>> 12] & MASK; 252 } 253 if (j < t4) { 254 return table4[(j - t3) >>> 6] & MASK; 255 } 256 // Note the tables are filled on the assumption that the sum of the probabilities. 257 // is >=2^30. If this is not true then the final table table5 will be smaller by the 258 // difference. So the tables *must* be constructed correctly. 259 return table5[j - t4] & MASK; 260 } 261 262 @Override 263 public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) { 264 return new MarsagliaTsangWangBase64Int8DiscreteSampler(rng, this); 265 } 266 } 267 268 /** 269 * An implementation for the sample algorithm based on the decomposition of the 270 * index in the range {@code [0,2^30)} into 5 base-64 digits with 16-bit backing storage. 271 */ 272 private static final class MarsagliaTsangWangBase64Int16DiscreteSampler 273 extends AbstractMarsagliaTsangWangDiscreteSampler { 274 /** The mask to convert a {@code byte} to an unsigned 16-bit integer. */ 275 private static final int MASK = 0xffff; 276 277 /** Limit for look-up table 1. */ 278 private final int t1; 279 /** Limit for look-up table 2. */ 280 private final int t2; 281 /** Limit for look-up table 3. */ 282 private final int t3; 283 /** Limit for look-up table 4. */ 284 private final int t4; 285 286 /** Look-up table table1. */ 287 private final short[] table1; 288 /** Look-up table table2. */ 289 private final short[] table2; 290 /** Look-up table table3. */ 291 private final short[] table3; 292 /** Look-up table table4. */ 293 private final short[] table4; 294 /** Look-up table table5. */ 295 private final short[] table5; 296 297 /** 298 * @param rng Generator of uniformly distributed random numbers. 299 * @param distributionName Distribution name. 300 * @param prob The probabilities. 301 * @param offset The offset (must be positive). 302 */ 303 MarsagliaTsangWangBase64Int16DiscreteSampler(UniformRandomProvider rng, 304 String distributionName, 305 int[] prob, 306 int offset) { 307 super(rng, distributionName); 308 309 // Get table sizes for each base-64 digit 310 int n1 = 0; 311 int n2 = 0; 312 int n3 = 0; 313 int n4 = 0; 314 int n5 = 0; 315 for (final int m : prob) { 316 n1 += getBase64Digit(m, 1); 317 n2 += getBase64Digit(m, 2); 318 n3 += getBase64Digit(m, 3); 319 n4 += getBase64Digit(m, 4); 320 n5 += getBase64Digit(m, 5); 321 } 322 323 table1 = new short[n1]; 324 table2 = new short[n2]; 325 table3 = new short[n3]; 326 table4 = new short[n4]; 327 table5 = new short[n5]; 328 329 // Compute offsets 330 t1 = n1 << 24; 331 t2 = t1 + (n2 << 18); 332 t3 = t2 + (n3 << 12); 333 t4 = t3 + (n4 << 6); 334 n1 = n2 = n3 = n4 = n5 = 0; 335 336 // Fill tables 337 for (int i = 0; i < prob.length; i++) { 338 final int m = prob[i]; 339 // Primitive type conversion will extract lower 16 bits 340 final short k = (short) (i + offset); 341 n1 = fill(table1, n1, n1 + getBase64Digit(m, 1), k); 342 n2 = fill(table2, n2, n2 + getBase64Digit(m, 2), k); 343 n3 = fill(table3, n3, n3 + getBase64Digit(m, 3), k); 344 n4 = fill(table4, n4, n4 + getBase64Digit(m, 4), k); 345 n5 = fill(table5, n5, n5 + getBase64Digit(m, 5), k); 346 } 347 } 348 349 /** 350 * @param rng Generator of uniformly distributed random numbers. 351 * @param source Source to copy. 352 */ 353 private MarsagliaTsangWangBase64Int16DiscreteSampler(UniformRandomProvider rng, 354 MarsagliaTsangWangBase64Int16DiscreteSampler source) { 355 super(rng, source); 356 t1 = source.t1; 357 t2 = source.t2; 358 t3 = source.t3; 359 t4 = source.t4; 360 table1 = source.table1; 361 table2 = source.table2; 362 table3 = source.table3; 363 table4 = source.table4; 364 table5 = source.table5; 365 } 366 367 /** 368 * Fill the table with the value. 369 * 370 * @param table Table. 371 * @param from Lower bound index (inclusive) 372 * @param to Upper bound index (exclusive) 373 * @param value Value. 374 * @return the upper bound index 375 */ 376 private static int fill(short[] table, int from, int to, short value) { 377 for (int i = from; i < to; i++) { 378 table[i] = value; 379 } 380 return to; 381 } 382 383 @Override 384 public int sample() { 385 final int j = rng.nextInt() >>> 2; 386 if (j < t1) { 387 return table1[j >>> 24] & MASK; 388 } 389 if (j < t2) { 390 return table2[(j - t1) >>> 18] & MASK; 391 } 392 if (j < t3) { 393 return table3[(j - t2) >>> 12] & MASK; 394 } 395 if (j < t4) { 396 return table4[(j - t3) >>> 6] & MASK; 397 } 398 // Note the tables are filled on the assumption that the sum of the probabilities. 399 // is >=2^30. If this is not true then the final table table5 will be smaller by the 400 // difference. So the tables *must* be constructed correctly. 401 return table5[j - t4] & MASK; 402 } 403 404 @Override 405 public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) { 406 return new MarsagliaTsangWangBase64Int16DiscreteSampler(rng, this); 407 } 408 } 409 410 /** 411 * An implementation for the sample algorithm based on the decomposition of the 412 * index in the range {@code [0,2^30)} into 5 base-64 digits with 32-bit backing storage. 413 */ 414 private static final class MarsagliaTsangWangBase64Int32DiscreteSampler 415 extends AbstractMarsagliaTsangWangDiscreteSampler { 416 /** Limit for look-up table 1. */ 417 private final int t1; 418 /** Limit for look-up table 2. */ 419 private final int t2; 420 /** Limit for look-up table 3. */ 421 private final int t3; 422 /** Limit for look-up table 4. */ 423 private final int t4; 424 425 /** Look-up table table1. */ 426 private final int[] table1; 427 /** Look-up table table2. */ 428 private final int[] table2; 429 /** Look-up table table3. */ 430 private final int[] table3; 431 /** Look-up table table4. */ 432 private final int[] table4; 433 /** Look-up table table5. */ 434 private final int[] table5; 435 436 /** 437 * @param rng Generator of uniformly distributed random numbers. 438 * @param distributionName Distribution name. 439 * @param prob The probabilities. 440 * @param offset The offset (must be positive). 441 */ 442 MarsagliaTsangWangBase64Int32DiscreteSampler(UniformRandomProvider rng, 443 String distributionName, 444 int[] prob, 445 int offset) { 446 super(rng, distributionName); 447 448 // Get table sizes for each base-64 digit 449 int n1 = 0; 450 int n2 = 0; 451 int n3 = 0; 452 int n4 = 0; 453 int n5 = 0; 454 for (final int m : prob) { 455 n1 += getBase64Digit(m, 1); 456 n2 += getBase64Digit(m, 2); 457 n3 += getBase64Digit(m, 3); 458 n4 += getBase64Digit(m, 4); 459 n5 += getBase64Digit(m, 5); 460 } 461 462 table1 = new int[n1]; 463 table2 = new int[n2]; 464 table3 = new int[n3]; 465 table4 = new int[n4]; 466 table5 = new int[n5]; 467 468 // Compute offsets 469 t1 = n1 << 24; 470 t2 = t1 + (n2 << 18); 471 t3 = t2 + (n3 << 12); 472 t4 = t3 + (n4 << 6); 473 n1 = n2 = n3 = n4 = n5 = 0; 474 475 // Fill tables 476 for (int i = 0; i < prob.length; i++) { 477 final int m = prob[i]; 478 final int k = i + offset; 479 n1 = fill(table1, n1, n1 + getBase64Digit(m, 1), k); 480 n2 = fill(table2, n2, n2 + getBase64Digit(m, 2), k); 481 n3 = fill(table3, n3, n3 + getBase64Digit(m, 3), k); 482 n4 = fill(table4, n4, n4 + getBase64Digit(m, 4), k); 483 n5 = fill(table5, n5, n5 + getBase64Digit(m, 5), k); 484 } 485 } 486 487 /** 488 * @param rng Generator of uniformly distributed random numbers. 489 * @param source Source to copy. 490 */ 491 private MarsagliaTsangWangBase64Int32DiscreteSampler(UniformRandomProvider rng, 492 MarsagliaTsangWangBase64Int32DiscreteSampler source) { 493 super(rng, source); 494 t1 = source.t1; 495 t2 = source.t2; 496 t3 = source.t3; 497 t4 = source.t4; 498 table1 = source.table1; 499 table2 = source.table2; 500 table3 = source.table3; 501 table4 = source.table4; 502 table5 = source.table5; 503 } 504 505 /** 506 * Fill the table with the value. 507 * 508 * @param table Table. 509 * @param from Lower bound index (inclusive) 510 * @param to Upper bound index (exclusive) 511 * @param value Value. 512 * @return the upper bound index 513 */ 514 private static int fill(int[] table, int from, int to, int value) { 515 for (int i = from; i < to; i++) { 516 table[i] = value; 517 } 518 return to; 519 } 520 521 @Override 522 public int sample() { 523 final int j = rng.nextInt() >>> 2; 524 if (j < t1) { 525 return table1[j >>> 24]; 526 } 527 if (j < t2) { 528 return table2[(j - t1) >>> 18]; 529 } 530 if (j < t3) { 531 return table3[(j - t2) >>> 12]; 532 } 533 if (j < t4) { 534 return table4[(j - t3) >>> 6]; 535 } 536 // Note the tables are filled on the assumption that the sum of the probabilities. 537 // is >=2^30. If this is not true then the final table table5 will be smaller by the 538 // difference. So the tables *must* be constructed correctly. 539 return table5[j - t4]; 540 } 541 542 @Override 543 public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) { 544 return new MarsagliaTsangWangBase64Int32DiscreteSampler(rng, this); 545 } 546 } 547 548 549 550 /** Class contains only static methods. */ 551 private MarsagliaTsangWangDiscreteSampler() {} 552 553 /** 554 * Gets the k<sup>th</sup> base 64 digit of {@code m}. 555 * 556 * @param m the value m. 557 * @param k the digit. 558 * @return the base 64 digit 559 */ 560 private static int getBase64Digit(int m, int k) { 561 return (m >>> (30 - 6 * k)) & 63; 562 } 563 564 /** 565 * Convert the probability to an integer in the range [0,2^30]. This is the numerator of 566 * a fraction with assumed denominator 2<sup>30</sup>. 567 * 568 * @param p Probability. 569 * @return the fraction numerator 570 */ 571 private static int toUnsignedInt30(double p) { 572 return (int) (p * INT_30 + 0.5); 573 } 574 575 /** 576 * Create a new instance for probabilities {@code p(i)} where the sample value {@code x} is 577 * {@code i + offset}. 578 * 579 * <p>The sum of the probabilities must be {@code >=} 2<sup>30</sup>. Only the 580 * values for cumulative probability up to 2<sup>30</sup> will be sampled.</p> 581 * 582 * @param rng Generator of uniformly distributed random numbers. 583 * @param distributionName Distribution name. 584 * @param prob The probabilities. 585 * @param offset The offset (must be positive). 586 * @return Sampler. 587 */ 588 private static SharedStateDiscreteSampler createSampler(UniformRandomProvider rng, 589 String distributionName, 590 int[] prob, 591 int offset) { 592 // Note: No argument checks for private method. 593 594 // Choose implementation based on the maximum index 595 final int maxIndex = prob.length + offset - 1; 596 if (maxIndex < INT_8) { 597 return new MarsagliaTsangWangBase64Int8DiscreteSampler(rng, distributionName, prob, offset); 598 } 599 if (maxIndex < INT_16) { 600 return new MarsagliaTsangWangBase64Int16DiscreteSampler(rng, distributionName, prob, offset); 601 } 602 return new MarsagliaTsangWangBase64Int32DiscreteSampler(rng, distributionName, prob, offset); 603 } 604 605 // ========================================================================= 606 // The following public classes provide factory methods to construct a sampler for: 607 // - Enumerated probability distribution (from provided double[] probabilities) 608 // - Poisson distribution for mean <= 1024 609 // - Binomial distribution for trials <= 65535 610 // ========================================================================= 611 612 /** 613 * Create a sampler for an enumerated distribution of {@code n} values each with an 614 * associated probability. 615 * The samples corresponding to each probability are assumed to be a natural sequence 616 * starting at zero. 617 */ 618 public static final class Enumerated { 619 /** The name of the enumerated probability distribution. */ 620 private static final String ENUMERATED_NAME = "Enumerated"; 621 622 /** Class contains only static methods. */ 623 private Enumerated() {} 624 625 /** 626 * Creates a sampler for an enumerated distribution of {@code n} values each with an 627 * associated probability. 628 * 629 * <p>The probabilities will be normalised using their sum. The only requirement 630 * is the sum is positive.</p> 631 * 632 * <p>The sum of the probabilities is normalised to 2<sup>30</sup>. Note that 633 * probabilities are adjusted to the nearest 2<sup>-30</sup> due to round-off during 634 * the normalisation conversion. Consequently any probability less than 2<sup>-31</sup> 635 * will not be observed in samples.</p> 636 * 637 * @param rng Generator of uniformly distributed random numbers. 638 * @param probabilities The list of probabilities. 639 * @return Sampler. 640 * @throws IllegalArgumentException if {@code probabilities} is null or empty, a 641 * probability is negative, infinite or {@code NaN}, or the sum of all 642 * probabilities is not strictly positive. 643 */ 644 public static SharedStateDiscreteSampler of(UniformRandomProvider rng, 645 double[] probabilities) { 646 return createSampler(rng, ENUMERATED_NAME, normaliseProbabilities(probabilities), 0); 647 } 648 649 /** 650 * Normalise the probabilities to integers that sum to 2<sup>30</sup>. 651 * 652 * @param probabilities The list of probabilities. 653 * @return the normalised probabilities. 654 * @throws IllegalArgumentException if {@code probabilities} is null or empty, a 655 * probability is negative, infinite or {@code NaN}, or the sum of all 656 * probabilities is not strictly positive. 657 */ 658 private static int[] normaliseProbabilities(double[] probabilities) { 659 final double sumProb = InternalUtils.validateProbabilities(probabilities); 660 661 // Compute the normalisation: 2^30 / sum 662 final double normalisation = INT_30 / sumProb; 663 final int[] prob = new int[probabilities.length]; 664 int sum = 0; 665 int max = 0; 666 int mode = 0; 667 for (int i = 0; i < prob.length; i++) { 668 // Add 0.5 for rounding 669 final int p = (int) (probabilities[i] * normalisation + 0.5); 670 sum += p; 671 // Find the mode (maximum probability) 672 if (max < p) { 673 max = p; 674 mode = i; 675 } 676 prob[i] = p; 677 } 678 679 // The sum must be >= 2^30. 680 // Here just compensate the difference onto the highest probability. 681 prob[mode] += INT_30 - sum; 682 683 return prob; 684 } 685 } 686 687 /** 688 * Create a sampler for the Poisson distribution. 689 */ 690 public static final class Poisson { 691 /** The name of the Poisson distribution. */ 692 private static final String POISSON_NAME = "Poisson"; 693 694 /** 695 * Upper bound on the mean for the Poisson distribution. 696 * 697 * <p>The original source code provided in Marsaglia, et al (2004) has no explicit 698 * limit but the code fails at mean {@code >= 1941} as the transform to compute p(x=mode) 699 * produces infinity. Use a conservative limit of 1024.</p> 700 */ 701 702 private static final double MAX_MEAN = 1024; 703 /** 704 * The threshold for the mean of the Poisson distribution to switch the method used 705 * to compute the probabilities. This is taken from the example software provided by 706 * Marsaglia, et al (2004). 707 */ 708 private static final double MEAN_THRESHOLD = 21.4; 709 710 /** Class contains only static methods. */ 711 private Poisson() {} 712 713 /** 714 * Creates a sampler for the Poisson distribution. 715 * 716 * <p>Any probability less than 2<sup>-31</sup> will not be observed in samples.</p> 717 * 718 * <p>Storage requirements depend on the tabulated probability values. Example storage 719 * requirements are listed below.</p> 720 * 721 * <pre> 722 * mean table size kB 723 * 0.25 882 0.88 724 * 0.5 1135 1.14 725 * 1 1200 1.20 726 * 2 1451 1.45 727 * 4 1955 1.96 728 * 8 2961 2.96 729 * 16 4410 4.41 730 * 32 6115 6.11 731 * 64 8499 8.50 732 * 128 11528 11.53 733 * 256 15935 31.87 734 * 512 20912 41.82 735 * 1024 30614 61.23 736 * </pre> 737 * 738 * <p>Note: Storage changes to 2 bytes per index between {@code mean=128} and {@code mean=256}.</p> 739 * 740 * @param rng Generator of uniformly distributed random numbers. 741 * @param mean Mean. 742 * @return Sampler. 743 * @throws IllegalArgumentException if {@code mean <= 0} or {@code mean > 1024}. 744 */ 745 public static SharedStateDiscreteSampler of(UniformRandomProvider rng, 746 double mean) { 747 validatePoissonDistributionParameters(mean); 748 749 // Create the distribution either from X=0 or from X=mode when the mean is high. 750 return mean < MEAN_THRESHOLD ? 751 createPoissonDistributionFromX0(rng, mean) : 752 createPoissonDistributionFromXMode(rng, mean); 753 } 754 755 /** 756 * Validate the Poisson distribution parameters. 757 * 758 * @param mean Mean. 759 * @throws IllegalArgumentException if {@code mean <= 0} or {@code mean > 1024}. 760 */ 761 private static void validatePoissonDistributionParameters(double mean) { 762 InternalUtils.requireStrictlyPositive(mean, "mean"); 763 if (mean > MAX_MEAN) { 764 throw new IllegalArgumentException("mean " + mean + " > " + MAX_MEAN); 765 } 766 } 767 768 /** 769 * Creates the Poisson distribution by computing probabilities recursively from {@code X=0}. 770 * 771 * @param rng Generator of uniformly distributed random numbers. 772 * @param mean Mean. 773 * @return Sampler. 774 */ 775 private static SharedStateDiscreteSampler createPoissonDistributionFromX0( 776 UniformRandomProvider rng, double mean) { 777 final double p0 = Math.exp(-mean); 778 779 // Recursive update of Poisson probability until the value is too small 780 // p(x + 1) = p(x) * mean / (x + 1) 781 double p = p0; 782 int i = 1; 783 while (p * DOUBLE_31 >= 1) { 784 p *= mean / i++; 785 } 786 787 // Probabilities are 30-bit integers, assumed denominator 2^30 788 final int size = i - 1; 789 final int[] prob = new int[size]; 790 791 p = p0; 792 prob[0] = toUnsignedInt30(p); 793 // The sum must exceed 2^30. In edges cases this is false due to round-off. 794 int sum = prob[0]; 795 for (i = 1; i < prob.length; i++) { 796 p *= mean / i; 797 prob[i] = toUnsignedInt30(p); 798 sum += prob[i]; 799 } 800 801 // If the sum is < 2^30 add the remaining sum to the mode (floor(mean)). 802 prob[(int) mean] += Math.max(0, INT_30 - sum); 803 804 // Note: offset = 0 805 return createSampler(rng, POISSON_NAME, prob, 0); 806 } 807 808 /** 809 * Creates the Poisson distribution by computing probabilities recursively upward and downward 810 * from {@code X=mode}, the location of the largest p-value. 811 * 812 * @param rng Generator of uniformly distributed random numbers. 813 * @param mean Mean. 814 * @return Sampler. 815 */ 816 private static SharedStateDiscreteSampler createPoissonDistributionFromXMode( 817 UniformRandomProvider rng, double mean) { 818 // If mean >= 21.4, generate from largest p-value up, then largest down. 819 // The largest p-value will be at the mode (floor(mean)). 820 821 // Find p(x=mode) 822 final int mode = (int) mean; 823 // This transform is stable until mean >= 1941 where p will result in Infinity 824 // before the divisor i is large enough to start reducing the product (i.e. i > c). 825 final double c = mean * Math.exp(-mean / mode); 826 double p = 1.0; 827 for (int i = 1; i <= mode; i++) { 828 p *= c / i; 829 } 830 final double pMode = p; 831 832 // Find the upper limit using recursive computation of the p-value. 833 // Note this will exit when i overflows to negative so no check on the range 834 int i = mode + 1; 835 while (p * DOUBLE_31 >= 1) { 836 p *= mean / i++; 837 } 838 final int last = i - 2; 839 840 // Find the lower limit using recursive computation of the p-value. 841 p = pMode; 842 int j = -1; 843 for (i = mode - 1; i >= 0; i--) { 844 p *= (i + 1) / mean; 845 if (p * DOUBLE_31 < 1) { 846 j = i; 847 break; 848 } 849 } 850 851 // Probabilities are 30-bit integers, assumed denominator 2^30. 852 // This is the minimum sample value: prob[x - offset] = p(x) 853 final int offset = j + 1; 854 final int size = last - offset + 1; 855 final int[] prob = new int[size]; 856 857 p = pMode; 858 prob[mode - offset] = toUnsignedInt30(p); 859 // The sum must exceed 2^30. In edges cases this is false due to round-off. 860 int sum = prob[mode - offset]; 861 // From mode to upper limit 862 for (i = mode + 1; i <= last; i++) { 863 p *= mean / i; 864 prob[i - offset] = toUnsignedInt30(p); 865 sum += prob[i - offset]; 866 } 867 // From mode to lower limit 868 p = pMode; 869 for (i = mode - 1; i >= offset; i--) { 870 p *= (i + 1) / mean; 871 prob[i - offset] = toUnsignedInt30(p); 872 sum += prob[i - offset]; 873 } 874 875 // If the sum is < 2^30 add the remaining sum to the mode. 876 // If above 2^30 then the effect is truncation of the long tail of the distribution. 877 prob[mode - offset] += Math.max(0, INT_30 - sum); 878 879 return createSampler(rng, POISSON_NAME, prob, offset); 880 } 881 } 882 883 /** 884 * Create a sampler for the Binomial distribution. 885 */ 886 public static final class Binomial { 887 /** The name of the Binomial distribution. */ 888 private static final String BINOMIAL_NAME = "Binomial"; 889 890 /** 891 * Return a fixed result for the Binomial distribution. This is a special class to handle 892 * an edge case of probability of success equal to 0 or 1. 893 */ 894 private static final class MarsagliaTsangWangFixedResultBinomialSampler 895 extends AbstractMarsagliaTsangWangDiscreteSampler { 896 /** The result. */ 897 private final int result; 898 899 /** 900 * @param result Result. 901 */ 902 MarsagliaTsangWangFixedResultBinomialSampler(int result) { 903 super(null, BINOMIAL_NAME); 904 this.result = result; 905 } 906 907 @Override 908 public int sample() { 909 return result; 910 } 911 912 @Override 913 public String toString() { 914 return BINOMIAL_NAME + " deviate"; 915 } 916 917 @Override 918 public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) { 919 // No shared state 920 return this; 921 } 922 } 923 924 /** 925 * Return an inversion result for the Binomial distribution. This assumes the 926 * following: 927 * 928 * <pre> 929 * Binomial(n, p) = 1 - Binomial(n, 1 - p) 930 * </pre> 931 */ 932 private static final class MarsagliaTsangWangInversionBinomialSampler 933 extends AbstractMarsagliaTsangWangDiscreteSampler { 934 /** The number of trials. */ 935 private final int trials; 936 /** The Binomial distribution sampler. */ 937 private final SharedStateDiscreteSampler sampler; 938 939 /** 940 * @param trials Number of trials. 941 * @param sampler Binomial distribution sampler. 942 */ 943 MarsagliaTsangWangInversionBinomialSampler(int trials, 944 SharedStateDiscreteSampler sampler) { 945 super(null, BINOMIAL_NAME); 946 this.trials = trials; 947 this.sampler = sampler; 948 } 949 950 @Override 951 public int sample() { 952 return trials - sampler.sample(); 953 } 954 955 @Override 956 public String toString() { 957 return sampler.toString(); 958 } 959 960 @Override 961 public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) { 962 return new MarsagliaTsangWangInversionBinomialSampler(this.trials, 963 this.sampler.withUniformRandomProvider(rng)); 964 } 965 } 966 967 /** Class contains only static methods. */ 968 private Binomial() {} 969 970 /** 971 * Creates a sampler for the Binomial distribution. 972 * 973 * <p>Any probability less than 2<sup>-31</sup> will not be observed in samples.</p> 974 * 975 * <p>Storage requirements depend on the tabulated probability values. Example storage 976 * requirements are listed below (in kB).</p> 977 * 978 * <pre> 979 * p 980 * trials 0.5 0.1 0.01 0.001 981 * 4 0.06 0.63 0.44 0.44 982 * 16 0.69 1.14 0.76 0.44 983 * 64 4.73 2.40 1.14 0.51 984 * 256 8.63 5.17 1.89 0.82 985 * 1024 31.12 9.45 3.34 0.89 986 * </pre> 987 * 988 * <p>The method requires that the Binomial distribution probability at {@code x=0} can be computed. 989 * This will fail when {@code (1 - p)^trials == 0} which requires {@code trials} to be large 990 * and/or {@code p} to be small. In this case an exception is raised.</p> 991 * 992 * @param rng Generator of uniformly distributed random numbers. 993 * @param trials Number of trials. 994 * @param probabilityOfSuccess Probability of success (p). 995 * @return Sampler. 996 * @throws IllegalArgumentException if {@code trials < 0} or {@code trials >= 2^16}, 997 * {@code p} is not in the range {@code [0-1]}, or the probability distribution cannot 998 * be computed. 999 */ 1000 public static SharedStateDiscreteSampler of(UniformRandomProvider rng, 1001 int trials, 1002 double probabilityOfSuccess) { 1003 validateBinomialDistributionParameters(trials, probabilityOfSuccess); 1004 1005 // Handle edge cases 1006 if (probabilityOfSuccess == 0) { 1007 return new MarsagliaTsangWangFixedResultBinomialSampler(0); 1008 } 1009 if (probabilityOfSuccess == 1) { 1010 return new MarsagliaTsangWangFixedResultBinomialSampler(trials); 1011 } 1012 1013 // Check the supported size. 1014 if (trials >= INT_16) { 1015 throw new IllegalArgumentException("Unsupported number of trials: " + trials); 1016 } 1017 1018 return createBinomialDistributionSampler(rng, trials, probabilityOfSuccess); 1019 } 1020 1021 /** 1022 * Validate the Binomial distribution parameters. 1023 * 1024 * @param trials Number of trials. 1025 * @param probabilityOfSuccess Probability of success (p). 1026 * @throws IllegalArgumentException if {@code trials < 0} or 1027 * {@code p} is not in the range {@code [0-1]} 1028 */ 1029 private static void validateBinomialDistributionParameters(int trials, double probabilityOfSuccess) { 1030 if (trials < 0) { 1031 throw new IllegalArgumentException("Trials is not positive: " + trials); 1032 } 1033 InternalUtils.requireRangeClosed(0, 1, probabilityOfSuccess, "probability of success"); 1034 } 1035 1036 /** 1037 * Creates the Binomial distribution sampler. 1038 * 1039 * <p>This assumes the parameters for the distribution are valid. The method 1040 * will only fail if the initial probability for {@code X=0} is zero.</p> 1041 * 1042 * @param rng Generator of uniformly distributed random numbers. 1043 * @param trials Number of trials. 1044 * @param probabilityOfSuccess Probability of success (p). 1045 * @return Sampler. 1046 * @throws IllegalArgumentException if the probability distribution cannot be 1047 * computed. 1048 */ 1049 private static SharedStateDiscreteSampler createBinomialDistributionSampler( 1050 UniformRandomProvider rng, int trials, double probabilityOfSuccess) { 1051 1052 // The maximum supported value for Math.exp is approximately -744. 1053 // This occurs when trials is large and p is close to 1. 1054 // Handle this by using an inversion: generate j=Binomial(n,1-p), return n-j 1055 final boolean useInversion = probabilityOfSuccess > 0.5; 1056 final double p = useInversion ? 1 - probabilityOfSuccess : probabilityOfSuccess; 1057 1058 // Check if the distribution can be computed 1059 final double p0 = Math.exp(trials * Math.log(1 - p)); 1060 if (p0 < Double.MIN_VALUE) { 1061 throw new IllegalArgumentException("Unable to compute distribution"); 1062 } 1063 1064 // First find size of probability array 1065 double t = p0; 1066 final double h = p / (1 - p); 1067 // Find first probability above the threshold of 2^-31 1068 int begin = 0; 1069 if (t * DOUBLE_31 < 1) { 1070 // Somewhere after p(0) 1071 // Note: 1072 // If this loop is entered p(0) is < 2^-31. 1073 // This has been tested at the extreme for p(0)=Double.MIN_VALUE and either 1074 // p=0.5 or trials=2^16-1 and does not fail to find the beginning. 1075 for (int i = 1; i <= trials; i++) { 1076 t *= (trials + 1 - i) * h / i; 1077 if (t * DOUBLE_31 >= 1) { 1078 begin = i; 1079 break; 1080 } 1081 } 1082 } 1083 // Find last probability 1084 int end = trials; 1085 for (int i = begin + 1; i <= trials; i++) { 1086 t *= (trials + 1 - i) * h / i; 1087 if (t * DOUBLE_31 < 1) { 1088 end = i - 1; 1089 break; 1090 } 1091 } 1092 1093 return createBinomialDistributionSamplerFromRange(rng, trials, p, useInversion, 1094 p0, begin, end); 1095 } 1096 1097 /** 1098 * Creates the Binomial distribution sampler using only the probability values for {@code X} 1099 * between the begin and the end (inclusive). 1100 * 1101 * @param rng Generator of uniformly distributed random numbers. 1102 * @param trials Number of trials. 1103 * @param p Probability of success (p). 1104 * @param useInversion Set to {@code true} if the probability was inverted. 1105 * @param p0 Probability at {@code X=0} 1106 * @param begin Begin value {@code X} for the distribution. 1107 * @param end End value {@code X} for the distribution. 1108 * @return Sampler. 1109 */ 1110 private static SharedStateDiscreteSampler createBinomialDistributionSamplerFromRange( 1111 UniformRandomProvider rng, int trials, double p, 1112 boolean useInversion, double p0, int begin, int end) { 1113 1114 // Assign probability values as 30-bit integers 1115 final int size = end - begin + 1; 1116 final int[] prob = new int[size]; 1117 double t = p0; 1118 final double h = p / (1 - p); 1119 for (int i = 1; i <= begin; i++) { 1120 t *= (trials + 1 - i) * h / i; 1121 } 1122 int sum = toUnsignedInt30(t); 1123 prob[0] = sum; 1124 for (int i = begin + 1; i <= end; i++) { 1125 t *= (trials + 1 - i) * h / i; 1126 prob[i - begin] = toUnsignedInt30(t); 1127 sum += prob[i - begin]; 1128 } 1129 1130 // If the sum is < 2^30 add the remaining sum to the mode (floor((n+1)p))). 1131 // If above 2^30 then the effect is truncation of the long tail of the distribution. 1132 final int mode = (int) ((trials + 1) * p) - begin; 1133 prob[mode] += Math.max(0, INT_30 - sum); 1134 1135 final SharedStateDiscreteSampler sampler = createSampler(rng, BINOMIAL_NAME, prob, begin); 1136 1137 // Check if an inversion was made 1138 return useInversion ? 1139 new MarsagliaTsangWangInversionBinomialSampler(trials, sampler) : 1140 sampler; 1141 } 1142 } 1143}