1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.commons.rng.examples.jmh.sampling.distribution;
19
20 import org.apache.commons.math3.distribution.BinomialDistribution;
21 import org.apache.commons.math3.distribution.IntegerDistribution;
22 import org.apache.commons.math3.distribution.PoissonDistribution;
23 import org.apache.commons.rng.UniformRandomProvider;
24 import org.apache.commons.rng.sampling.distribution.AliasMethodDiscreteSampler;
25 import org.apache.commons.rng.sampling.distribution.DirichletSampler;
26 import org.apache.commons.rng.sampling.distribution.DiscreteSampler;
27 import org.apache.commons.rng.sampling.distribution.FastLoadedDiceRollerDiscreteSampler;
28 import org.apache.commons.rng.sampling.distribution.GuideTableDiscreteSampler;
29 import org.apache.commons.rng.sampling.distribution.MarsagliaTsangWangDiscreteSampler;
30 import org.apache.commons.rng.simple.RandomSource;
31
32 import org.openjdk.jmh.annotations.Benchmark;
33 import org.openjdk.jmh.annotations.BenchmarkMode;
34 import org.openjdk.jmh.annotations.Fork;
35 import org.openjdk.jmh.annotations.Level;
36 import org.openjdk.jmh.annotations.Measurement;
37 import org.openjdk.jmh.annotations.Mode;
38 import org.openjdk.jmh.annotations.OutputTimeUnit;
39 import org.openjdk.jmh.annotations.Param;
40 import org.openjdk.jmh.annotations.Scope;
41 import org.openjdk.jmh.annotations.Setup;
42 import org.openjdk.jmh.annotations.State;
43 import org.openjdk.jmh.annotations.Warmup;
44
45 import java.util.Arrays;
46 import java.util.concurrent.TimeUnit;
47 import java.util.function.Supplier;
48
49 /**
50 * Executes benchmark to compare the speed of generation of random numbers from an enumerated
51 * discrete probability distribution.
52 */
53 @BenchmarkMode(Mode.AverageTime)
54 @OutputTimeUnit(TimeUnit.NANOSECONDS)
55 @Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
56 @Measurement(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
57 @State(Scope.Benchmark)
58 @Fork(value = 1, jvmArgs = {"-server", "-Xms128M", "-Xmx128M"})
59 public class EnumeratedDistributionSamplersPerformance {
60 /**
61 * The value for the baseline generation of an {@code int} value.
62 *
63 * <p>This must NOT be final!</p>
64 */
65 private int value;
66
67 /**
68 * The random sources to use for testing. This is a smaller list than all the possible
69 * random sources; the list is composed of generators of different speeds.
70 */
71 @State(Scope.Benchmark)
72 public static class LocalRandomSources {
73 /**
74 * RNG providers.
75 *
76 * <p>Use different speeds.</p>
77 *
78 * @see <a href="https://commons.apache.org/proper/commons-rng/userguide/rng.html">
79 * Commons RNG user guide</a>
80 */
81 @Param({"WELL_44497_B",
82 "ISAAC",
83 "XO_RO_SHI_RO_128_PLUS"})
84 private String randomSourceName;
85
86 /** RNG. */
87 private UniformRandomProvider generator;
88
89 /**
90 * @return the RNG.
91 */
92 public UniformRandomProvider getGenerator() {
93 return generator;
94 }
95
96 /** Create the random source. */
97 @Setup
98 public void setup() {
99 final RandomSource randomSource = RandomSource.valueOf(randomSourceName);
100 generator = randomSource.create();
101 }
102 }
103
104 /**
105 * The {@link DiscreteSampler} samplers to use for testing. Creates the sampler for each
106 * random source.
107 *
108 * <p>This class is abstract. The probability distribution is created by implementations.</p>
109 */
110 @State(Scope.Benchmark)
111 public abstract static class SamplerSources extends LocalRandomSources {
112 /**
113 * The sampler type.
114 */
115 @Param({"BinarySearchDiscreteSampler",
116 "AliasMethodDiscreteSampler",
117 "GuideTableDiscreteSampler",
118 "MarsagliaTsangWangDiscreteSampler",
119 "FastLoadedDiceRollerDiscreteSampler",
120 "FastLoadedDiceRollerDiscreteSamplerLong",
121 "FastLoadedDiceRollerDiscreteSampler53",
122
123 // Uncomment to test non-default parameters
124 //"AliasMethodDiscreteSamplerNoPad", // Not optimal for sampling
125 //"AliasMethodDiscreteSamplerAlpha1",
126 //"AliasMethodDiscreteSamplerAlpha2",
127
128 // The AliasMethod memory requirement doubles for each alpha increment.
129 // A fair comparison is to use 2^alpha for the equivalent guide table method.
130 //"GuideTableDiscreteSamplerAlpha2",
131 //"GuideTableDiscreteSamplerAlpha4",
132 })
133 private String samplerType;
134
135 /** The factory. */
136 private Supplier<DiscreteSampler> factory;
137
138 /** The sampler. */
139 private DiscreteSampler sampler;
140
141 /**
142 * Gets the sampler.
143 *
144 * @return the sampler.
145 */
146 public DiscreteSampler getSampler() {
147 return sampler;
148 }
149
150 /** Create the distribution (per iteration as it may vary) and instantiates sampler. */
151 @Override
152 @Setup(Level.Iteration)
153 public void setup() {
154 super.setup();
155
156 final double[] probabilities = createProbabilities();
157 createSamplerFactory(getGenerator(), probabilities);
158 sampler = factory.get();
159 }
160
161 /**
162 * Creates the probabilities for the distribution.
163 *
164 * @return The probabilities.
165 */
166 protected abstract double[] createProbabilities();
167
168 /**
169 * Creates the sampler factory.
170 *
171 * @param rng The random generator.
172 * @param probabilities The probabilities.
173 */
174 private void createSamplerFactory(final UniformRandomProvider rng,
175 final double[] probabilities) {
176 if ("BinarySearchDiscreteSampler".equals(samplerType)) {
177 factory = () -> new BinarySearchDiscreteSampler(rng, probabilities);
178 } else if ("AliasMethodDiscreteSampler".equals(samplerType)) {
179 factory = () -> AliasMethodDiscreteSampler.of(rng, probabilities);
180 } else if ("AliasMethodDiscreteSamplerNoPad".equals(samplerType)) {
181 factory = () -> AliasMethodDiscreteSampler.of(rng, probabilities, -1);
182 } else if ("AliasMethodDiscreteSamplerAlpha1".equals(samplerType)) {
183 factory = () -> AliasMethodDiscreteSampler.of(rng, probabilities, 1);
184 } else if ("AliasMethodDiscreteSamplerAlpha2".equals(samplerType)) {
185 factory = () -> AliasMethodDiscreteSampler.of(rng, probabilities, 2);
186 } else if ("GuideTableDiscreteSampler".equals(samplerType)) {
187 factory = () -> GuideTableDiscreteSampler.of(rng, probabilities);
188 } else if ("GuideTableDiscreteSamplerAlpha2".equals(samplerType)) {
189 factory = () -> GuideTableDiscreteSampler.of(rng, probabilities, 2);
190 } else if ("GuideTableDiscreteSamplerAlpha8".equals(samplerType)) {
191 factory = () -> GuideTableDiscreteSampler.of(rng, probabilities, 8);
192 } else if ("MarsagliaTsangWangDiscreteSampler".equals(samplerType)) {
193 factory = () -> MarsagliaTsangWangDiscreteSampler.Enumerated.of(rng, probabilities);
194 } else if ("FastLoadedDiceRollerDiscreteSampler".equals(samplerType)) {
195 factory = () -> FastLoadedDiceRollerDiscreteSampler.of(rng, probabilities);
196 } else if ("FastLoadedDiceRollerDiscreteSamplerLong".equals(samplerType)) {
197 // Avoid exact floating-point arithmetic in construction.
198 // Frequencies must sum to less than 2^63; here the sum is ~2^62.
199 // This conversion may omit very small probabilities.
200 final double sum = Arrays.stream(probabilities).sum();
201 final long[] frequencies = Arrays.stream(probabilities)
202 .mapToLong(x -> Math.round(0x1.0p62 * x / sum))
203 .toArray();
204 factory = () -> FastLoadedDiceRollerDiscreteSampler.of(rng, frequencies);
205 } else if ("FastLoadedDiceRollerDiscreteSampler53".equals(samplerType)) {
206 factory = () -> FastLoadedDiceRollerDiscreteSampler.of(rng, probabilities, 53);
207 } else {
208 throw new IllegalStateException();
209 }
210 }
211
212 /**
213 * Creates a new instance of the sampler.
214 *
215 * @return The sampler.
216 */
217 public DiscreteSampler createSampler() {
218 return factory.get();
219 }
220 }
221
222 /**
223 * Define known probability distributions for testing. These are expected to have well
224 * behaved cumulative probability functions.
225 */
226 @State(Scope.Benchmark)
227 public static class KnownDistributionSources extends SamplerSources {
228 /** The cumulative probability limit for unbounded distributions. */
229 private static final double CUMULATIVE_PROBABILITY_LIMIT = 1 - 1e-9;
230 /** Binomial distribution number of trials. */
231 private static final int BINOM_N = 67;
232 /** Binomial distribution probability of success. */
233 private static final double BINOM_P = 0.7;
234 /** Geometric distribution probability of success. */
235 private static final double GEO_P = 0.2;
236 /** Poisson distribution mean. */
237 private static final double POISS_MEAN = 3.22;
238 /** Bimodal distribution mean 1. */
239 private static final double BIMOD_MEAN1 = 10;
240 /** Bimodal distribution mean 1. */
241 private static final double BIMOD_MEAN2 = 20;
242
243 /**
244 * The distribution.
245 */
246 @Param({"Binomial_N67_P0.7",
247 "Geometric_P0.2",
248 "4SidedLoadedDie",
249 "Poisson_Mean3.22",
250 "Poisson_Mean10_Mean20"})
251 private String distribution;
252
253 /** {@inheritDoc} */
254 @Override
255 protected double[] createProbabilities() {
256 if ("Binomial_N67_P0.7".equals(distribution)) {
257 final BinomialDistribution dist = new BinomialDistribution(null, BINOM_N, BINOM_P);
258 return createProbabilities(dist, 0, BINOM_N);
259 } else if ("Geometric_P0.2".equals(distribution)) {
260 final double probabilityOfFailure = 1 - GEO_P;
261 // https://en.wikipedia.org/wiki/Geometric_distribution
262 // PMF = (1-p)^k * p
263 // k is number of failures before a success
264 double p = 1.0; // (1-p)^0
265 // Build until the cumulative function is big
266 double[] probabilities = new double[100];
267 double sum = 0;
268 int k = 0;
269 while (k < probabilities.length) {
270 probabilities[k] = p * GEO_P;
271 sum += probabilities[k++];
272 if (sum > CUMULATIVE_PROBABILITY_LIMIT) {
273 break;
274 }
275 // For the next PMF
276 p *= probabilityOfFailure;
277 }
278 return Arrays.copyOf(probabilities, k);
279 } else if ("4SidedLoadedDie".equals(distribution)) {
280 return new double[] {1.0 / 2, 1.0 / 3, 1.0 / 12, 1.0 / 12};
281 } else if ("Poisson_Mean3.22".equals(distribution)) {
282 final IntegerDistribution dist = createPoissonDistribution(POISS_MEAN);
283 final int max = dist.inverseCumulativeProbability(CUMULATIVE_PROBABILITY_LIMIT);
284 return createProbabilities(dist, 0, max);
285 } else if ("Poisson_Mean10_Mean20".equals(distribution)) {
286 // Create a Bimodel using two Poisson distributions
287 final IntegerDistribution dist1 = createPoissonDistribution(BIMOD_MEAN2);
288 final int max = dist1.inverseCumulativeProbability(CUMULATIVE_PROBABILITY_LIMIT);
289 final double[] p1 = createProbabilities(dist1, 0, max);
290 final double[] p2 = createProbabilities(createPoissonDistribution(BIMOD_MEAN1), 0, max);
291 for (int i = 0; i < p1.length; i++) {
292 p1[i] += p2[i];
293 }
294 // Leave to the distribution to normalise the sum
295 return p1;
296 }
297 throw new IllegalStateException();
298 }
299
300 /**
301 * Creates the poisson distribution.
302 *
303 * @param mean the mean
304 * @return the distribution
305 */
306 private static IntegerDistribution createPoissonDistribution(double mean) {
307 return new PoissonDistribution(null, mean,
308 PoissonDistribution.DEFAULT_EPSILON, PoissonDistribution.DEFAULT_MAX_ITERATIONS);
309 }
310
311 /**
312 * Creates the probabilities from the distribution.
313 *
314 * @param dist the distribution
315 * @param lower the lower bounds (inclusive)
316 * @param upper the upper bounds (inclusive)
317 * @return the probabilities
318 */
319 private static double[] createProbabilities(IntegerDistribution dist, int lower, int upper) {
320 double[] probabilities = new double[upper - lower + 1];
321 int index = 0;
322 for (int x = lower; x <= upper; x++) {
323 probabilities[index++] = dist.probability(x);
324 }
325 return probabilities;
326 }
327 }
328
329 /**
330 * Define random probability distributions of known size for testing. These are random but
331 * the average cumulative probability function will be a straight line given the increment
332 * average is 0.5.
333 */
334 @State(Scope.Benchmark)
335 public static class RandomDistributionSources extends SamplerSources {
336 /**
337 * The distribution size.
338 * These are spaced half-way between powers-of-2 to minimise the advantage of
339 * padding by the Alias method sampler.
340 */
341 @Param({"6",
342 //"12",
343 //"24",
344 //"48",
345 "96",
346 //"192",
347 //"384",
348 // Above 2048 forces the Alias method to use more than 64-bits for sampling
349 "3072"})
350 private int randomNonUniformSize;
351
352 /** {@inheritDoc} */
353 @Override
354 protected double[] createProbabilities() {
355 return RandomSource.XO_RO_SHI_RO_128_PP.create()
356 .doubles(randomNonUniformSize).toArray();
357 }
358 }
359
360 /**
361 * Sample random probability arrays from a Dirichlet distribution.
362 *
363 * <p>The distribution ensures the probabilities sum to 1.
364 * The <a href="https://en.wikipedia.org/wiki/Entropy_(information_theory)">entropy</a>
365 * of the probabilities increases with parameters k and alpha.
366 * The following shows the mean and sd of the entropy from 100 samples
367 * for a range of parameters.
368 * <pre>
369 * k alpha mean sd
370 * 4 0.500 1.299 0.374
371 * 4 1.000 1.531 0.294
372 * 4 2.000 1.754 0.172
373 * 8 0.500 2.087 0.348
374 * 8 1.000 2.490 0.266
375 * 8 2.000 2.707 0.142
376 * 16 0.500 3.023 0.287
377 * 16 1.000 3.454 0.166
378 * 16 2.000 3.693 0.095
379 * 32 0.500 4.008 0.182
380 * 32 1.000 4.406 0.125
381 * 32 2.000 4.692 0.075
382 * 64 0.500 4.986 0.151
383 * 64 1.000 5.392 0.115
384 * 64 2.000 5.680 0.048
385 * </pre>
386 */
387 @State(Scope.Benchmark)
388 public static class DirichletDistributionSources extends SamplerSources {
389 /** Number of categories. */
390 @Param({"4", "8", "16"})
391 private int k;
392
393 /** Concentration parameter. */
394 @Param({"0.5", "1", "2"})
395 private double alpha;
396
397 /** {@inheritDoc} */
398 @Override
399 protected double[] createProbabilities() {
400 return DirichletSampler.symmetric(RandomSource.XO_RO_SHI_RO_128_PP.create(),
401 k, alpha).sample();
402 }
403 }
404
405 /**
406 * The {@link FastLoadedDiceRollerDiscreteSampler} samplers to use for testing.
407 * Creates the sampler for each random source and the probabilities using
408 * a Dirichlet distribution.
409 *
410 * <p>This class is a specialized source to allow examination of the effect of the
411 * {@link FastLoadedDiceRollerDiscreteSampler} {@code alpha} parameter.
412 */
413 @State(Scope.Benchmark)
414 public static class FastLoadedDiceRollerDiscreteSamplerSources extends LocalRandomSources {
415 /** Number of categories. */
416 @Param({"4", "8", "16"})
417 private int k;
418
419 /** Concentration parameter. */
420 @Param({"0.5", "1", "2"})
421 private double concentration;
422
423 /** The constructor {@code alpha} parameter. */
424 @Param({"0", "30", "53"})
425 private int alpha;
426
427 /** The factory. */
428 private Supplier<DiscreteSampler> factory;
429
430 /** The sampler. */
431 private DiscreteSampler sampler;
432
433 /**
434 * Gets the sampler.
435 *
436 * @return the sampler.
437 */
438 public DiscreteSampler getSampler() {
439 return sampler;
440 }
441
442 /** Create the distribution probabilities (per iteration as it may vary), the sampler
443 * factory and instantiates sampler. */
444 @Override
445 @Setup(Level.Iteration)
446 public void setup() {
447 super.setup();
448
449 final double[] probabilities =
450 DirichletSampler.symmetric(RandomSource.XO_RO_SHI_RO_128_PP.create(),
451 k, concentration).sample();
452 final UniformRandomProvider rng = getGenerator();
453 factory = () -> FastLoadedDiceRollerDiscreteSampler.of(rng, probabilities, alpha);
454 sampler = factory.get();
455 }
456
457 /**
458 * Creates a new instance of the sampler.
459 *
460 * @return The sampler.
461 */
462 public DiscreteSampler createSampler() {
463 return factory.get();
464 }
465 }
466
467 /**
468 * Compute a sample by binary search of the cumulative probability distribution.
469 */
470 static final class BinarySearchDiscreteSampler
471 implements DiscreteSampler {
472 /** Underlying source of randomness. */
473 private final UniformRandomProvider rng;
474 /**
475 * The cumulative probability table.
476 */
477 private final double[] cumulativeProbabilities;
478
479 /**
480 * @param rng Generator of uniformly distributed random numbers.
481 * @param probabilities The probabilities.
482 * @throws IllegalArgumentException if {@code probabilities} is null or empty, a
483 * probability is negative, infinite or {@code NaN}, or the sum of all
484 * probabilities is not strictly positive.
485 */
486 BinarySearchDiscreteSampler(UniformRandomProvider rng,
487 double[] probabilities) {
488 // Minimal set-up validation
489 if (probabilities == null || probabilities.length == 0) {
490 throw new IllegalArgumentException("Probabilities must not be empty.");
491 }
492
493 final int size = probabilities.length;
494 cumulativeProbabilities = new double[size];
495
496 double sumProb = 0;
497 int count = 0;
498 for (final double prob : probabilities) {
499 if (prob < 0 ||
500 Double.isInfinite(prob) ||
501 Double.isNaN(prob)) {
502 throw new IllegalArgumentException("Invalid probability: " +
503 prob);
504 }
505
506 // Compute and store cumulative probability.
507 sumProb += prob;
508 cumulativeProbabilities[count++] = sumProb;
509 }
510
511 if (Double.isInfinite(sumProb) || sumProb <= 0) {
512 throw new IllegalArgumentException("Invalid sum of probabilities: " + sumProb);
513 }
514
515 this.rng = rng;
516
517 // Normalise cumulative probability.
518 for (int i = 0; i < size; i++) {
519 final double norm = cumulativeProbabilities[i] / sumProb;
520 cumulativeProbabilities[i] = (norm < 1) ? norm : 1.0;
521 }
522 }
523
524 /** {@inheritDoc} */
525 @Override
526 public int sample() {
527 final double u = rng.nextDouble();
528
529 // Java binary search
530 //int index = Arrays.binarySearch(cumulativeProbabilities, u);
531 //if (index < 0) {
532 // index = -index - 1;
533 //}
534 //
535 //return index < cumulativeProbabilities.length ?
536 // index :
537 // cumulativeProbabilities.length - 1;
538
539 // Binary search within known cumulative probability table.
540 // Find x so that u > f[x-1] and u <= f[x].
541 // This is a looser search than Arrays.binarySearch:
542 // - The output is x = upper.
543 // - The table stores probabilities where f[0] is >= 0 and the max == 1.0.
544 // - u should be >= 0 and <= 1 (or the random generator is broken).
545 // - It avoids comparisons using Double.doubleToLongBits.
546 // - It avoids the low likelihood of equality between two doubles for fast exit
547 // so uses only 1 compare per loop.
548 int lower = 0;
549 int upper = cumulativeProbabilities.length - 1;
550 while (lower < upper) {
551 final int mid = (lower + upper) >>> 1;
552 final double midVal = cumulativeProbabilities[mid];
553 if (u > midVal) {
554 // Change lower such that
555 // u > f[lower - 1]
556 lower = mid + 1;
557 } else {
558 // Change upper such that
559 // u <= f[upper]
560 upper = mid;
561 }
562 }
563 return upper;
564 }
565 }
566
567 // Benchmarks methods below.
568
569 /**
570 * Baseline for the JMH timing overhead for production of an {@code int} value.
571 *
572 * @return the {@code int} value
573 */
574 @Benchmark
575 public int baselineInt() {
576 return value;
577 }
578
579 /**
580 * Baseline for the production of a {@code double} value.
581 * This is used to assess the performance of the underlying random source.
582 *
583 * @param sources Source of randomness.
584 * @return the {@code int} value
585 */
586 @Benchmark
587 public int baselineNextDouble(LocalRandomSources sources) {
588 return sources.getGenerator().nextDouble() < 0.5 ? 1 : 0;
589 }
590
591 /**
592 * Run the sampler.
593 *
594 * @param sources Source of randomness.
595 * @return the sample value
596 */
597 @Benchmark
598 public int sampleKnown(KnownDistributionSources sources) {
599 return sources.getSampler().sample();
600 }
601
602 /**
603 * Create and run the sampler.
604 *
605 * @param sources Source of randomness.
606 * @return the sample value
607 */
608 @Benchmark
609 public int singleSampleKnown(KnownDistributionSources sources) {
610 return sources.createSampler().sample();
611 }
612
613 /**
614 * Run the sampler.
615 *
616 * @param sources Source of randomness.
617 * @return the sample value
618 */
619 @Benchmark
620 public int sampleRandom(RandomDistributionSources sources) {
621 return sources.getSampler().sample();
622 }
623
624 /**
625 * Create and run the sampler.
626 *
627 * @param sources Source of randomness.
628 * @return the sample value
629 */
630 @Benchmark
631 public int singleSampleRandom(RandomDistributionSources sources) {
632 return sources.createSampler().sample();
633 }
634
635 /**
636 * Run the sampler.
637 *
638 * @param sources Source of randomness.
639 * @return the sample value
640 */
641 @Benchmark
642 public int sampleDirichlet(DirichletDistributionSources sources) {
643 return sources.getSampler().sample();
644 }
645
646 /**
647 * Create and run the sampler.
648 *
649 * @param sources Source of randomness.
650 * @return the sample value
651 */
652 @Benchmark
653 public int singleSampleDirichlet(DirichletDistributionSources sources) {
654 return sources.createSampler().sample();
655 }
656
657 /**
658 * Run the sampler.
659 *
660 * @param sources Source of randomness.
661 * @return the sample value
662 */
663 @Benchmark
664 public int sampleFast(FastLoadedDiceRollerDiscreteSamplerSources sources) {
665 return sources.getSampler().sample();
666 }
667
668 /**
669 * Create and run the sampler.
670 *
671 * @param sources Source of randomness.
672 * @return the sample value
673 */
674 @Benchmark
675 public int singleSampleFast(FastLoadedDiceRollerDiscreteSamplerSources sources) {
676 return sources.createSampler().sample();
677 }
678 }