1   
2   
3   
4   
5   
6   
7   
8   
9   
10  
11  
12  
13  
14  
15  
16  
17  package org.apache.commons.rng.sampling.distribution;
18  
19  import java.util.Locale;
20  import org.apache.commons.rng.UniformRandomProvider;
21  import org.apache.commons.rng.core.source32.IntProvider;
22  import org.apache.commons.rng.sampling.RandomAssert;
23  import org.apache.commons.rng.simple.RandomSource;
24  import org.junit.jupiter.api.Assertions;
25  import org.junit.jupiter.api.Test;
26  
27  
28  
29  
30  
31  class DiscreteUniformSamplerTest {
32      
33  
34  
35      @Test
36      void testConstructorThrowsWithLowerAboveUpper() {
37          final int upper = 55;
38          final int lower = upper + 1;
39          final UniformRandomProvider rng = RandomAssert.seededRNG();
40          Assertions.assertThrows(IllegalArgumentException.class,
41              () -> DiscreteUniformSampler.of(rng, lower, upper));
42      }
43  
44      @Test
45      void testSamplesWithRangeOf1() {
46          final int upper = 99;
47          final int lower = upper;
48          final UniformRandomProvider rng = RandomAssert.createRNG();
49          final SharedStateDiscreteSampler sampler = DiscreteUniformSampler.of(rng, lower, upper);
50          for (int i = 0; i < 5; i++) {
51              Assertions.assertEquals(lower, sampler.sample());
52          }
53      }
54  
55      
56  
57  
58  
59      @Test
60      void testSamplesWithFullRange() {
61          final int upper = Integer.MAX_VALUE;
62          final int lower = Integer.MIN_VALUE;
63          final UniformRandomProvider rng1 = RandomAssert.seededRNG();
64          final UniformRandomProvider rng2 = RandomAssert.seededRNG();
65          final SharedStateDiscreteSampler sampler = DiscreteUniformSampler.of(rng2, lower, upper);
66          for (int i = 0; i < 10; i++) {
67              Assertions.assertEquals(rng1.nextInt(), sampler.sample());
68          }
69      }
70  
71      
72  
73  
74  
75  
76  
77      @Test
78      void testSamplesWithSmallNonPowerOf2Range() {
79          final int upper = 257;
80          for (final int lower : new int[] {-13, 0, 13}) {
81              final int n = upper - lower + 1;
82              final UniformRandomProvider rng1 = RandomAssert.seededRNG();
83              final UniformRandomProvider rng2 = RandomAssert.seededRNG();
84              final SharedStateDiscreteSampler sampler = DiscreteUniformSampler.of(rng2, lower, upper);
85              for (int i = 0; i < 10; i++) {
86                  Assertions.assertEquals(lower + rng1.nextInt(n), sampler.sample());
87              }
88          }
89      }
90  
91      
92  
93  
94  
95      @Test
96      void testSamplesWithPowerOf2Range() {
97          final UniformRandomProvider rngZeroBits = new IntProvider() {
98              @Override
99              public int next() {
100                 
101                 return 0;
102             }
103         };
104         final UniformRandomProvider rngAllBits = new IntProvider() {
105             @Override
106             public int next() {
107                 
108                 return -1;
109             }
110         };
111 
112         final int lower = -3;
113         DiscreteUniformSampler sampler;
114         
115         
116         
117         for (int i = 0; i < 32; i++) {
118             final int range = 1 << i;
119             final int upper = lower + range - 1;
120             sampler = new DiscreteUniformSampler(rngZeroBits, lower, upper);
121             Assertions.assertEquals(lower, sampler.sample(), "Zero bits sample");
122             sampler = new DiscreteUniformSampler(rngAllBits, lower, upper);
123             Assertions.assertEquals(upper, sampler.sample(), "All bits sample");
124         }
125     }
126 
127     
128 
129 
130 
131     @Test
132     void testSamplesWithPowerOf2RangeIsBitShift() {
133         final int lower = 0;
134         SharedStateDiscreteSampler sampler;
135         
136         for (int i = 1; i <= 31; i++) {
137             
138             final int upper = (1 << i) - 1;
139             final int shift = 32 - i;
140             final UniformRandomProvider rng1 = RandomAssert.seededRNG();
141             final UniformRandomProvider rng2 = RandomAssert.seededRNG();
142             sampler = DiscreteUniformSampler.of(rng2, lower, upper);
143             for (int j = 0; j < 10; j++) {
144                 Assertions.assertEquals(rng1.nextInt() >>> shift, sampler.sample());
145             }
146         }
147     }
148 
149     
150 
151 
152 
153     @Test
154     void testSamplesWithLargeNonPowerOf2RangeIsRejectionMethod() {
155         
156         final int upper = Integer.MAX_VALUE / 2 + 1;
157         final int lower = Integer.MIN_VALUE / 2 - 1;
158         final UniformRandomProvider rng1 = RandomAssert.seededRNG();
159         final UniformRandomProvider rng2 = RandomAssert.seededRNG();
160         final SharedStateDiscreteSampler sampler = DiscreteUniformSampler.of(rng2, lower, upper);
161         for (int i = 0; i < 10; i++) {
162             
163             long expected;
164             do {
165                 expected = rng1.nextInt();
166             } while (expected < lower || expected > upper);
167             Assertions.assertEquals(expected, sampler.sample());
168         }
169     }
170 
171     @Test
172     void testOffsetSamplesWithNonPowerOf2Range() {
173         assertOffsetSamples(257);
174     }
175 
176     @Test
177     void testOffsetSamplesWithPowerOf2Range() {
178         assertOffsetSamples(256);
179     }
180 
181     @Test
182     void testOffsetSamplesWithRangeOf1() {
183         assertOffsetSamples(1);
184     }
185 
186     private static void assertOffsetSamples(int range) {
187         final Long seed = RandomSource.createLong();
188         final UniformRandomProvider rng1 = RandomSource.SPLIT_MIX_64.create(seed);
189         final UniformRandomProvider rng2 = RandomSource.SPLIT_MIX_64.create(seed);
190         final UniformRandomProvider rng3 = RandomSource.SPLIT_MIX_64.create(seed);
191 
192         
193         range = range - 1;
194         final int offsetLo = -13;
195         final int offsetHi = 42;
196         final SharedStateDiscreteSampler sampler = DiscreteUniformSampler.of(rng1, 0, range);
197         final SharedStateDiscreteSampler samplerLo = DiscreteUniformSampler.of(rng2, offsetLo, offsetLo + range);
198         final SharedStateDiscreteSampler samplerHi = DiscreteUniformSampler.of(rng3, offsetHi, offsetHi + range);
199         for (int i = 0; i < 10; i++) {
200             final int sample1 = sampler.sample();
201             final int sample2 = samplerLo.sample();
202             final int sample3 = samplerHi.sample();
203             Assertions.assertEquals(sample1 + offsetLo, sample2, "Incorrect negative offset sample");
204             Assertions.assertEquals(sample1 + offsetHi, sample3, "Incorrect positive offset sample");
205         }
206     }
207 
208     
209 
210 
211     @Test
212     void testSampleUniformityWithNonPowerOf2Range() {
213         
214         
215         
216         
217         final UniformRandomProvider rng = new IntProvider() {
218             private final int increment = 362437;
219             
220             private final int start = Integer.MIN_VALUE - increment;
221 
222             private int bits = start;
223 
224             @Override
225             public int next() {
226                 
227                 
228                 
229                 int result = bits += increment;
230                 if (result < start) {
231                     return result;
232                 }
233                 throw new IllegalStateException("end of sequence");
234             }
235         };
236 
237         
238         final int n = 37; 
239         final int[] histogram = new int[n];
240 
241         final int lower = 0;
242         final int upper = n - 1;
243 
244         final SharedStateDiscreteSampler sampler = DiscreteUniformSampler.of(rng, lower, upper);
245 
246         try {
247             while (true) {
248                 histogram[sampler.sample()]++;
249             }
250         } catch (IllegalStateException ex) {
251             
252         }
253 
254         
255         int min = histogram[0];
256         int max = histogram[0];
257         for (int value : histogram) {
258             min = Math.min(min, value);
259             max = Math.max(max, value);
260         }
261         Assertions.assertTrue(max - min <= 1, "Not uniform, max = " + max + ", min=" + min);
262     }
263 
264     
265 
266 
267     @Test
268     void testSampleUniformityWithPowerOf2Range() {
269         
270         
271         
272         final UniformRandomProvider rng = new IntProvider() {
273             private int bits = 0;
274 
275             @Override
276             public int next() {
277                 
278                 return Integer.reverse(bits++);
279             }
280         };
281 
282         
283         final int n = 32; 
284         final int[] histogram = new int[n];
285 
286         final int lower = 0;
287         final int upper = n - 1;
288 
289         final SharedStateDiscreteSampler sampler = DiscreteUniformSampler.of(rng, lower, upper);
290 
291         final int expected = 2;
292         for (int i = expected * n; i-- > 0;) {
293             histogram[sampler.sample()]++;
294         }
295 
296         
297         for (int value : histogram) {
298             Assertions.assertEquals(expected, value);
299         }
300     }
301 
302     
303 
304 
305 
306 
307 
308     @Test
309     void testSampleRejectionWithNonPowerOf2Range() {
310         
311         
312         final int[] value = new int[1];
313         final UniformRandomProvider rng = new IntProvider() {
314             @Override
315             public int next() {
316                 return value[0]++;
317             }
318         };
319 
320         
321         
322         final int n = 37;
323         final int lower = 0;
324         final int upper = n - 1;
325 
326         final SharedStateDiscreteSampler sampler = DiscreteUniformSampler.of(rng, lower, upper);
327 
328         final int sample = sampler.sample();
329 
330         Assertions.assertEquals(0, sample, "Sample is incorrect");
331         Assertions.assertEquals(2, value[0], "Sample should be produced from 2nd value");
332     }
333 
334     @Test
335     void testSharedStateSamplerWithSmallRange() {
336         testSharedStateSampler(5, 67);
337     }
338 
339     @Test
340     void testSharedStateSamplerWithLargeRange() {
341         
342         testSharedStateSampler(Integer.MIN_VALUE / 2 - 1, Integer.MAX_VALUE / 2 + 1);
343     }
344 
345     @Test
346     void testSharedStateSamplerWithPowerOf2Range() {
347         testSharedStateSampler(0, 31);
348     }
349 
350     @Test
351     void testSharedStateSamplerWithRangeOf1() {
352         testSharedStateSampler(9, 9);
353     }
354 
355     
356 
357 
358 
359 
360 
361     private static void testSharedStateSampler(int lower, int upper) {
362         final UniformRandomProvider rng1 = RandomAssert.seededRNG();
363         final UniformRandomProvider rng2 = RandomAssert.seededRNG();
364         
365         final SharedStateDiscreteSampler sampler1 =
366             new DiscreteUniformSampler(rng1, lower, upper);
367         final SharedStateDiscreteSampler sampler2 = sampler1.withUniformRandomProvider(rng2);
368         RandomAssert.assertProduceSameSequence(sampler1, sampler2);
369     }
370 
371     @Test
372     void testToStringWithSmallRange() {
373         assertToString(5, 67);
374     }
375 
376     @Test
377     void testToStringWithLargeRange() {
378         assertToString(-99999999, Integer.MAX_VALUE);
379     }
380 
381     @Test
382     void testToStringWithPowerOf2Range() {
383         
384         assertToString(0, 31);
385     }
386 
387     @Test
388     void testToStringWithRangeOf1() {
389         assertToString(9, 9);
390     }
391 
392     
393 
394 
395 
396 
397 
398 
399     private static void assertToString(int lower, int upper) {
400         final UniformRandomProvider rng = RandomAssert.seededRNG();
401         final DiscreteUniformSampler sampler = new DiscreteUniformSampler(rng, lower, upper);
402         Assertions.assertTrue(sampler.toString().toLowerCase(Locale.US).contains("uniform"));
403     }
404 }