1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 package org.apache.commons.rng.core.util; 18 19 import java.util.Objects; 20 import java.util.Spliterator; 21 import java.util.function.Consumer; 22 import java.util.stream.Stream; 23 import java.util.stream.StreamSupport; 24 import org.apache.commons.rng.SplittableUniformRandomProvider; 25 import org.apache.commons.rng.UniformRandomProvider; 26 27 /** 28 * Utility for creating streams using a source of randomness. 29 * 30 * @since 1.5 31 */ 32 public final class RandomStreams { 33 /** The number of bits of each random character in the seed. 34 * The generation algorithm will work if this is in the range [2, 30]. */ 35 private static final int SEED_CHAR_BITS = 4; 36 37 /** 38 * A factory for creating objects using a seed and a using a source of randomness. 39 * 40 * @param <T> the object type 41 * @since 1.5 42 */ 43 public interface SeededObjectFactory<T> { 44 /** 45 * Creates the object. 46 * 47 * @param seed Seed used to initialise the instance. 48 * @param source Source of randomness used to initialise the instance. 49 * @return the object 50 */ 51 T create(long seed, UniformRandomProvider source); 52 } 53 54 /** 55 * Class contains only static methods. 56 */ 57 private RandomStreams() {} 58 59 /** 60 * Returns a stream producing the given {@code streamSize} number of new objects 61 * generated using the supplied {@code source} of randomness and object {@code factory}. 62 * 63 * <p>A {@code long} seed is provided for each object instance using the stream position 64 * and random bits created from the supplied {@code source}. 65 * 66 * <p>The stream supports parallel execution by splitting the provided {@code source} 67 * of randomness. Consequently objects in the same position in the stream created from 68 * a sequential stream may be created from a different source of randomness than a parallel 69 * stream; it is not expected that parallel execution will create the same final 70 * collection of objects. 71 * 72 * @param <T> the object type 73 * @param streamSize Number of objects to generate. 74 * @param source A source of randomness used to initialise the new instances; this may 75 * be split to provide a source of randomness across a parallel stream. 76 * @param factory Factory to create new instances. 77 * @return a stream of objects; the stream is limited to the given {@code streamSize}. 78 * @throws IllegalArgumentException if {@code streamSize} is negative. 79 * @throws NullPointerException if {@code source} or {@code factory} is null. 80 */ 81 public static <T> Stream<T> generateWithSeed(long streamSize, 82 SplittableUniformRandomProvider source, 83 SeededObjectFactory<T> factory) { 84 if (streamSize < 0) { 85 throw new IllegalArgumentException("Invalid stream size: " + streamSize); 86 } 87 Objects.requireNonNull(source, "source"); 88 Objects.requireNonNull(factory, "factory"); 89 final long seed = createSeed(source); 90 return StreamSupport 91 .stream(new SeededObjectSpliterator<>(0, streamSize, source, factory, seed), false); 92 } 93 94 /** 95 * Creates a seed to prepend to a counter. The seed is created to satisfy the following 96 * requirements: 97 * <ul> 98 * <li>The least significant bit is set 99 * <li>The seed is composed of characters from an n-bit alphabet 100 * <li>The character used in the least significant bits is unique 101 * <li>The other characters are sampled uniformly from the remaining (n-1) characters 102 * </ul> 103 * 104 * <p>The composed seed is created using {@code ((seed << shift) | count)} 105 * where the shift is applied to ensure non-overlap of the shifted seed and 106 * the count. This is achieved by ensuring the lowest 1-bit of the seed is 107 * above the highest 1-bit of the count. The shift is a multiple of n to ensure 108 * the character used in the least significant bits aligns with higher characters 109 * after a shift. As higher characters exclude the least significant character 110 * no shifted seed can duplicate previously observed composed seeds. This holds 111 * until the least significant character itself is shifted out of the composed seed. 112 * 113 * <p>The seed generation algorithm starts with a random series of bits with the lowest bit 114 * set. Any occurrences of the least significant character in the remaining characters are 115 * replaced using {@link UniformRandomProvider#nextInt()}. 116 * 117 * <p>The remaining characters will be rejected at a rate of 2<sup>-n</sup>. The 118 * character size is a compromise between a low rejection rate and the highest supported 119 * count that may receive a prepended seed. 120 * 121 * <p>The JDK's {@code java.util.random} package uses 4-bits for the character size when 122 * creating a stream of SplittableGenerator. This achieves a rejection rate 123 * of {@code 1/16}. Using this size will require 1 call to generate a {@code long} and 124 * on average 1 call to {@code nextInt(15)}. The maximum supported stream size with a unique 125 * seed per object is 2<sup>60</sup>. The algorithm here also uses a character size of 4-bits; 126 * this simplifies the implementation as there are exactly 16 characters. The algorithm is a 127 * different implementation to the JDK and creates an output seed with similar properties. 128 * 129 * @param rng Source of randomness. 130 * @return the seed 131 */ 132 static long createSeed(UniformRandomProvider rng) { 133 // Initial random bits. Lowest bit must be set. 134 long bits = rng.nextLong() | 1; 135 // Mask to extract characters. 136 // Can be used to sample from (n-1) n-bit characters. 137 final long n = (1L << SEED_CHAR_BITS) - 1; 138 139 // Extract the unique character. 140 final long unique = bits & n; 141 142 // Check the rest of the characters do not match the unique character. 143 // This loop extracts the remaining characters and replaces if required. 144 // This will work if the characters do not evenly divide into 64 as we iterate 145 // over the count of remaining bits. The original order is maintained so that 146 // if the bits already satisfy the requirements they are unchanged. 147 for (int i = SEED_CHAR_BITS; i < Long.SIZE; i += SEED_CHAR_BITS) { 148 // Next character 149 long c = (bits >>> i) & n; 150 if (c == unique) { 151 // Branch frequency of 2^-bits. 152 // This code is deliberately branchless. 153 // Avoid nextInt(n) using: c = floor(n * ([0, 2^32) / 2^32)) 154 // Rejection rate for non-uniformity will be negligible: 2^32 % 15 == 1 155 // so any rejection algorithm only has to exclude 1 value from nextInt(). 156 c = (n * Integer.toUnsignedLong(rng.nextInt())) >>> Integer.SIZE; 157 // Ensure the sample is uniform in [0, n] excluding the unique character 158 c = (unique + c + 1) & n; 159 // Replace by masking out the current character and bitwise add the new one 160 bits = (bits & ~(n << i)) | (c << i); 161 } 162 } 163 return bits; 164 } 165 166 /** 167 * Spliterator for streams of a given object type that can be created from a seed 168 * and source of randomness. The source of randomness is splittable allowing parallel 169 * stream support. 170 * 171 * <p>The seed is mixed with the stream position to ensure each object is created using 172 * a unique seed value. As the position increases the seed is left shifted until there 173 * is no bit overlap between the seed and the position, i.e the right-most 1-bit of the seed 174 * is larger than the left-most 1-bit of the position. 175 *s 176 * @param <T> the object type 177 */ 178 private static final class SeededObjectSpliterator<T> 179 implements Spliterator<T> { 180 /** Message when the consumer action is null. */ 181 private static final String NULL_ACTION = "action must not be null"; 182 183 /** The current position in the range. */ 184 private long position; 185 /** The upper limit of the range. */ 186 private final long end; 187 /** Seed used to initialise the new instances. The least significant 1-bit of 188 * the seed must be above the most significant bit of the position. This is maintained 189 * by left shift when the position is updated. */ 190 private long seed; 191 /** Source of randomness used to initialise the new instances. */ 192 private final SplittableUniformRandomProvider source; 193 /** Factory to create new instances. */ 194 private final SeededObjectFactory<T> factory; 195 196 /** 197 * @param start Start position of the stream (inclusive). 198 * @param end Upper limit of the stream (exclusive). 199 * @param source Source of randomness used to initialise the new instances. 200 * @param factory Factory to create new instances. 201 * @param seed Seed used to initialise the instances. The least significant 1-bit of 202 * the seed must be above the most significant bit of the {@code start} position. 203 */ 204 SeededObjectSpliterator(long start, long end, 205 SplittableUniformRandomProvider source, 206 SeededObjectFactory<T> factory, 207 long seed) { 208 position = start; 209 this.end = end; 210 this.seed = seed; 211 this.source = source; 212 this.factory = factory; 213 } 214 215 @Override 216 public long estimateSize() { 217 return end - position; 218 } 219 220 @Override 221 public int characteristics() { 222 return Spliterator.SIZED | Spliterator.SUBSIZED | Spliterator.IMMUTABLE; 223 } 224 225 @Override 226 public Spliterator<T> trySplit() { 227 final long start = position; 228 final long middle = (start + end) >>> 1; 229 if (middle <= start) { 230 return null; 231 } 232 // The child spliterator can use the same seed as the position does not overlap 233 final SeededObjectSpliterator<T> s = 234 new SeededObjectSpliterator<>(start, middle, source.split(), factory, seed); 235 // Since the position has increased ensure the seed does not overlap 236 position = middle; 237 while (seed != 0 && Long.compareUnsigned(Long.lowestOneBit(seed), middle) <= 0) { 238 seed <<= SEED_CHAR_BITS; 239 } 240 return s; 241 } 242 243 @Override 244 public boolean tryAdvance(Consumer<? super T> action) { 245 Objects.requireNonNull(action, NULL_ACTION); 246 final long pos = position; 247 if (pos < end) { 248 // Advance before exceptions from the action are relayed to the caller 249 position = pos + 1; 250 action.accept(factory.create(seed | pos, source)); 251 // If the position overlaps the seed, shift it by 1 character 252 if ((position & seed) != 0) { 253 seed <<= SEED_CHAR_BITS; 254 } 255 return true; 256 } 257 return false; 258 } 259 260 @Override 261 public void forEachRemaining(Consumer<? super T> action) { 262 Objects.requireNonNull(action, NULL_ACTION); 263 long pos = position; 264 final long last = end; 265 if (pos < last) { 266 // Ensure forEachRemaining is called only once 267 position = last; 268 final SplittableUniformRandomProvider s = source; 269 final SeededObjectFactory<T> f = factory; 270 do { 271 action.accept(f.create(seed | pos, s)); 272 pos++; 273 // If the position overlaps the seed, shift it by 1 character 274 if ((pos & seed) != 0) { 275 seed <<= SEED_CHAR_BITS; 276 } 277 } while (pos < last); 278 } 279 } 280 } 281 }