1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.rng.examples.jmh.sampling.distribution;
19
20 import org.apache.commons.math3.distribution.BinomialDistribution;
21 import org.apache.commons.math3.distribution.IntegerDistribution;
22 import org.apache.commons.math3.distribution.PoissonDistribution;
23 import org.apache.commons.rng.UniformRandomProvider;
24 import org.apache.commons.rng.sampling.distribution.AliasMethodDiscreteSampler;
25 import org.apache.commons.rng.sampling.distribution.DirichletSampler;
26 import org.apache.commons.rng.sampling.distribution.DiscreteSampler;
27 import org.apache.commons.rng.sampling.distribution.FastLoadedDiceRollerDiscreteSampler;
28 import org.apache.commons.rng.sampling.distribution.GuideTableDiscreteSampler;
29 import org.apache.commons.rng.sampling.distribution.MarsagliaTsangWangDiscreteSampler;
30 import org.apache.commons.rng.simple.RandomSource;
31
32 import org.openjdk.jmh.annotations.Benchmark;
33 import org.openjdk.jmh.annotations.BenchmarkMode;
34 import org.openjdk.jmh.annotations.Fork;
35 import org.openjdk.jmh.annotations.Level;
36 import org.openjdk.jmh.annotations.Measurement;
37 import org.openjdk.jmh.annotations.Mode;
38 import org.openjdk.jmh.annotations.OutputTimeUnit;
39 import org.openjdk.jmh.annotations.Param;
40 import org.openjdk.jmh.annotations.Scope;
41 import org.openjdk.jmh.annotations.Setup;
42 import org.openjdk.jmh.annotations.State;
43 import org.openjdk.jmh.annotations.Warmup;
44
45 import java.util.Arrays;
46 import java.util.concurrent.TimeUnit;
47 import java.util.function.Supplier;
48
49
50
51
52
53 @BenchmarkMode(Mode.AverageTime)
54 @OutputTimeUnit(TimeUnit.NANOSECONDS)
55 @Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
56 @Measurement(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
57 @State(Scope.Benchmark)
58 @Fork(value = 1, jvmArgs = {"-server", "-Xms128M", "-Xmx128M"})
59 public class EnumeratedDistributionSamplersPerformance {
60
61
62
63
64
65 private int value;
66
67
68
69
70
71 @State(Scope.Benchmark)
72 public static class LocalRandomSources {
73
74
75
76
77
78
79
80
81 @Param({"WELL_44497_B",
82 "ISAAC",
83 "XO_RO_SHI_RO_128_PLUS"})
84 private String randomSourceName;
85
86
87 private UniformRandomProvider generator;
88
89
90
91
92 public UniformRandomProvider getGenerator() {
93 return generator;
94 }
95
96
97 @Setup
98 public void setup() {
99 final RandomSource randomSource = RandomSource.valueOf(randomSourceName);
100 generator = randomSource.create();
101 }
102 }
103
104
105
106
107
108
109
110 @State(Scope.Benchmark)
111 public abstract static class SamplerSources extends LocalRandomSources {
112
113
114
115 @Param({"BinarySearchDiscreteSampler",
116 "AliasMethodDiscreteSampler",
117 "GuideTableDiscreteSampler",
118 "MarsagliaTsangWangDiscreteSampler",
119 "FastLoadedDiceRollerDiscreteSampler",
120 "FastLoadedDiceRollerDiscreteSamplerLong",
121 "FastLoadedDiceRollerDiscreteSampler53",
122
123
124
125
126
127
128
129
130
131
132 })
133 private String samplerType;
134
135
136 private Supplier<DiscreteSampler> factory;
137
138
139 private DiscreteSampler sampler;
140
141
142
143
144
145
146 public DiscreteSampler getSampler() {
147 return sampler;
148 }
149
150
151 @Override
152 @Setup(Level.Iteration)
153 public void setup() {
154 super.setup();
155
156 final double[] probabilities = createProbabilities();
157 createSamplerFactory(getGenerator(), probabilities);
158 sampler = factory.get();
159 }
160
161
162
163
164
165
166 protected abstract double[] createProbabilities();
167
168
169
170
171
172
173
174 private void createSamplerFactory(final UniformRandomProvider rng,
175 final double[] probabilities) {
176 if ("BinarySearchDiscreteSampler".equals(samplerType)) {
177 factory = () -> new BinarySearchDiscreteSampler(rng, probabilities);
178 } else if ("AliasMethodDiscreteSampler".equals(samplerType)) {
179 factory = () -> AliasMethodDiscreteSampler.of(rng, probabilities);
180 } else if ("AliasMethodDiscreteSamplerNoPad".equals(samplerType)) {
181 factory = () -> AliasMethodDiscreteSampler.of(rng, probabilities, -1);
182 } else if ("AliasMethodDiscreteSamplerAlpha1".equals(samplerType)) {
183 factory = () -> AliasMethodDiscreteSampler.of(rng, probabilities, 1);
184 } else if ("AliasMethodDiscreteSamplerAlpha2".equals(samplerType)) {
185 factory = () -> AliasMethodDiscreteSampler.of(rng, probabilities, 2);
186 } else if ("GuideTableDiscreteSampler".equals(samplerType)) {
187 factory = () -> GuideTableDiscreteSampler.of(rng, probabilities);
188 } else if ("GuideTableDiscreteSamplerAlpha2".equals(samplerType)) {
189 factory = () -> GuideTableDiscreteSampler.of(rng, probabilities, 2);
190 } else if ("GuideTableDiscreteSamplerAlpha8".equals(samplerType)) {
191 factory = () -> GuideTableDiscreteSampler.of(rng, probabilities, 8);
192 } else if ("MarsagliaTsangWangDiscreteSampler".equals(samplerType)) {
193 factory = () -> MarsagliaTsangWangDiscreteSampler.Enumerated.of(rng, probabilities);
194 } else if ("FastLoadedDiceRollerDiscreteSampler".equals(samplerType)) {
195 factory = () -> FastLoadedDiceRollerDiscreteSampler.of(rng, probabilities);
196 } else if ("FastLoadedDiceRollerDiscreteSamplerLong".equals(samplerType)) {
197
198
199
200 final double sum = Arrays.stream(probabilities).sum();
201 final long[] frequencies = Arrays.stream(probabilities)
202 .mapToLong(x -> Math.round(0x1.0p62 * x / sum))
203 .toArray();
204 factory = () -> FastLoadedDiceRollerDiscreteSampler.of(rng, frequencies);
205 } else if ("FastLoadedDiceRollerDiscreteSampler53".equals(samplerType)) {
206 factory = () -> FastLoadedDiceRollerDiscreteSampler.of(rng, probabilities, 53);
207 } else {
208 throw new IllegalStateException();
209 }
210 }
211
212
213
214
215
216
217 public DiscreteSampler createSampler() {
218 return factory.get();
219 }
220 }
221
222
223
224
225
226 @State(Scope.Benchmark)
227 public static class KnownDistributionSources extends SamplerSources {
228
229 private static final double CUMULATIVE_PROBABILITY_LIMIT = 1 - 1e-9;
230
231 private static final int BINOM_N = 67;
232
233 private static final double BINOM_P = 0.7;
234
235 private static final double GEO_P = 0.2;
236
237 private static final double POISS_MEAN = 3.22;
238
239 private static final double BIMOD_MEAN1 = 10;
240
241 private static final double BIMOD_MEAN2 = 20;
242
243
244
245
246 @Param({"Binomial_N67_P0.7",
247 "Geometric_P0.2",
248 "4SidedLoadedDie",
249 "Poisson_Mean3.22",
250 "Poisson_Mean10_Mean20"})
251 private String distribution;
252
253
254 @Override
255 protected double[] createProbabilities() {
256 if ("Binomial_N67_P0.7".equals(distribution)) {
257 final BinomialDistribution dist = new BinomialDistribution(null, BINOM_N, BINOM_P);
258 return createProbabilities(dist, 0, BINOM_N);
259 } else if ("Geometric_P0.2".equals(distribution)) {
260 final double probabilityOfFailure = 1 - GEO_P;
261
262
263
264 double p = 1.0;
265
266 double[] probabilities = new double[100];
267 double sum = 0;
268 int k = 0;
269 while (k < probabilities.length) {
270 probabilities[k] = p * GEO_P;
271 sum += probabilities[k++];
272 if (sum > CUMULATIVE_PROBABILITY_LIMIT) {
273 break;
274 }
275
276 p *= probabilityOfFailure;
277 }
278 return Arrays.copyOf(probabilities, k);
279 } else if ("4SidedLoadedDie".equals(distribution)) {
280 return new double[] {1.0 / 2, 1.0 / 3, 1.0 / 12, 1.0 / 12};
281 } else if ("Poisson_Mean3.22".equals(distribution)) {
282 final IntegerDistribution dist = createPoissonDistribution(POISS_MEAN);
283 final int max = dist.inverseCumulativeProbability(CUMULATIVE_PROBABILITY_LIMIT);
284 return createProbabilities(dist, 0, max);
285 } else if ("Poisson_Mean10_Mean20".equals(distribution)) {
286
287 final IntegerDistribution dist1 = createPoissonDistribution(BIMOD_MEAN2);
288 final int max = dist1.inverseCumulativeProbability(CUMULATIVE_PROBABILITY_LIMIT);
289 final double[] p1 = createProbabilities(dist1, 0, max);
290 final double[] p2 = createProbabilities(createPoissonDistribution(BIMOD_MEAN1), 0, max);
291 for (int i = 0; i < p1.length; i++) {
292 p1[i] += p2[i];
293 }
294
295 return p1;
296 }
297 throw new IllegalStateException();
298 }
299
300
301
302
303
304
305
306 private static IntegerDistribution createPoissonDistribution(double mean) {
307 return new PoissonDistribution(null, mean,
308 PoissonDistribution.DEFAULT_EPSILON, PoissonDistribution.DEFAULT_MAX_ITERATIONS);
309 }
310
311
312
313
314
315
316
317
318
319 private static double[] createProbabilities(IntegerDistribution dist, int lower, int upper) {
320 double[] probabilities = new double[upper - lower + 1];
321 int index = 0;
322 for (int x = lower; x <= upper; x++) {
323 probabilities[index++] = dist.probability(x);
324 }
325 return probabilities;
326 }
327 }
328
329
330
331
332
333
334 @State(Scope.Benchmark)
335 public static class RandomDistributionSources extends SamplerSources {
336
337
338
339
340
341 @Param({"6",
342
343
344
345 "96",
346
347
348
349 "3072"})
350 private int randomNonUniformSize;
351
352
353 @Override
354 protected double[] createProbabilities() {
355 return RandomSource.XO_RO_SHI_RO_128_PP.create()
356 .doubles(randomNonUniformSize).toArray();
357 }
358 }
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387 @State(Scope.Benchmark)
388 public static class DirichletDistributionSources extends SamplerSources {
389
390 @Param({"4", "8", "16"})
391 private int k;
392
393
394 @Param({"0.5", "1", "2"})
395 private double alpha;
396
397
398 @Override
399 protected double[] createProbabilities() {
400 return DirichletSampler.symmetric(RandomSource.XO_RO_SHI_RO_128_PP.create(),
401 k, alpha).sample();
402 }
403 }
404
405
406
407
408
409
410
411
412
413 @State(Scope.Benchmark)
414 public static class FastLoadedDiceRollerDiscreteSamplerSources extends LocalRandomSources {
415
416 @Param({"4", "8", "16"})
417 private int k;
418
419
420 @Param({"0.5", "1", "2"})
421 private double concentration;
422
423
424 @Param({"0", "30", "53"})
425 private int alpha;
426
427
428 private Supplier<DiscreteSampler> factory;
429
430
431 private DiscreteSampler sampler;
432
433
434
435
436
437
438 public DiscreteSampler getSampler() {
439 return sampler;
440 }
441
442
443
444 @Override
445 @Setup(Level.Iteration)
446 public void setup() {
447 super.setup();
448
449 final double[] probabilities =
450 DirichletSampler.symmetric(RandomSource.XO_RO_SHI_RO_128_PP.create(),
451 k, concentration).sample();
452 final UniformRandomProvider rng = getGenerator();
453 factory = () -> FastLoadedDiceRollerDiscreteSampler.of(rng, probabilities, alpha);
454 sampler = factory.get();
455 }
456
457
458
459
460
461
462 public DiscreteSampler createSampler() {
463 return factory.get();
464 }
465 }
466
467
468
469
470 static final class BinarySearchDiscreteSampler
471 implements DiscreteSampler {
472
473 private final UniformRandomProvider rng;
474
475
476
477 private final double[] cumulativeProbabilities;
478
479
480
481
482
483
484
485
486 BinarySearchDiscreteSampler(UniformRandomProvider rng,
487 double[] probabilities) {
488
489 if (probabilities == null || probabilities.length == 0) {
490 throw new IllegalArgumentException("Probabilities must not be empty.");
491 }
492
493 final int size = probabilities.length;
494 cumulativeProbabilities = new double[size];
495
496 double sumProb = 0;
497 int count = 0;
498 for (final double prob : probabilities) {
499 if (prob < 0 ||
500 Double.isInfinite(prob) ||
501 Double.isNaN(prob)) {
502 throw new IllegalArgumentException("Invalid probability: " +
503 prob);
504 }
505
506
507 sumProb += prob;
508 cumulativeProbabilities[count++] = sumProb;
509 }
510
511 if (Double.isInfinite(sumProb) || sumProb <= 0) {
512 throw new IllegalArgumentException("Invalid sum of probabilities: " + sumProb);
513 }
514
515 this.rng = rng;
516
517
518 for (int i = 0; i < size; i++) {
519 final double norm = cumulativeProbabilities[i] / sumProb;
520 cumulativeProbabilities[i] = (norm < 1) ? norm : 1.0;
521 }
522 }
523
524
525 @Override
526 public int sample() {
527 final double u = rng.nextDouble();
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548 int lower = 0;
549 int upper = cumulativeProbabilities.length - 1;
550 while (lower < upper) {
551 final int mid = (lower + upper) >>> 1;
552 final double midVal = cumulativeProbabilities[mid];
553 if (u > midVal) {
554
555
556 lower = mid + 1;
557 } else {
558
559
560 upper = mid;
561 }
562 }
563 return upper;
564 }
565 }
566
567
568
569
570
571
572
573
574 @Benchmark
575 public int baselineInt() {
576 return value;
577 }
578
579
580
581
582
583
584
585
586 @Benchmark
587 public int baselineNextDouble(LocalRandomSources sources) {
588 return sources.getGenerator().nextDouble() < 0.5 ? 1 : 0;
589 }
590
591
592
593
594
595
596
597 @Benchmark
598 public int sampleKnown(KnownDistributionSources sources) {
599 return sources.getSampler().sample();
600 }
601
602
603
604
605
606
607
608 @Benchmark
609 public int singleSampleKnown(KnownDistributionSources sources) {
610 return sources.createSampler().sample();
611 }
612
613
614
615
616
617
618
619 @Benchmark
620 public int sampleRandom(RandomDistributionSources sources) {
621 return sources.getSampler().sample();
622 }
623
624
625
626
627
628
629
630 @Benchmark
631 public int singleSampleRandom(RandomDistributionSources sources) {
632 return sources.createSampler().sample();
633 }
634
635
636
637
638
639
640
641 @Benchmark
642 public int sampleDirichlet(DirichletDistributionSources sources) {
643 return sources.getSampler().sample();
644 }
645
646
647
648
649
650
651
652 @Benchmark
653 public int singleSampleDirichlet(DirichletDistributionSources sources) {
654 return sources.createSampler().sample();
655 }
656
657
658
659
660
661
662
663 @Benchmark
664 public int sampleFast(FastLoadedDiceRollerDiscreteSamplerSources sources) {
665 return sources.getSampler().sample();
666 }
667
668
669
670
671
672
673
674 @Benchmark
675 public int singleSampleFast(FastLoadedDiceRollerDiscreteSamplerSources sources) {
676 return sources.createSampler().sample();
677 }
678 }