1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.rng.sampling.distribution;
18
19 import org.apache.commons.rng.UniformRandomProvider;
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52 public final class MarsagliaTsangWangDiscreteSampler {
53
54 private static final int INT_8 = 1 << 8;
55
56 private static final int INT_16 = 1 << 16;
57
58 private static final int INT_30 = 1 << 30;
59
60 private static final double DOUBLE_31 = 1L << 31;
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91 private abstract static class AbstractMarsagliaTsangWangDiscreteSampler
92 implements SharedStateDiscreteSampler {
93
94 protected final UniformRandomProvider rng;
95
96
97 private final String distributionName;
98
99
100
101
102
103 AbstractMarsagliaTsangWangDiscreteSampler(UniformRandomProvider rng,
104 String distributionName) {
105 this.rng = rng;
106 this.distributionName = distributionName;
107 }
108
109
110
111
112
113 AbstractMarsagliaTsangWangDiscreteSampler(UniformRandomProvider rng,
114 AbstractMarsagliaTsangWangDiscreteSampler source) {
115 this.rng = rng;
116 this.distributionName = source.distributionName;
117 }
118
119
120 @Override
121 public String toString() {
122 return "Marsaglia Tsang Wang " + distributionName + " deviate [" + rng.toString() + "]";
123 }
124 }
125
126
127
128
129
130 private static class MarsagliaTsangWangBase64Int8DiscreteSampler
131 extends AbstractMarsagliaTsangWangDiscreteSampler {
132
133 private static final int MASK = 0xff;
134
135
136 private final int t1;
137
138 private final int t2;
139
140 private final int t3;
141
142 private final int t4;
143
144
145 private final byte[] table1;
146
147 private final byte[] table2;
148
149 private final byte[] table3;
150
151 private final byte[] table4;
152
153 private final byte[] table5;
154
155
156
157
158
159
160
161 MarsagliaTsangWangBase64Int8DiscreteSampler(UniformRandomProvider rng,
162 String distributionName,
163 int[] prob,
164 int offset) {
165 super(rng, distributionName);
166
167
168 int n1 = 0;
169 int n2 = 0;
170 int n3 = 0;
171 int n4 = 0;
172 int n5 = 0;
173 for (final int m : prob) {
174 n1 += getBase64Digit(m, 1);
175 n2 += getBase64Digit(m, 2);
176 n3 += getBase64Digit(m, 3);
177 n4 += getBase64Digit(m, 4);
178 n5 += getBase64Digit(m, 5);
179 }
180
181 table1 = new byte[n1];
182 table2 = new byte[n2];
183 table3 = new byte[n3];
184 table4 = new byte[n4];
185 table5 = new byte[n5];
186
187
188 t1 = n1 << 24;
189 t2 = t1 + (n2 << 18);
190 t3 = t2 + (n3 << 12);
191 t4 = t3 + (n4 << 6);
192 n1 = n2 = n3 = n4 = n5 = 0;
193
194
195 for (int i = 0; i < prob.length; i++) {
196 final int m = prob[i];
197
198 final byte k = (byte) (i + offset);
199 n1 = fill(table1, n1, n1 + getBase64Digit(m, 1), k);
200 n2 = fill(table2, n2, n2 + getBase64Digit(m, 2), k);
201 n3 = fill(table3, n3, n3 + getBase64Digit(m, 3), k);
202 n4 = fill(table4, n4, n4 + getBase64Digit(m, 4), k);
203 n5 = fill(table5, n5, n5 + getBase64Digit(m, 5), k);
204 }
205 }
206
207
208
209
210
211 private MarsagliaTsangWangBase64Int8DiscreteSampler(UniformRandomProvider rng,
212 MarsagliaTsangWangBase64Int8DiscreteSampler source) {
213 super(rng, source);
214 t1 = source.t1;
215 t2 = source.t2;
216 t3 = source.t3;
217 t4 = source.t4;
218 table1 = source.table1;
219 table2 = source.table2;
220 table3 = source.table3;
221 table4 = source.table4;
222 table5 = source.table5;
223 }
224
225
226
227
228
229
230
231
232
233
234 private static int fill(byte[] table, int from, int to, byte value) {
235 for (int i = from; i < to; i++) {
236 table[i] = value;
237 }
238 return to;
239 }
240
241 @Override
242 public int sample() {
243 final int j = rng.nextInt() >>> 2;
244 if (j < t1) {
245 return table1[j >>> 24] & MASK;
246 }
247 if (j < t2) {
248 return table2[(j - t1) >>> 18] & MASK;
249 }
250 if (j < t3) {
251 return table3[(j - t2) >>> 12] & MASK;
252 }
253 if (j < t4) {
254 return table4[(j - t3) >>> 6] & MASK;
255 }
256
257
258
259 return table5[j - t4] & MASK;
260 }
261
262 @Override
263 public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
264 return new MarsagliaTsangWangBase64Int8DiscreteSampler(rng, this);
265 }
266 }
267
268
269
270
271
272 private static class MarsagliaTsangWangBase64Int16DiscreteSampler
273 extends AbstractMarsagliaTsangWangDiscreteSampler {
274
275 private static final int MASK = 0xffff;
276
277
278 private final int t1;
279
280 private final int t2;
281
282 private final int t3;
283
284 private final int t4;
285
286
287 private final short[] table1;
288
289 private final short[] table2;
290
291 private final short[] table3;
292
293 private final short[] table4;
294
295 private final short[] table5;
296
297
298
299
300
301
302
303 MarsagliaTsangWangBase64Int16DiscreteSampler(UniformRandomProvider rng,
304 String distributionName,
305 int[] prob,
306 int offset) {
307 super(rng, distributionName);
308
309
310 int n1 = 0;
311 int n2 = 0;
312 int n3 = 0;
313 int n4 = 0;
314 int n5 = 0;
315 for (final int m : prob) {
316 n1 += getBase64Digit(m, 1);
317 n2 += getBase64Digit(m, 2);
318 n3 += getBase64Digit(m, 3);
319 n4 += getBase64Digit(m, 4);
320 n5 += getBase64Digit(m, 5);
321 }
322
323 table1 = new short[n1];
324 table2 = new short[n2];
325 table3 = new short[n3];
326 table4 = new short[n4];
327 table5 = new short[n5];
328
329
330 t1 = n1 << 24;
331 t2 = t1 + (n2 << 18);
332 t3 = t2 + (n3 << 12);
333 t4 = t3 + (n4 << 6);
334 n1 = n2 = n3 = n4 = n5 = 0;
335
336
337 for (int i = 0; i < prob.length; i++) {
338 final int m = prob[i];
339
340 final short k = (short) (i + offset);
341 n1 = fill(table1, n1, n1 + getBase64Digit(m, 1), k);
342 n2 = fill(table2, n2, n2 + getBase64Digit(m, 2), k);
343 n3 = fill(table3, n3, n3 + getBase64Digit(m, 3), k);
344 n4 = fill(table4, n4, n4 + getBase64Digit(m, 4), k);
345 n5 = fill(table5, n5, n5 + getBase64Digit(m, 5), k);
346 }
347 }
348
349
350
351
352
353 private MarsagliaTsangWangBase64Int16DiscreteSampler(UniformRandomProvider rng,
354 MarsagliaTsangWangBase64Int16DiscreteSampler source) {
355 super(rng, source);
356 t1 = source.t1;
357 t2 = source.t2;
358 t3 = source.t3;
359 t4 = source.t4;
360 table1 = source.table1;
361 table2 = source.table2;
362 table3 = source.table3;
363 table4 = source.table4;
364 table5 = source.table5;
365 }
366
367
368
369
370
371
372
373
374
375
376 private static int fill(short[] table, int from, int to, short value) {
377 for (int i = from; i < to; i++) {
378 table[i] = value;
379 }
380 return to;
381 }
382
383 @Override
384 public int sample() {
385 final int j = rng.nextInt() >>> 2;
386 if (j < t1) {
387 return table1[j >>> 24] & MASK;
388 }
389 if (j < t2) {
390 return table2[(j - t1) >>> 18] & MASK;
391 }
392 if (j < t3) {
393 return table3[(j - t2) >>> 12] & MASK;
394 }
395 if (j < t4) {
396 return table4[(j - t3) >>> 6] & MASK;
397 }
398
399
400
401 return table5[j - t4] & MASK;
402 }
403
404 @Override
405 public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
406 return new MarsagliaTsangWangBase64Int16DiscreteSampler(rng, this);
407 }
408 }
409
410
411
412
413
414 private static class MarsagliaTsangWangBase64Int32DiscreteSampler
415 extends AbstractMarsagliaTsangWangDiscreteSampler {
416
417 private final int t1;
418
419 private final int t2;
420
421 private final int t3;
422
423 private final int t4;
424
425
426 private final int[] table1;
427
428 private final int[] table2;
429
430 private final int[] table3;
431
432 private final int[] table4;
433
434 private final int[] table5;
435
436
437
438
439
440
441
442 MarsagliaTsangWangBase64Int32DiscreteSampler(UniformRandomProvider rng,
443 String distributionName,
444 int[] prob,
445 int offset) {
446 super(rng, distributionName);
447
448
449 int n1 = 0;
450 int n2 = 0;
451 int n3 = 0;
452 int n4 = 0;
453 int n5 = 0;
454 for (final int m : prob) {
455 n1 += getBase64Digit(m, 1);
456 n2 += getBase64Digit(m, 2);
457 n3 += getBase64Digit(m, 3);
458 n4 += getBase64Digit(m, 4);
459 n5 += getBase64Digit(m, 5);
460 }
461
462 table1 = new int[n1];
463 table2 = new int[n2];
464 table3 = new int[n3];
465 table4 = new int[n4];
466 table5 = new int[n5];
467
468
469 t1 = n1 << 24;
470 t2 = t1 + (n2 << 18);
471 t3 = t2 + (n3 << 12);
472 t4 = t3 + (n4 << 6);
473 n1 = n2 = n3 = n4 = n5 = 0;
474
475
476 for (int i = 0; i < prob.length; i++) {
477 final int m = prob[i];
478 final int k = i + offset;
479 n1 = fill(table1, n1, n1 + getBase64Digit(m, 1), k);
480 n2 = fill(table2, n2, n2 + getBase64Digit(m, 2), k);
481 n3 = fill(table3, n3, n3 + getBase64Digit(m, 3), k);
482 n4 = fill(table4, n4, n4 + getBase64Digit(m, 4), k);
483 n5 = fill(table5, n5, n5 + getBase64Digit(m, 5), k);
484 }
485 }
486
487
488
489
490
491 private MarsagliaTsangWangBase64Int32DiscreteSampler(UniformRandomProvider rng,
492 MarsagliaTsangWangBase64Int32DiscreteSampler source) {
493 super(rng, source);
494 t1 = source.t1;
495 t2 = source.t2;
496 t3 = source.t3;
497 t4 = source.t4;
498 table1 = source.table1;
499 table2 = source.table2;
500 table3 = source.table3;
501 table4 = source.table4;
502 table5 = source.table5;
503 }
504
505
506
507
508
509
510
511
512
513
514 private static int fill(int[] table, int from, int to, int value) {
515 for (int i = from; i < to; i++) {
516 table[i] = value;
517 }
518 return to;
519 }
520
521 @Override
522 public int sample() {
523 final int j = rng.nextInt() >>> 2;
524 if (j < t1) {
525 return table1[j >>> 24];
526 }
527 if (j < t2) {
528 return table2[(j - t1) >>> 18];
529 }
530 if (j < t3) {
531 return table3[(j - t2) >>> 12];
532 }
533 if (j < t4) {
534 return table4[(j - t3) >>> 6];
535 }
536
537
538
539 return table5[j - t4];
540 }
541
542 @Override
543 public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
544 return new MarsagliaTsangWangBase64Int32DiscreteSampler(rng, this);
545 }
546 }
547
548
549
550
551 private MarsagliaTsangWangDiscreteSampler() {}
552
553
554
555
556
557
558
559
560 private static int getBase64Digit(int m, int k) {
561 return (m >>> (30 - 6 * k)) & 63;
562 }
563
564
565
566
567
568
569
570
571 private static int toUnsignedInt30(double p) {
572 return (int) (p * INT_30 + 0.5);
573 }
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588 private static SharedStateDiscreteSampler createSampler(UniformRandomProvider rng,
589 String distributionName,
590 int[] prob,
591 int offset) {
592
593
594
595 final int maxIndex = prob.length + offset - 1;
596 if (maxIndex < INT_8) {
597 return new MarsagliaTsangWangBase64Int8DiscreteSampler(rng, distributionName, prob, offset);
598 }
599 if (maxIndex < INT_16) {
600 return new MarsagliaTsangWangBase64Int16DiscreteSampler(rng, distributionName, prob, offset);
601 }
602 return new MarsagliaTsangWangBase64Int32DiscreteSampler(rng, distributionName, prob, offset);
603 }
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618 public static final class Enumerated {
619
620 private static final String ENUMERATED_NAME = "Enumerated";
621
622
623 private Enumerated() {}
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644 public static SharedStateDiscreteSampler of(UniformRandomProvider rng,
645 double[] probabilities) {
646 return createSampler(rng, ENUMERATED_NAME, normaliseProbabilities(probabilities), 0);
647 }
648
649
650
651
652
653
654
655
656
657
658 private static int[] normaliseProbabilities(double[] probabilities) {
659 final double sumProb = InternalUtils.validateProbabilities(probabilities);
660
661
662 final double normalisation = INT_30 / sumProb;
663 final int[] prob = new int[probabilities.length];
664 int sum = 0;
665 int max = 0;
666 int mode = 0;
667 for (int i = 0; i < prob.length; i++) {
668
669 final int p = (int) (probabilities[i] * normalisation + 0.5);
670 sum += p;
671
672 if (max < p) {
673 max = p;
674 mode = i;
675 }
676 prob[i] = p;
677 }
678
679
680
681 prob[mode] += INT_30 - sum;
682
683 return prob;
684 }
685 }
686
687
688
689
690 public static final class Poisson {
691
692 private static final String POISSON_NAME = "Poisson";
693
694
695
696
697
698
699
700
701
702 private static final double MAX_MEAN = 1024;
703
704
705
706
707
708 private static final double MEAN_THRESHOLD = 21.4;
709
710
711 private Poisson() {}
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745 public static SharedStateDiscreteSampler of(UniformRandomProvider rng,
746 double mean) {
747 validatePoissonDistributionParameters(mean);
748
749
750 return mean < MEAN_THRESHOLD ?
751 createPoissonDistributionFromX0(rng, mean) :
752 createPoissonDistributionFromXMode(rng, mean);
753 }
754
755
756
757
758
759
760
761 private static void validatePoissonDistributionParameters(double mean) {
762 if (mean <= 0) {
763 throw new IllegalArgumentException("mean is not strictly positive: " + mean);
764 }
765 if (mean > MAX_MEAN) {
766 throw new IllegalArgumentException("mean " + mean + " > " + MAX_MEAN);
767 }
768 }
769
770
771
772
773
774
775
776
777 private static SharedStateDiscreteSampler createPoissonDistributionFromX0(
778 UniformRandomProvider rng, double mean) {
779 final double p0 = Math.exp(-mean);
780
781
782
783 double p = p0;
784 int i = 1;
785 while (p * DOUBLE_31 >= 1) {
786 p *= mean / i++;
787 }
788
789
790 final int size = i - 1;
791 final int[] prob = new int[size];
792
793 p = p0;
794 prob[0] = toUnsignedInt30(p);
795
796 int sum = prob[0];
797 for (i = 1; i < prob.length; i++) {
798 p *= mean / i;
799 prob[i] = toUnsignedInt30(p);
800 sum += prob[i];
801 }
802
803
804 prob[(int) mean] += Math.max(0, INT_30 - sum);
805
806
807 return createSampler(rng, POISSON_NAME, prob, 0);
808 }
809
810
811
812
813
814
815
816
817
818 private static SharedStateDiscreteSampler createPoissonDistributionFromXMode(
819 UniformRandomProvider rng, double mean) {
820
821
822
823
824 final int mode = (int) mean;
825
826
827 final double c = mean * Math.exp(-mean / mode);
828 double p = 1.0;
829 for (int i = 1; i <= mode; i++) {
830 p *= c / i;
831 }
832 final double pMode = p;
833
834
835
836 int i = mode + 1;
837 while (p * DOUBLE_31 >= 1) {
838 p *= mean / i++;
839 }
840 final int last = i - 2;
841
842
843 p = pMode;
844 int j = -1;
845 for (i = mode - 1; i >= 0; i--) {
846 p *= (i + 1) / mean;
847 if (p * DOUBLE_31 < 1) {
848 j = i;
849 break;
850 }
851 }
852
853
854
855 final int offset = j + 1;
856 final int size = last - offset + 1;
857 final int[] prob = new int[size];
858
859 p = pMode;
860 prob[mode - offset] = toUnsignedInt30(p);
861
862 int sum = prob[mode - offset];
863
864 for (i = mode + 1; i <= last; i++) {
865 p *= mean / i;
866 prob[i - offset] = toUnsignedInt30(p);
867 sum += prob[i - offset];
868 }
869
870 p = pMode;
871 for (i = mode - 1; i >= offset; i--) {
872 p *= (i + 1) / mean;
873 prob[i - offset] = toUnsignedInt30(p);
874 sum += prob[i - offset];
875 }
876
877
878
879 prob[mode - offset] += Math.max(0, INT_30 - sum);
880
881 return createSampler(rng, POISSON_NAME, prob, offset);
882 }
883 }
884
885
886
887
888 public static final class Binomial {
889
890 private static final String BINOMIAL_NAME = "Binomial";
891
892
893
894
895
896 private static class MarsagliaTsangWangFixedResultBinomialSampler
897 extends AbstractMarsagliaTsangWangDiscreteSampler {
898
899 private final int result;
900
901
902
903
904 MarsagliaTsangWangFixedResultBinomialSampler(int result) {
905 super(null, BINOMIAL_NAME);
906 this.result = result;
907 }
908
909 @Override
910 public int sample() {
911 return result;
912 }
913
914 @Override
915 public String toString() {
916 return BINOMIAL_NAME + " deviate";
917 }
918
919 @Override
920 public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
921
922 return this;
923 }
924 }
925
926
927
928
929
930
931
932
933
934 private static class MarsagliaTsangWangInversionBinomialSampler
935 extends AbstractMarsagliaTsangWangDiscreteSampler {
936
937 private final int trials;
938
939 private final SharedStateDiscreteSampler sampler;
940
941
942
943
944
945 MarsagliaTsangWangInversionBinomialSampler(int trials,
946 SharedStateDiscreteSampler sampler) {
947 super(null, BINOMIAL_NAME);
948 this.trials = trials;
949 this.sampler = sampler;
950 }
951
952 @Override
953 public int sample() {
954 return trials - sampler.sample();
955 }
956
957 @Override
958 public String toString() {
959 return sampler.toString();
960 }
961
962 @Override
963 public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
964 return new MarsagliaTsangWangInversionBinomialSampler(this.trials,
965 this.sampler.withUniformRandomProvider(rng));
966 }
967 }
968
969
970 private Binomial() {}
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002 public static SharedStateDiscreteSampler of(UniformRandomProvider rng,
1003 int trials,
1004 double probabilityOfSuccess) {
1005 validateBinomialDistributionParameters(trials, probabilityOfSuccess);
1006
1007
1008 if (probabilityOfSuccess == 0) {
1009 return new MarsagliaTsangWangFixedResultBinomialSampler(0);
1010 }
1011 if (probabilityOfSuccess == 1) {
1012 return new MarsagliaTsangWangFixedResultBinomialSampler(trials);
1013 }
1014
1015
1016 if (trials >= INT_16) {
1017 throw new IllegalArgumentException("Unsupported number of trials: " + trials);
1018 }
1019
1020 return createBinomialDistributionSampler(rng, trials, probabilityOfSuccess);
1021 }
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031 private static void validateBinomialDistributionParameters(int trials, double probabilityOfSuccess) {
1032 if (trials < 0) {
1033 throw new IllegalArgumentException("Trials is not positive: " + trials);
1034 }
1035 if (probabilityOfSuccess < 0 || probabilityOfSuccess > 1) {
1036 throw new IllegalArgumentException("Probability is not in range [0,1]: " + probabilityOfSuccess);
1037 }
1038 }
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053 private static SharedStateDiscreteSampler createBinomialDistributionSampler(
1054 UniformRandomProvider rng, int trials, double probabilityOfSuccess) {
1055
1056
1057
1058
1059 final boolean useInversion = probabilityOfSuccess > 0.5;
1060 final double p = useInversion ? 1 - probabilityOfSuccess : probabilityOfSuccess;
1061
1062
1063 final double p0 = Math.exp(trials * Math.log(1 - p));
1064 if (p0 < Double.MIN_VALUE) {
1065 throw new IllegalArgumentException("Unable to compute distribution");
1066 }
1067
1068
1069 double t = p0;
1070 final double h = p / (1 - p);
1071
1072 int begin = 0;
1073 if (t * DOUBLE_31 < 1) {
1074
1075
1076
1077
1078
1079 for (int i = 1; i <= trials; i++) {
1080 t *= (trials + 1 - i) * h / i;
1081 if (t * DOUBLE_31 >= 1) {
1082 begin = i;
1083 break;
1084 }
1085 }
1086 }
1087
1088 int end = trials;
1089 for (int i = begin + 1; i <= trials; i++) {
1090 t *= (trials + 1 - i) * h / i;
1091 if (t * DOUBLE_31 < 1) {
1092 end = i - 1;
1093 break;
1094 }
1095 }
1096
1097 return createBinomialDistributionSamplerFromRange(rng, trials, p, useInversion,
1098 p0, begin, end);
1099 }
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114 private static SharedStateDiscreteSampler createBinomialDistributionSamplerFromRange(
1115 UniformRandomProvider rng, int trials, double p,
1116 boolean useInversion, double p0, int begin, int end) {
1117
1118
1119 final int size = end - begin + 1;
1120 final int[] prob = new int[size];
1121 double t = p0;
1122 final double h = p / (1 - p);
1123 for (int i = 1; i <= begin; i++) {
1124 t *= (trials + 1 - i) * h / i;
1125 }
1126 int sum = toUnsignedInt30(t);
1127 prob[0] = sum;
1128 for (int i = begin + 1; i <= end; i++) {
1129 t *= (trials + 1 - i) * h / i;
1130 prob[i - begin] = toUnsignedInt30(t);
1131 sum += prob[i - begin];
1132 }
1133
1134
1135
1136 final int mode = (int) ((trials + 1) * p) - begin;
1137 prob[mode] += Math.max(0, INT_30 - sum);
1138
1139 final SharedStateDiscreteSampler sampler = createSampler(rng, BINOMIAL_NAME, prob, begin);
1140
1141
1142 return useInversion ?
1143 new MarsagliaTsangWangInversionBinomialSampler(trials, sampler) :
1144 sampler;
1145 }
1146 }
1147 }