1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.rng.sampling.distribution;
18
19 import org.apache.commons.rng.UniformRandomProvider;
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52 public final class MarsagliaTsangWangDiscreteSampler {
53
54 private static final int INT_8 = 1 << 8;
55
56 private static final int INT_16 = 1 << 16;
57
58 private static final int INT_30 = 1 << 30;
59
60 private static final double DOUBLE_31 = 1L << 31;
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91 private abstract static class AbstractMarsagliaTsangWangDiscreteSampler
92 implements SharedStateDiscreteSampler {
93
94 protected final UniformRandomProvider rng;
95
96
97 private final String distributionName;
98
99
100
101
102
103 AbstractMarsagliaTsangWangDiscreteSampler(UniformRandomProvider rng,
104 String distributionName) {
105 this.rng = rng;
106 this.distributionName = distributionName;
107 }
108
109
110
111
112
113 AbstractMarsagliaTsangWangDiscreteSampler(UniformRandomProvider rng,
114 AbstractMarsagliaTsangWangDiscreteSampler source) {
115 this.rng = rng;
116 this.distributionName = source.distributionName;
117 }
118
119
120 @Override
121 public String toString() {
122 return "Marsaglia Tsang Wang " + distributionName + " deviate [" + rng.toString() + "]";
123 }
124 }
125
126
127
128
129
130 private static final class MarsagliaTsangWangBase64Int8DiscreteSampler
131 extends AbstractMarsagliaTsangWangDiscreteSampler {
132
133 private static final int MASK = 0xff;
134
135
136 private final int t1;
137
138 private final int t2;
139
140 private final int t3;
141
142 private final int t4;
143
144
145 private final byte[] table1;
146
147 private final byte[] table2;
148
149 private final byte[] table3;
150
151 private final byte[] table4;
152
153 private final byte[] table5;
154
155
156
157
158
159
160
161 MarsagliaTsangWangBase64Int8DiscreteSampler(UniformRandomProvider rng,
162 String distributionName,
163 int[] prob,
164 int offset) {
165 super(rng, distributionName);
166
167
168 int n1 = 0;
169 int n2 = 0;
170 int n3 = 0;
171 int n4 = 0;
172 int n5 = 0;
173 for (final int m : prob) {
174 n1 += getBase64Digit(m, 1);
175 n2 += getBase64Digit(m, 2);
176 n3 += getBase64Digit(m, 3);
177 n4 += getBase64Digit(m, 4);
178 n5 += getBase64Digit(m, 5);
179 }
180
181 table1 = new byte[n1];
182 table2 = new byte[n2];
183 table3 = new byte[n3];
184 table4 = new byte[n4];
185 table5 = new byte[n5];
186
187
188 t1 = n1 << 24;
189 t2 = t1 + (n2 << 18);
190 t3 = t2 + (n3 << 12);
191 t4 = t3 + (n4 << 6);
192 n1 = n2 = n3 = n4 = n5 = 0;
193
194
195 for (int i = 0; i < prob.length; i++) {
196 final int m = prob[i];
197
198 final byte k = (byte) (i + offset);
199 n1 = fill(table1, n1, n1 + getBase64Digit(m, 1), k);
200 n2 = fill(table2, n2, n2 + getBase64Digit(m, 2), k);
201 n3 = fill(table3, n3, n3 + getBase64Digit(m, 3), k);
202 n4 = fill(table4, n4, n4 + getBase64Digit(m, 4), k);
203 n5 = fill(table5, n5, n5 + getBase64Digit(m, 5), k);
204 }
205 }
206
207
208
209
210
211 private MarsagliaTsangWangBase64Int8DiscreteSampler(UniformRandomProvider rng,
212 MarsagliaTsangWangBase64Int8DiscreteSampler source) {
213 super(rng, source);
214 t1 = source.t1;
215 t2 = source.t2;
216 t3 = source.t3;
217 t4 = source.t4;
218 table1 = source.table1;
219 table2 = source.table2;
220 table3 = source.table3;
221 table4 = source.table4;
222 table5 = source.table5;
223 }
224
225
226
227
228
229
230
231
232
233
234 private static int fill(byte[] table, int from, int to, byte value) {
235 for (int i = from; i < to; i++) {
236 table[i] = value;
237 }
238 return to;
239 }
240
241 @Override
242 public int sample() {
243 final int j = rng.nextInt() >>> 2;
244 if (j < t1) {
245 return table1[j >>> 24] & MASK;
246 }
247 if (j < t2) {
248 return table2[(j - t1) >>> 18] & MASK;
249 }
250 if (j < t3) {
251 return table3[(j - t2) >>> 12] & MASK;
252 }
253 if (j < t4) {
254 return table4[(j - t3) >>> 6] & MASK;
255 }
256
257
258
259 return table5[j - t4] & MASK;
260 }
261
262 @Override
263 public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
264 return new MarsagliaTsangWangBase64Int8DiscreteSampler(rng, this);
265 }
266 }
267
268
269
270
271
272 private static final class MarsagliaTsangWangBase64Int16DiscreteSampler
273 extends AbstractMarsagliaTsangWangDiscreteSampler {
274
275 private static final int MASK = 0xffff;
276
277
278 private final int t1;
279
280 private final int t2;
281
282 private final int t3;
283
284 private final int t4;
285
286
287 private final short[] table1;
288
289 private final short[] table2;
290
291 private final short[] table3;
292
293 private final short[] table4;
294
295 private final short[] table5;
296
297
298
299
300
301
302
303 MarsagliaTsangWangBase64Int16DiscreteSampler(UniformRandomProvider rng,
304 String distributionName,
305 int[] prob,
306 int offset) {
307 super(rng, distributionName);
308
309
310 int n1 = 0;
311 int n2 = 0;
312 int n3 = 0;
313 int n4 = 0;
314 int n5 = 0;
315 for (final int m : prob) {
316 n1 += getBase64Digit(m, 1);
317 n2 += getBase64Digit(m, 2);
318 n3 += getBase64Digit(m, 3);
319 n4 += getBase64Digit(m, 4);
320 n5 += getBase64Digit(m, 5);
321 }
322
323 table1 = new short[n1];
324 table2 = new short[n2];
325 table3 = new short[n3];
326 table4 = new short[n4];
327 table5 = new short[n5];
328
329
330 t1 = n1 << 24;
331 t2 = t1 + (n2 << 18);
332 t3 = t2 + (n3 << 12);
333 t4 = t3 + (n4 << 6);
334 n1 = n2 = n3 = n4 = n5 = 0;
335
336
337 for (int i = 0; i < prob.length; i++) {
338 final int m = prob[i];
339
340 final short k = (short) (i + offset);
341 n1 = fill(table1, n1, n1 + getBase64Digit(m, 1), k);
342 n2 = fill(table2, n2, n2 + getBase64Digit(m, 2), k);
343 n3 = fill(table3, n3, n3 + getBase64Digit(m, 3), k);
344 n4 = fill(table4, n4, n4 + getBase64Digit(m, 4), k);
345 n5 = fill(table5, n5, n5 + getBase64Digit(m, 5), k);
346 }
347 }
348
349
350
351
352
353 private MarsagliaTsangWangBase64Int16DiscreteSampler(UniformRandomProvider rng,
354 MarsagliaTsangWangBase64Int16DiscreteSampler source) {
355 super(rng, source);
356 t1 = source.t1;
357 t2 = source.t2;
358 t3 = source.t3;
359 t4 = source.t4;
360 table1 = source.table1;
361 table2 = source.table2;
362 table3 = source.table3;
363 table4 = source.table4;
364 table5 = source.table5;
365 }
366
367
368
369
370
371
372
373
374
375
376 private static int fill(short[] table, int from, int to, short value) {
377 for (int i = from; i < to; i++) {
378 table[i] = value;
379 }
380 return to;
381 }
382
383 @Override
384 public int sample() {
385 final int j = rng.nextInt() >>> 2;
386 if (j < t1) {
387 return table1[j >>> 24] & MASK;
388 }
389 if (j < t2) {
390 return table2[(j - t1) >>> 18] & MASK;
391 }
392 if (j < t3) {
393 return table3[(j - t2) >>> 12] & MASK;
394 }
395 if (j < t4) {
396 return table4[(j - t3) >>> 6] & MASK;
397 }
398
399
400
401 return table5[j - t4] & MASK;
402 }
403
404 @Override
405 public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
406 return new MarsagliaTsangWangBase64Int16DiscreteSampler(rng, this);
407 }
408 }
409
410
411
412
413
414 private static final class MarsagliaTsangWangBase64Int32DiscreteSampler
415 extends AbstractMarsagliaTsangWangDiscreteSampler {
416
417 private final int t1;
418
419 private final int t2;
420
421 private final int t3;
422
423 private final int t4;
424
425
426 private final int[] table1;
427
428 private final int[] table2;
429
430 private final int[] table3;
431
432 private final int[] table4;
433
434 private final int[] table5;
435
436
437
438
439
440
441
442 MarsagliaTsangWangBase64Int32DiscreteSampler(UniformRandomProvider rng,
443 String distributionName,
444 int[] prob,
445 int offset) {
446 super(rng, distributionName);
447
448
449 int n1 = 0;
450 int n2 = 0;
451 int n3 = 0;
452 int n4 = 0;
453 int n5 = 0;
454 for (final int m : prob) {
455 n1 += getBase64Digit(m, 1);
456 n2 += getBase64Digit(m, 2);
457 n3 += getBase64Digit(m, 3);
458 n4 += getBase64Digit(m, 4);
459 n5 += getBase64Digit(m, 5);
460 }
461
462 table1 = new int[n1];
463 table2 = new int[n2];
464 table3 = new int[n3];
465 table4 = new int[n4];
466 table5 = new int[n5];
467
468
469 t1 = n1 << 24;
470 t2 = t1 + (n2 << 18);
471 t3 = t2 + (n3 << 12);
472 t4 = t3 + (n4 << 6);
473 n1 = n2 = n3 = n4 = n5 = 0;
474
475
476 for (int i = 0; i < prob.length; i++) {
477 final int m = prob[i];
478 final int k = i + offset;
479 n1 = fill(table1, n1, n1 + getBase64Digit(m, 1), k);
480 n2 = fill(table2, n2, n2 + getBase64Digit(m, 2), k);
481 n3 = fill(table3, n3, n3 + getBase64Digit(m, 3), k);
482 n4 = fill(table4, n4, n4 + getBase64Digit(m, 4), k);
483 n5 = fill(table5, n5, n5 + getBase64Digit(m, 5), k);
484 }
485 }
486
487
488
489
490
491 private MarsagliaTsangWangBase64Int32DiscreteSampler(UniformRandomProvider rng,
492 MarsagliaTsangWangBase64Int32DiscreteSampler source) {
493 super(rng, source);
494 t1 = source.t1;
495 t2 = source.t2;
496 t3 = source.t3;
497 t4 = source.t4;
498 table1 = source.table1;
499 table2 = source.table2;
500 table3 = source.table3;
501 table4 = source.table4;
502 table5 = source.table5;
503 }
504
505
506
507
508
509
510
511
512
513
514 private static int fill(int[] table, int from, int to, int value) {
515 for (int i = from; i < to; i++) {
516 table[i] = value;
517 }
518 return to;
519 }
520
521 @Override
522 public int sample() {
523 final int j = rng.nextInt() >>> 2;
524 if (j < t1) {
525 return table1[j >>> 24];
526 }
527 if (j < t2) {
528 return table2[(j - t1) >>> 18];
529 }
530 if (j < t3) {
531 return table3[(j - t2) >>> 12];
532 }
533 if (j < t4) {
534 return table4[(j - t3) >>> 6];
535 }
536
537
538
539 return table5[j - t4];
540 }
541
542 @Override
543 public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
544 return new MarsagliaTsangWangBase64Int32DiscreteSampler(rng, this);
545 }
546 }
547
548
549
550
551 private MarsagliaTsangWangDiscreteSampler() {}
552
553
554
555
556
557
558
559
560 private static int getBase64Digit(int m, int k) {
561 return (m >>> (30 - 6 * k)) & 63;
562 }
563
564
565
566
567
568
569
570
571 private static int toUnsignedInt30(double p) {
572 return (int) (p * INT_30 + 0.5);
573 }
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588 private static SharedStateDiscreteSampler createSampler(UniformRandomProvider rng,
589 String distributionName,
590 int[] prob,
591 int offset) {
592
593
594
595 final int maxIndex = prob.length + offset - 1;
596 if (maxIndex < INT_8) {
597 return new MarsagliaTsangWangBase64Int8DiscreteSampler(rng, distributionName, prob, offset);
598 }
599 if (maxIndex < INT_16) {
600 return new MarsagliaTsangWangBase64Int16DiscreteSampler(rng, distributionName, prob, offset);
601 }
602 return new MarsagliaTsangWangBase64Int32DiscreteSampler(rng, distributionName, prob, offset);
603 }
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618 public static final class Enumerated {
619
620 private static final String ENUMERATED_NAME = "Enumerated";
621
622
623 private Enumerated() {}
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644 public static SharedStateDiscreteSampler of(UniformRandomProvider rng,
645 double[] probabilities) {
646 return createSampler(rng, ENUMERATED_NAME, normaliseProbabilities(probabilities), 0);
647 }
648
649
650
651
652
653
654
655
656
657
658 private static int[] normaliseProbabilities(double[] probabilities) {
659 final double sumProb = InternalUtils.validateProbabilities(probabilities);
660
661
662 final double normalisation = INT_30 / sumProb;
663 final int[] prob = new int[probabilities.length];
664 int sum = 0;
665 int max = 0;
666 int mode = 0;
667 for (int i = 0; i < prob.length; i++) {
668
669 final int p = (int) (probabilities[i] * normalisation + 0.5);
670 sum += p;
671
672 if (max < p) {
673 max = p;
674 mode = i;
675 }
676 prob[i] = p;
677 }
678
679
680
681 prob[mode] += INT_30 - sum;
682
683 return prob;
684 }
685 }
686
687
688
689
690 public static final class Poisson {
691
692 private static final String POISSON_NAME = "Poisson";
693
694
695
696
697
698
699
700
701
702 private static final double MAX_MEAN = 1024;
703
704
705
706
707
708 private static final double MEAN_THRESHOLD = 21.4;
709
710
711 private Poisson() {}
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745 public static SharedStateDiscreteSampler of(UniformRandomProvider rng,
746 double mean) {
747 validatePoissonDistributionParameters(mean);
748
749
750 return mean < MEAN_THRESHOLD ?
751 createPoissonDistributionFromX0(rng, mean) :
752 createPoissonDistributionFromXMode(rng, mean);
753 }
754
755
756
757
758
759
760
761 private static void validatePoissonDistributionParameters(double mean) {
762 InternalUtils.requireStrictlyPositive(mean, "mean");
763 if (mean > MAX_MEAN) {
764 throw new IllegalArgumentException("mean " + mean + " > " + MAX_MEAN);
765 }
766 }
767
768
769
770
771
772
773
774
775 private static SharedStateDiscreteSampler createPoissonDistributionFromX0(
776 UniformRandomProvider rng, double mean) {
777 final double p0 = Math.exp(-mean);
778
779
780
781 double p = p0;
782 int i = 1;
783 while (p * DOUBLE_31 >= 1) {
784 p *= mean / i++;
785 }
786
787
788 final int size = i - 1;
789 final int[] prob = new int[size];
790
791 p = p0;
792 prob[0] = toUnsignedInt30(p);
793
794 int sum = prob[0];
795 for (i = 1; i < prob.length; i++) {
796 p *= mean / i;
797 prob[i] = toUnsignedInt30(p);
798 sum += prob[i];
799 }
800
801
802 prob[(int) mean] += Math.max(0, INT_30 - sum);
803
804
805 return createSampler(rng, POISSON_NAME, prob, 0);
806 }
807
808
809
810
811
812
813
814
815
816 private static SharedStateDiscreteSampler createPoissonDistributionFromXMode(
817 UniformRandomProvider rng, double mean) {
818
819
820
821
822 final int mode = (int) mean;
823
824
825 final double c = mean * Math.exp(-mean / mode);
826 double p = 1.0;
827 for (int i = 1; i <= mode; i++) {
828 p *= c / i;
829 }
830 final double pMode = p;
831
832
833
834 int i = mode + 1;
835 while (p * DOUBLE_31 >= 1) {
836 p *= mean / i++;
837 }
838 final int last = i - 2;
839
840
841 p = pMode;
842 int j = -1;
843 for (i = mode - 1; i >= 0; i--) {
844 p *= (i + 1) / mean;
845 if (p * DOUBLE_31 < 1) {
846 j = i;
847 break;
848 }
849 }
850
851
852
853 final int offset = j + 1;
854 final int size = last - offset + 1;
855 final int[] prob = new int[size];
856
857 p = pMode;
858 prob[mode - offset] = toUnsignedInt30(p);
859
860 int sum = prob[mode - offset];
861
862 for (i = mode + 1; i <= last; i++) {
863 p *= mean / i;
864 prob[i - offset] = toUnsignedInt30(p);
865 sum += prob[i - offset];
866 }
867
868 p = pMode;
869 for (i = mode - 1; i >= offset; i--) {
870 p *= (i + 1) / mean;
871 prob[i - offset] = toUnsignedInt30(p);
872 sum += prob[i - offset];
873 }
874
875
876
877 prob[mode - offset] += Math.max(0, INT_30 - sum);
878
879 return createSampler(rng, POISSON_NAME, prob, offset);
880 }
881 }
882
883
884
885
886 public static final class Binomial {
887
888 private static final String BINOMIAL_NAME = "Binomial";
889
890
891
892
893
894 private static final class MarsagliaTsangWangFixedResultBinomialSampler
895 extends AbstractMarsagliaTsangWangDiscreteSampler {
896
897 private final int result;
898
899
900
901
902 MarsagliaTsangWangFixedResultBinomialSampler(int result) {
903 super(null, BINOMIAL_NAME);
904 this.result = result;
905 }
906
907 @Override
908 public int sample() {
909 return result;
910 }
911
912 @Override
913 public String toString() {
914 return BINOMIAL_NAME + " deviate";
915 }
916
917 @Override
918 public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
919
920 return this;
921 }
922 }
923
924
925
926
927
928
929
930
931
932 private static final class MarsagliaTsangWangInversionBinomialSampler
933 extends AbstractMarsagliaTsangWangDiscreteSampler {
934
935 private final int trials;
936
937 private final SharedStateDiscreteSampler sampler;
938
939
940
941
942
943 MarsagliaTsangWangInversionBinomialSampler(int trials,
944 SharedStateDiscreteSampler sampler) {
945 super(null, BINOMIAL_NAME);
946 this.trials = trials;
947 this.sampler = sampler;
948 }
949
950 @Override
951 public int sample() {
952 return trials - sampler.sample();
953 }
954
955 @Override
956 public String toString() {
957 return sampler.toString();
958 }
959
960 @Override
961 public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
962 return new MarsagliaTsangWangInversionBinomialSampler(this.trials,
963 this.sampler.withUniformRandomProvider(rng));
964 }
965 }
966
967
968 private Binomial() {}
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000 public static SharedStateDiscreteSampler of(UniformRandomProvider rng,
1001 int trials,
1002 double probabilityOfSuccess) {
1003 validateBinomialDistributionParameters(trials, probabilityOfSuccess);
1004
1005
1006 if (probabilityOfSuccess == 0) {
1007 return new MarsagliaTsangWangFixedResultBinomialSampler(0);
1008 }
1009 if (probabilityOfSuccess == 1) {
1010 return new MarsagliaTsangWangFixedResultBinomialSampler(trials);
1011 }
1012
1013
1014 if (trials >= INT_16) {
1015 throw new IllegalArgumentException("Unsupported number of trials: " + trials);
1016 }
1017
1018 return createBinomialDistributionSampler(rng, trials, probabilityOfSuccess);
1019 }
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029 private static void validateBinomialDistributionParameters(int trials, double probabilityOfSuccess) {
1030 if (trials < 0) {
1031 throw new IllegalArgumentException("Trials is not positive: " + trials);
1032 }
1033 InternalUtils.requireRangeClosed(0, 1, probabilityOfSuccess, "probability of success");
1034 }
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049 private static SharedStateDiscreteSampler createBinomialDistributionSampler(
1050 UniformRandomProvider rng, int trials, double probabilityOfSuccess) {
1051
1052
1053
1054
1055 final boolean useInversion = probabilityOfSuccess > 0.5;
1056 final double p = useInversion ? 1 - probabilityOfSuccess : probabilityOfSuccess;
1057
1058
1059 final double p0 = Math.exp(trials * Math.log(1 - p));
1060 if (p0 < Double.MIN_VALUE) {
1061 throw new IllegalArgumentException("Unable to compute distribution");
1062 }
1063
1064
1065 double t = p0;
1066 final double h = p / (1 - p);
1067
1068 int begin = 0;
1069 if (t * DOUBLE_31 < 1) {
1070
1071
1072
1073
1074
1075 for (int i = 1; i <= trials; i++) {
1076 t *= (trials + 1 - i) * h / i;
1077 if (t * DOUBLE_31 >= 1) {
1078 begin = i;
1079 break;
1080 }
1081 }
1082 }
1083
1084 int end = trials;
1085 for (int i = begin + 1; i <= trials; i++) {
1086 t *= (trials + 1 - i) * h / i;
1087 if (t * DOUBLE_31 < 1) {
1088 end = i - 1;
1089 break;
1090 }
1091 }
1092
1093 return createBinomialDistributionSamplerFromRange(rng, trials, p, useInversion,
1094 p0, begin, end);
1095 }
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110 private static SharedStateDiscreteSampler createBinomialDistributionSamplerFromRange(
1111 UniformRandomProvider rng, int trials, double p,
1112 boolean useInversion, double p0, int begin, int end) {
1113
1114
1115 final int size = end - begin + 1;
1116 final int[] prob = new int[size];
1117 double t = p0;
1118 final double h = p / (1 - p);
1119 for (int i = 1; i <= begin; i++) {
1120 t *= (trials + 1 - i) * h / i;
1121 }
1122 int sum = toUnsignedInt30(t);
1123 prob[0] = sum;
1124 for (int i = begin + 1; i <= end; i++) {
1125 t *= (trials + 1 - i) * h / i;
1126 prob[i - begin] = toUnsignedInt30(t);
1127 sum += prob[i - begin];
1128 }
1129
1130
1131
1132 final int mode = (int) ((trials + 1) * p) - begin;
1133 prob[mode] += Math.max(0, INT_30 - sum);
1134
1135 final SharedStateDiscreteSampler sampler = createSampler(rng, BINOMIAL_NAME, prob, begin);
1136
1137
1138 return useInversion ?
1139 new MarsagliaTsangWangInversionBinomialSampler(trials, sampler) :
1140 sampler;
1141 }
1142 }
1143 }