View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.rng.core.util;
18  
19  import java.util.Objects;
20  import java.util.Spliterator;
21  import java.util.function.Consumer;
22  import java.util.stream.Stream;
23  import java.util.stream.StreamSupport;
24  import org.apache.commons.rng.SplittableUniformRandomProvider;
25  import org.apache.commons.rng.UniformRandomProvider;
26  
27  /**
28   * Utility for creating streams using a source of randomness.
29   *
30   * @since 1.5
31   */
32  public final class RandomStreams {
33      /** The number of bits of each random character in the seed.
34       * The generation algorithm will work if this is in the range [2, 30]. */
35      private static final int SEED_CHAR_BITS = 4;
36  
37      /**
38       * A factory for creating objects using a seed and a using a source of randomness.
39       *
40       * @param <T> the object type
41       * @since 1.5
42       */
43      public interface SeededObjectFactory<T> {
44          /**
45           * Creates the object.
46           *
47           * @param seed Seed used to initialise the instance.
48           * @param source Source of randomness used to initialise the instance.
49           * @return the object
50           */
51          T create(long seed, UniformRandomProvider source);
52      }
53  
54      /**
55       * Class contains only static methods.
56       */
57      private RandomStreams() {}
58  
59      /**
60       * Returns a stream producing the given {@code streamSize} number of new objects
61       * generated using the supplied {@code source} of randomness and object {@code factory}.
62       *
63       * <p>A {@code long} seed is provided for each object instance using the stream position
64       * and random bits created from the supplied {@code source}.
65       *
66       * <p>The stream supports parallel execution by splitting the provided {@code source}
67       * of randomness. Consequently objects in the same position in the stream created from
68       * a sequential stream may be created from a different source of randomness than a parallel
69       * stream; it is not expected that parallel execution will create the same final
70       * collection of objects.
71       *
72       * @param <T> the object type
73       * @param streamSize Number of objects to generate.
74       * @param source A source of randomness used to initialise the new instances; this may
75       * be split to provide a source of randomness across a parallel stream.
76       * @param factory Factory to create new instances.
77       * @return a stream of objects; the stream is limited to the given {@code streamSize}.
78       * @throws IllegalArgumentException if {@code streamSize} is negative.
79       * @throws NullPointerException if {@code source} or {@code factory} is null.
80       */
81      public static <T> Stream<T> generateWithSeed(long streamSize,
82                                                   SplittableUniformRandomProvider source,
83                                                   SeededObjectFactory<T> factory) {
84          if (streamSize < 0) {
85              throw new IllegalArgumentException("Invalid stream size: " + streamSize);
86          }
87          Objects.requireNonNull(source, "source");
88          Objects.requireNonNull(factory, "factory");
89          final long seed = createSeed(source);
90          return StreamSupport
91              .stream(new SeededObjectSpliterator<>(0, streamSize, source, factory, seed), false);
92      }
93  
94      /**
95       * Creates a seed to prepend to a counter. The seed is created to satisfy the following
96       * requirements:
97       * <ul>
98       * <li>The least significant bit is set
99       * <li>The seed is composed of characters from an n-bit alphabet
100      * <li>The character used in the least significant bits is unique
101      * <li>The other characters are sampled uniformly from the remaining (n-1) characters
102      * </ul>
103      *
104      * <p>The composed seed is created using {@code ((seed << shift) | count)}
105      * where the shift is applied to ensure non-overlap of the shifted seed and
106      * the count. This is achieved by ensuring the lowest 1-bit of the seed is
107      * above the highest 1-bit of the count. The shift is a multiple of n to ensure
108      * the character used in the least significant bits aligns with higher characters
109      * after a shift. As higher characters exclude the least significant character
110      * no shifted seed can duplicate previously observed composed seeds. This holds
111      * until the least significant character itself is shifted out of the composed seed.
112      *
113      * <p>The seed generation algorithm starts with a random series of bits with the lowest bit
114      * set. Any occurrences of the least significant character in the remaining characters are
115      * replaced using {@link UniformRandomProvider#nextInt()}.
116      *
117      * <p>The remaining characters will be rejected at a rate of 2<sup>-n</sup>. The
118      * character size is a compromise between a low rejection rate and the highest supported
119      * count that may receive a prepended seed.
120      *
121      * <p>The JDK's {@code java.util.random} package uses 4-bits for the character size when
122      * creating a stream of SplittableGenerator. This achieves a rejection rate
123      * of {@code 1/16}. Using this size will require 1 call to generate a {@code long} and
124      * on average 1 call to {@code nextInt(15)}. The maximum supported stream size with a unique
125      * seed per object is 2<sup>60</sup>. The algorithm here also uses a character size of 4-bits;
126      * this simplifies the implementation as there are exactly 16 characters. The algorithm is a
127      * different implementation to the JDK and creates an output seed with similar properties.
128      *
129      * @param rng Source of randomness.
130      * @return the seed
131      */
132     static long createSeed(UniformRandomProvider rng) {
133         // Initial random bits. Lowest bit must be set.
134         long bits = rng.nextLong() | 1;
135         // Mask to extract characters.
136         // Can be used to sample from (n-1) n-bit characters.
137         final long n = (1L << SEED_CHAR_BITS) - 1;
138 
139         // Extract the unique character.
140         final long unique = bits & n;
141 
142         // Check the rest of the characters do not match the unique character.
143         // This loop extracts the remaining characters and replaces if required.
144         // This will work if the characters do not evenly divide into 64 as we iterate
145         // over the count of remaining bits. The original order is maintained so that
146         // if the bits already satisfy the requirements they are unchanged.
147         for (int i = SEED_CHAR_BITS; i < Long.SIZE; i += SEED_CHAR_BITS) {
148             // Next character
149             long c = (bits >>> i) & n;
150             if (c == unique) {
151                 // Branch frequency of 2^-bits.
152                 // This code is deliberately branchless.
153                 // Avoid nextInt(n) using: c = floor(n * ([0, 2^32) / 2^32))
154                 // Rejection rate for non-uniformity will be negligible: 2^32 % 15 == 1
155                 // so any rejection algorithm only has to exclude 1 value from nextInt().
156                 c = (n * Integer.toUnsignedLong(rng.nextInt())) >>> Integer.SIZE;
157                 // Ensure the sample is uniform in [0, n] excluding the unique character
158                 c = (unique + c + 1) & n;
159                 // Replace by masking out the current character and bitwise add the new one
160                 bits = (bits & ~(n << i)) | (c << i);
161             }
162         }
163         return bits;
164     }
165 
166     /**
167      * Spliterator for streams of a given object type that can be created from a seed
168      * and source of randomness. The source of randomness is splittable allowing parallel
169      * stream support.
170      *
171      * <p>The seed is mixed with the stream position to ensure each object is created using
172      * a unique seed value. As the position increases the seed is left shifted until there
173      * is no bit overlap between the seed and the position, i.e the right-most 1-bit of the seed
174      * is larger than the left-most 1-bit of the position.
175      *s
176      * @param <T> the object type
177      */
178     private static final class SeededObjectSpliterator<T>
179             implements Spliterator<T> {
180         /** Message when the consumer action is null. */
181         private static final String NULL_ACTION = "action must not be null";
182 
183         /** The current position in the range. */
184         private long position;
185         /** The upper limit of the range. */
186         private final long end;
187         /** Seed used to initialise the new instances. The least significant 1-bit of
188          * the seed must be above the most significant bit of the position. This is maintained
189          * by left shift when the position is updated. */
190         private long seed;
191         /** Source of randomness used to initialise the new instances. */
192         private final SplittableUniformRandomProvider source;
193         /** Factory to create new instances. */
194         private final SeededObjectFactory<T> factory;
195 
196         /**
197          * @param start Start position of the stream (inclusive).
198          * @param end Upper limit of the stream (exclusive).
199          * @param source Source of randomness used to initialise the new instances.
200          * @param factory Factory to create new instances.
201          * @param seed Seed used to initialise the instances. The least significant 1-bit of
202          * the seed must be above the most significant bit of the {@code start} position.
203          */
204         SeededObjectSpliterator(long start, long end,
205                                 SplittableUniformRandomProvider source,
206                                 SeededObjectFactory<T> factory,
207                                 long seed) {
208             position = start;
209             this.end = end;
210             this.seed = seed;
211             this.source = source;
212             this.factory = factory;
213         }
214 
215         @Override
216         public long estimateSize() {
217             return end - position;
218         }
219 
220         @Override
221         public int characteristics() {
222             return Spliterator.SIZED | Spliterator.SUBSIZED | Spliterator.IMMUTABLE;
223         }
224 
225         @Override
226         public Spliterator<T> trySplit() {
227             final long start = position;
228             final long middle = (start + end) >>> 1;
229             if (middle <= start) {
230                 return null;
231             }
232             // The child spliterator can use the same seed as the position does not overlap
233             final SeededObjectSpliterator<T> s =
234                 new SeededObjectSpliterator<>(start, middle, source.split(), factory, seed);
235             // Since the position has increased ensure the seed does not overlap
236             position = middle;
237             while (seed != 0 && Long.compareUnsigned(Long.lowestOneBit(seed), middle) <= 0) {
238                 seed <<= SEED_CHAR_BITS;
239             }
240             return s;
241         }
242 
243         @Override
244         public boolean tryAdvance(Consumer<? super T> action) {
245             Objects.requireNonNull(action, NULL_ACTION);
246             final long pos = position;
247             if (pos < end) {
248                 // Advance before exceptions from the action are relayed to the caller
249                 position = pos + 1;
250                 action.accept(factory.create(seed | pos, source));
251                 // If the position overlaps the seed, shift it by 1 character
252                 if ((position & seed) != 0) {
253                     seed <<= SEED_CHAR_BITS;
254                 }
255                 return true;
256             }
257             return false;
258         }
259 
260         @Override
261         public void forEachRemaining(Consumer<? super T> action) {
262             Objects.requireNonNull(action, NULL_ACTION);
263             long pos = position;
264             final long last = end;
265             if (pos < last) {
266                 // Ensure forEachRemaining is called only once
267                 position = last;
268                 final SplittableUniformRandomProvider s = source;
269                 final SeededObjectFactory<T> f = factory;
270                 do {
271                     action.accept(f.create(seed | pos, s));
272                     pos++;
273                     // If the position overlaps the seed, shift it by 1 character
274                     if ((pos & seed) != 0) {
275                         seed <<= SEED_CHAR_BITS;
276                     }
277                 } while (pos < last);
278             }
279         }
280     }
281 }