1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.rng.examples.jmh.sampling;
19
20 import org.apache.commons.rng.UniformRandomProvider;
21 import org.apache.commons.rng.sampling.ObjectSampler;
22 import org.apache.commons.rng.sampling.distribution.NormalizedGaussianSampler;
23 import org.apache.commons.rng.sampling.distribution.ZigguratSampler;
24 import org.apache.commons.rng.simple.RandomSource;
25 import org.openjdk.jmh.annotations.Benchmark;
26 import org.openjdk.jmh.annotations.BenchmarkMode;
27 import org.openjdk.jmh.annotations.Fork;
28 import org.openjdk.jmh.annotations.Measurement;
29 import org.openjdk.jmh.annotations.Mode;
30 import org.openjdk.jmh.annotations.OutputTimeUnit;
31 import org.openjdk.jmh.annotations.Param;
32 import org.openjdk.jmh.annotations.Scope;
33 import org.openjdk.jmh.annotations.Setup;
34 import org.openjdk.jmh.annotations.State;
35 import org.openjdk.jmh.annotations.Warmup;
36 import org.openjdk.jmh.infra.Blackhole;
37 import java.util.concurrent.TimeUnit;
38
39
40
41
42
43 @BenchmarkMode(Mode.AverageTime)
44 @OutputTimeUnit(TimeUnit.NANOSECONDS)
45 @Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
46 @Measurement(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
47 @State(Scope.Benchmark)
48 @Fork(value = 1, jvmArgs = { "-server", "-Xms128M", "-Xmx128M" })
49 public class UnitSphereSamplerBenchmark {
50
51 private static final String BASELINE = "Baseline";
52
53 private static final String NON_ARRAY = "NonArray";
54
55 private static final String ARRAY = "Array";
56
57 private static final String UNKNOWN_SAMPLER = "Unknown sampler type: ";
58
59
60
61
62
63
64 @State(Scope.Benchmark)
65 public abstract static class SamplerData {
66
67 private ObjectSampler<double[]> sampler;
68
69
70 @Param({"100"})
71 private int size;
72
73
74
75
76
77
78 public int getSize() {
79 return size;
80 }
81
82
83
84
85
86
87 public ObjectSampler<double[]> getSampler() {
88 return sampler;
89 }
90
91
92
93
94 @Setup
95 public void setup() {
96
97 final UniformRandomProvider rng = RandomSource.XO_RO_SHI_RO_128_PP.create();
98 sampler = createSampler(rng);
99 }
100
101
102
103
104
105
106
107 protected abstract ObjectSampler<double[]> createSampler(UniformRandomProvider rng);
108 }
109
110
111
112
113 @State(Scope.Benchmark)
114 public static class Sampler1D extends SamplerData {
115
116 private static final String SIGNED_DOUBLE = "signedDouble";
117
118 private static final String MASKED_INT = "maskedInt";
119
120 private static final String MASKED_LONG = "maskedLong";
121
122 private static final String BOOLEAN = "boolean";
123
124 private static final long ONE = Double.doubleToRawLongBits(1.0);
125
126 private static final long SIGN_BIT = 1L << 31;
127
128
129 @Param({BASELINE, SIGNED_DOUBLE, MASKED_INT, MASKED_LONG, BOOLEAN, ARRAY})
130 private String type;
131
132
133 @Override
134 protected ObjectSampler<double[]> createSampler(final UniformRandomProvider rng) {
135 if (BASELINE.equals(type)) {
136 return () -> {
137 return new double[] {1.0};
138 };
139 } else if (SIGNED_DOUBLE.equals(type)) {
140 return () -> {
141
142
143 return new double[] {1.0 - ((rng.nextInt() >>> 30) & 0x2)};
144 };
145 } else if (MASKED_INT.equals(type)) {
146 return () -> {
147
148 return new double[] {Double.longBitsToDouble(ONE | ((rng.nextInt() & SIGN_BIT) << 32))};
149 };
150 } else if (MASKED_LONG.equals(type)) {
151 return () -> {
152
153 return new double[] {Double.longBitsToDouble(ONE | (rng.nextLong() & Long.MIN_VALUE))};
154 };
155 } else if (BOOLEAN.equals(type)) {
156 return () -> {
157 return new double[] {rng.nextBoolean() ? -1.0 : 1.0};
158 };
159 } else if (ARRAY.equals(type)) {
160 return new ArrayBasedUnitSphereSampler(1, rng);
161 }
162 throw new IllegalStateException(UNKNOWN_SAMPLER + type);
163 }
164 }
165
166
167
168
169 @State(Scope.Benchmark)
170 public static class Sampler2D extends SamplerData {
171
172 @Param({BASELINE, ARRAY, NON_ARRAY})
173 private String type;
174
175
176 @Override
177 protected ObjectSampler<double[]> createSampler(final UniformRandomProvider rng) {
178 if (BASELINE.equals(type)) {
179 return () -> new double[] {1.0, 0.0};
180 } else if (ARRAY.equals(type)) {
181 return new ArrayBasedUnitSphereSampler(2, rng);
182 } else if (NON_ARRAY.equals(type)) {
183 return new UnitSphereSampler2D(rng);
184 }
185 throw new IllegalStateException(UNKNOWN_SAMPLER + type);
186 }
187
188
189
190
191 private static class UnitSphereSampler2D implements ObjectSampler<double[]> {
192
193 private final NormalizedGaussianSampler sampler;
194
195
196
197
198 UnitSphereSampler2D(UniformRandomProvider rng) {
199 sampler = ZigguratSampler.NormalizedGaussian.of(rng);
200 }
201
202 @Override
203 public double[] sample() {
204 final double x = sampler.sample();
205 final double y = sampler.sample();
206 final double sum = x * x + y * y;
207
208 if (sum == 0) {
209
210 return sample();
211 }
212
213 final double f = 1.0 / Math.sqrt(sum);
214 return new double[] {x * f, y * f};
215 }
216 }
217 }
218
219
220
221
222 @State(Scope.Benchmark)
223 public static class Sampler3D extends SamplerData {
224
225 @Param({BASELINE, ARRAY, NON_ARRAY})
226 private String type;
227
228
229 @Override
230 protected ObjectSampler<double[]> createSampler(final UniformRandomProvider rng) {
231 if (BASELINE.equals(type)) {
232 return () -> new double[] {1.0, 0.0, 0.0};
233 } else if (ARRAY.equals(type)) {
234 return new ArrayBasedUnitSphereSampler(3, rng);
235 } else if (NON_ARRAY.equals(type)) {
236 return new UnitSphereSampler3D(rng);
237 }
238 throw new IllegalStateException(UNKNOWN_SAMPLER + type);
239 }
240
241
242
243
244 private static class UnitSphereSampler3D implements ObjectSampler<double[]> {
245
246 private final NormalizedGaussianSampler sampler;
247
248
249
250
251 UnitSphereSampler3D(UniformRandomProvider rng) {
252 sampler = ZigguratSampler.NormalizedGaussian.of(rng);
253 }
254
255 @Override
256 public double[] sample() {
257 final double x = sampler.sample();
258 final double y = sampler.sample();
259 final double z = sampler.sample();
260 final double sum = x * x + y * y + z * z;
261
262 if (sum == 0) {
263
264 return sample();
265 }
266
267 final double f = 1.0 / Math.sqrt(sum);
268 return new double[] {x * f, y * f, z * f};
269 }
270 }
271 }
272
273
274
275
276 @State(Scope.Benchmark)
277 public static class Sampler4D extends SamplerData {
278
279 @Param({BASELINE, ARRAY, NON_ARRAY})
280 private String type;
281
282
283 @Override
284 protected ObjectSampler<double[]> createSampler(final UniformRandomProvider rng) {
285 if (BASELINE.equals(type)) {
286 return () -> new double[] {1.0, 0.0, 0.0, 0.0};
287 } else if (ARRAY.equals(type)) {
288 return new ArrayBasedUnitSphereSampler(4, rng);
289 } else if (NON_ARRAY.equals(type)) {
290 return new UnitSphereSampler4D(rng);
291 }
292 throw new IllegalStateException(UNKNOWN_SAMPLER + type);
293 }
294
295
296
297
298 private static class UnitSphereSampler4D implements ObjectSampler<double[]> {
299
300 private final NormalizedGaussianSampler sampler;
301
302
303
304
305 UnitSphereSampler4D(UniformRandomProvider rng) {
306 sampler = ZigguratSampler.NormalizedGaussian.of(rng);
307 }
308
309 @Override
310 public double[] sample() {
311 final double x = sampler.sample();
312 final double y = sampler.sample();
313 final double z = sampler.sample();
314 final double a = sampler.sample();
315 final double sum = x * x + y * y + z * z + a * a;
316
317 if (sum == 0) {
318
319 return sample();
320 }
321
322 final double f = 1.0 / Math.sqrt(sum);
323 return new double[] {x * f, y * f, z * f, a * f};
324 }
325 }
326 }
327
328
329
330
331 private static class ArrayBasedUnitSphereSampler implements ObjectSampler<double[]> {
332
333 private final int dimension;
334
335 private final NormalizedGaussianSampler sampler;
336
337
338
339
340
341 ArrayBasedUnitSphereSampler(int dimension, UniformRandomProvider rng) {
342 this.dimension = dimension;
343 sampler = ZigguratSampler.NormalizedGaussian.of(rng);
344 }
345
346 @Override
347 public double[] sample() {
348 final double[] v = new double[dimension];
349
350
351
352 double sum = 0;
353 for (int i = 0; i < dimension; i++) {
354 final double x = sampler.sample();
355 v[i] = x;
356 sum += x * x;
357 }
358
359 if (sum == 0) {
360
361 return sample();
362 }
363
364 final double f = 1 / Math.sqrt(sum);
365 for (int i = 0; i < dimension; i++) {
366 v[i] *= f;
367 }
368
369 return v;
370 }
371 }
372
373
374
375
376
377
378
379 private static void runSampler(Blackhole bh, SamplerData data) {
380 final ObjectSampler<double[]> sampler = data.getSampler();
381 for (int i = data.getSize() - 1; i >= 0; i--) {
382 bh.consume(sampler.sample());
383 }
384 }
385
386
387
388
389
390
391
392 @Benchmark
393 public void create1D(Blackhole bh, Sampler1D data) {
394 runSampler(bh, data);
395 }
396
397
398
399
400
401
402
403 @Benchmark
404 public void create2D(Blackhole bh, Sampler2D data) {
405 runSampler(bh, data);
406 }
407
408
409
410
411
412
413
414 @Benchmark
415 public void create3D(Blackhole bh, Sampler3D data) {
416 runSampler(bh, data);
417 }
418
419
420
421
422
423
424
425 @Benchmark
426 public void create4D(Blackhole bh, Sampler4D data) {
427 runSampler(bh, data);
428 }
429 }