View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.rng.examples.jmh.sampling.distribution;
19  
20  import java.util.concurrent.TimeUnit;
21  import java.util.function.DoubleFunction;
22  import org.apache.commons.rng.RandomProviderState;
23  import org.apache.commons.rng.RestorableUniformRandomProvider;
24  import org.apache.commons.rng.UniformRandomProvider;
25  import org.apache.commons.rng.sampling.PermutationSampler;
26  import org.apache.commons.rng.sampling.distribution.DiscreteSampler;
27  import org.apache.commons.rng.sampling.distribution.PoissonSampler;
28  import org.apache.commons.rng.sampling.distribution.PoissonSamplerCache;
29  import org.apache.commons.rng.simple.RandomSource;
30  import org.openjdk.jmh.annotations.Benchmark;
31  import org.openjdk.jmh.annotations.BenchmarkMode;
32  import org.openjdk.jmh.annotations.Fork;
33  import org.openjdk.jmh.annotations.Measurement;
34  import org.openjdk.jmh.annotations.Mode;
35  import org.openjdk.jmh.annotations.OutputTimeUnit;
36  import org.openjdk.jmh.annotations.Param;
37  import org.openjdk.jmh.annotations.Scope;
38  import org.openjdk.jmh.annotations.Setup;
39  import org.openjdk.jmh.annotations.State;
40  import org.openjdk.jmh.annotations.Warmup;
41  import org.openjdk.jmh.infra.Blackhole;
42  
43  /**
44   * Executes benchmark to compare the speed of generation of Poisson random numbers when using a
45   * cache.
46   *
47   * <p>The benchmark is designed for a worse case scenario of Poisson means that are uniformly spread
48   * over a range and non-integer. A single sample is required per mean, E.g.</p>
49   *
50   * <pre>
51   * int min = 40;
52   * int max = 1000;
53   * int range = max - min;
54   * UniformRandomProvider rng = ...;
55   *
56   * // Compare ...
57   * for (int i = 0; i &lt; 1000; i++) {
58   *   PoissonSampler.of(rng, min + rng.nextDouble() * range).sample();
59   * }
60   *
61   * // To ...
62   * PoissonSamplerCache cache = new PoissonSamplerCache(min, max);
63   * for (int i = 0; i &lt; 1000; i++) {
64   *   PoissonSamplerCache.createPoissonSampler(rng, min + rng.nextDouble() * range).sample();
65   * }
66   * </pre>
67   *
68   * <p>The alternative scenario where the means are integer is not considered as this could be easily
69   * handled by creating an array to hold the PoissonSamplers for each mean. This does not require any
70   * specialised caching of state and is simple enough to perform for single threaded applications:</p>
71   *
72   * <pre>
73   * public class SimpleUnsafePoissonSamplerCache {
74   *   int min = 50;
75   *   int max = 100;
76   *   PoissonSampler[] samplers = new PoissonSampler[max - min + 1];
77   *
78   *   public PoissonSampler createPoissonSampler(UniformRandomProvider rng, int mean) {
79   *     if (mean &lt; min || mean &gt; max) {
80   *       return PoissonSampler.of(rng, mean);
81   *     }
82   *     int index = mean - min;
83   *     PoissonSampler sample = samplers[index];
84   *     if (sampler == null) {
85   *       sampler = PoissonSampler.of(rng, mean);
86   *       samplers[index] = sampler;
87   *     }
88   *     return sampler;
89   *   }
90   * }
91   * </pre>
92   *
93   * <p>Note that in this example the UniformRandomProvider is also cached and so this is only
94   * applicable to a single threaded application. Thread safety could be ensured using the
95   * {@link org.apache.commons.rng.sampling.SharedStateSampler SharedStateSampler} functionality
96   * of the cached sampler.</p>
97   *
98   * <p>Re-written to use the PoissonSamplerCache would provide a new PoissonSampler per call in a
99   * thread-safe manner:
100  *
101  * <pre>
102  * public class SimplePoissonSamplerCache {
103  *   int min = 50;
104  *   int max = 100;
105  *   PoissonSamplerCache samplers = new PoissonSamplerCache(min, max);
106  *
107  *   public PoissonSampler createPoissonSampler(UniformRandomProvider rng, int mean) {
108  *       return samplers.createPoissonSampler(rng, mean);
109  *   }
110  * }
111  * </pre>
112 */
113 @BenchmarkMode(Mode.AverageTime)
114 @OutputTimeUnit(TimeUnit.MICROSECONDS)
115 @Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
116 @Measurement(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
117 @State(Scope.Benchmark)
118 @Fork(value = 1, jvmArgs = { "-server", "-Xms128M", "-Xmx128M" })
119 public class PoissonSamplerCachePerformance {
120     /** Number of samples per run. */
121     private static final int NUM_SAMPLES = 100_000;
122     /**
123      * Number of range samples.
124      *
125      * <p>Note: The LargeMeanPoissonSampler will not use a SmallMeanPoissonSampler
126      * if the mean is an integer. This will occur if the [range sample] * range is
127      * an integer.
128      *
129      * <p>If the SmallMeanPoissonSampler is not used then the cache has more
130      * advantage over the uncached version as relatively more time is spent in
131      * initialising the algorithm.
132      *
133      * <p>To avoid this use a prime number above the maximum range
134      * (currently 4096). Any number (n/RANGE_SAMPLES) * range will not be integer
135      * with {@code n < RANGE_SAMPLES} and {@code range < RANGE_SAMPLES} (unless n==0).
136      */
137     private static final int RANGE_SAMPLE_SIZE = 4099;
138     /** The size of the seed. */
139     private static final int SEED_SIZE = 128;
140 
141     /**
142      * Seed used to ensure the tests are the same. This can be different per
143      * benchmark, but should be the same within the benchmark.
144      */
145     private static final int[] SEED;
146 
147     /**
148      * The range sample. Should contain doubles in the range 0 inclusive to 1 exclusive.
149      *
150      * <p>The range sample is used to create a mean using:
151      * rangeMin + sample * (rangeMax - rangeMin).
152      *
153      * <p>Ideally this should be large enough to fully sample the
154      * range when expressed as discrete integers, i.e. no sparseness, and random.
155      */
156     private static final double[] RANGE_SAMPLE;
157 
158     static {
159         // Build a random seed for all the tests
160         SEED = new int[SEED_SIZE];
161         final UniformRandomProvider rng = RandomSource.MWC_256.create();
162         for (int i = 0; i < SEED.length; i++) {
163             SEED[i] = rng.nextInt();
164         }
165 
166         final int size = RANGE_SAMPLE_SIZE;
167         final int[] sample = PermutationSampler.natural(size);
168         PermutationSampler.shuffle(rng, sample);
169 
170         RANGE_SAMPLE = new double[size];
171         for (int i = 0; i < size; i++) {
172             // Note: This will have one occurrence of zero in the range.
173             // This will create at least one LargeMeanPoissonSampler that will
174             // not use a SmallMeanPoissonSampler. The different performance of this
175             // will be lost among the other samples.
176             RANGE_SAMPLE[i] = (double) sample[i] / size;
177         }
178     }
179 
180     /**
181      * The benchmark state (retrieve the various "RandomSource"s).
182      */
183     @State(Scope.Benchmark)
184     public static class Sources {
185         /**
186          * RNG providers.
187          *
188          * <p>Use different speeds.</p>
189          *
190          * @see <a href="https://commons.apache.org/proper/commons-rng/userguide/rng.html">
191          *      Commons RNG user guide</a>
192          */
193         @Param({ "SPLIT_MIX_64",
194             // Comment in for slower generators
195             //"MWC_256", "KISS", "WELL_1024_A", "WELL_44497_B"
196             })
197         private String randomSourceName;
198 
199         /** RNG. */
200         private RestorableUniformRandomProvider generator;
201 
202         /**
203          * The state of the generator at the start of the test (for reproducible
204          * results).
205          */
206         private RandomProviderState state;
207 
208         /**
209          * @return the RNG.
210          */
211         public UniformRandomProvider getGenerator() {
212             generator.restoreState(state);
213             return generator;
214         }
215 
216         /** Instantiates generator. */
217         @Setup
218         public void setup() {
219             final RandomSource randomSource = RandomSource
220                     .valueOf(randomSourceName);
221             // Use the same seed
222             generator = randomSource.create(SEED.clone());
223             state = generator.saveState();
224         }
225     }
226 
227     /**
228      * The range of mean values for testing the cache.
229      */
230     @State(Scope.Benchmark)
231     public static class MeanRange {
232         /**
233          * Test range.
234          *
235          * <p>The covers the best case scenario of caching everything (range=1) and upwards
236          * in powers of 4.
237          */
238         @Param({ "1", "4", "16", "64", "256", "1024", "4096"})
239         private double range;
240 
241         /**
242          * Gets the mean.
243          *
244          * @param i the index
245          * @return the mean
246          */
247         public double getMean(int i) {
248             return getMin() + RANGE_SAMPLE[i % RANGE_SAMPLE.length] * range;
249         }
250 
251         /**
252          * Gets the min of the range.
253          *
254          * @return the min
255          */
256         public double getMin() {
257             return PoissonSamplerCache.getMinimumCachedMean();
258         }
259 
260         /**
261          * Gets the max of the range.
262          *
263          * @return the max
264          */
265         public double getMax() {
266             return getMin() + range;
267         }
268     }
269 
270     /**
271      * Exercises a poisson sampler created for a single use with a range of means.
272      *
273      * @param factory The factory.
274      * @param range   The range of means.
275      * @param bh      Data sink.
276      */
277     private static void runSample(DoubleFunction<DiscreteSampler> factory,
278                                   MeanRange range,
279                                   Blackhole bh) {
280         for (int i = 0; i < NUM_SAMPLES; i++) {
281             bh.consume(factory.apply(range.getMean(i)).sample());
282         }
283     }
284 
285     // Benchmarks methods below.
286 
287     /**
288      * @param sources Source of randomness.
289      * @param range   The range.
290      * @param bh      Data sink.
291      */
292     @Benchmark
293     public void runPoissonSampler(Sources sources,
294                                   MeanRange range,
295                                   Blackhole bh) {
296         final UniformRandomProvider r = sources.getGenerator();
297         final DoubleFunction<DiscreteSampler> factory = mean -> PoissonSampler.of(r, mean);
298         runSample(factory, range, bh);
299     }
300 
301     /**
302      * @param sources Source of randomness.
303      * @param range   The range.
304      * @param bh      Data sink.
305      */
306     @Benchmark
307     public void runPoissonSamplerCacheWhenEmpty(Sources sources,
308                                                 MeanRange range,
309                                                 Blackhole bh) {
310         final UniformRandomProvider r = sources.getGenerator();
311         final PoissonSamplerCache cache = new PoissonSamplerCache(0, 0);
312         final DoubleFunction<DiscreteSampler> factory = mean -> cache.createSharedStateSampler(r, mean);
313         runSample(factory, range, bh);
314     }
315 
316     /**
317      * @param sources Source of randomness.
318      * @param range   The range.
319      * @param bh      Data sink.
320      */
321     @Benchmark
322     public void runPoissonSamplerCache(Sources sources,
323                                        MeanRange range,
324                                        Blackhole bh) {
325         final UniformRandomProvider r = sources.getGenerator();
326         final PoissonSamplerCache cache = new PoissonSamplerCache(
327                 range.getMin(), range.getMax());
328         final DoubleFunction<DiscreteSampler> factory = mean -> cache.createSharedStateSampler(r, mean);
329         runSample(factory, range, bh);
330     }
331 }