RandomStreams.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.rng.core.util;

  18. import java.util.Objects;
  19. import java.util.Spliterator;
  20. import java.util.function.Consumer;
  21. import java.util.stream.Stream;
  22. import java.util.stream.StreamSupport;
  23. import org.apache.commons.rng.SplittableUniformRandomProvider;
  24. import org.apache.commons.rng.UniformRandomProvider;

  25. /**
  26.  * Utility for creating streams using a source of randomness.
  27.  *
  28.  * @since 1.5
  29.  */
  30. public final class RandomStreams {
  31.     /** The number of bits of each random character in the seed.
  32.      * The generation algorithm will work if this is in the range [2, 30]. */
  33.     private static final int SEED_CHAR_BITS = 4;

  34.     /**
  35.      * A factory for creating objects using a seed and a using a source of randomness.
  36.      *
  37.      * @param <T> the object type
  38.      * @since 1.5
  39.      */
  40.     public interface SeededObjectFactory<T> {
  41.         /**
  42.          * Creates the object.
  43.          *
  44.          * @param seed Seed used to initialise the instance.
  45.          * @param source Source of randomness used to initialise the instance.
  46.          * @return the object
  47.          */
  48.         T create(long seed, UniformRandomProvider source);
  49.     }

  50.     /**
  51.      * Class contains only static methods.
  52.      */
  53.     private RandomStreams() {}

  54.     /**
  55.      * Returns a stream producing the given {@code streamSize} number of new objects
  56.      * generated using the supplied {@code source} of randomness and object {@code factory}.
  57.      *
  58.      * <p>A {@code long} seed is provided for each object instance using the stream position
  59.      * and random bits created from the supplied {@code source}.
  60.      *
  61.      * <p>The stream supports parallel execution by splitting the provided {@code source}
  62.      * of randomness. Consequently objects in the same position in the stream created from
  63.      * a sequential stream may be created from a different source of randomness than a parallel
  64.      * stream; it is not expected that parallel execution will create the same final
  65.      * collection of objects.
  66.      *
  67.      * @param <T> the object type
  68.      * @param streamSize Number of objects to generate.
  69.      * @param source A source of randomness used to initialise the new instances; this may
  70.      * be split to provide a source of randomness across a parallel stream.
  71.      * @param factory Factory to create new instances.
  72.      * @return a stream of objects; the stream is limited to the given {@code streamSize}.
  73.      * @throws IllegalArgumentException if {@code streamSize} is negative.
  74.      * @throws NullPointerException if {@code source} or {@code factory} is null.
  75.      */
  76.     public static <T> Stream<T> generateWithSeed(long streamSize,
  77.                                                  SplittableUniformRandomProvider source,
  78.                                                  SeededObjectFactory<T> factory) {
  79.         if (streamSize < 0) {
  80.             throw new IllegalArgumentException("Invalid stream size: " + streamSize);
  81.         }
  82.         Objects.requireNonNull(source, "source");
  83.         Objects.requireNonNull(factory, "factory");
  84.         final long seed = createSeed(source);
  85.         return StreamSupport
  86.             .stream(new SeededObjectSpliterator<>(0, streamSize, source, factory, seed), false);
  87.     }

  88.     /**
  89.      * Creates a seed to prepend to a counter. The seed is created to satisfy the following
  90.      * requirements:
  91.      * <ul>
  92.      * <li>The least significant bit is set
  93.      * <li>The seed is composed of characters from an n-bit alphabet
  94.      * <li>The character used in the least significant bits is unique
  95.      * <li>The other characters are sampled uniformly from the remaining (n-1) characters
  96.      * </ul>
  97.      *
  98.      * <p>The composed seed is created using {@code ((seed << shift) | count)}
  99.      * where the shift is applied to ensure non-overlap of the shifted seed and
  100.      * the count. This is achieved by ensuring the lowest 1-bit of the seed is
  101.      * above the highest 1-bit of the count. The shift is a multiple of n to ensure
  102.      * the character used in the least significant bits aligns with higher characters
  103.      * after a shift. As higher characters exclude the least significant character
  104.      * no shifted seed can duplicate previously observed composed seeds. This holds
  105.      * until the least significant character itself is shifted out of the composed seed.
  106.      *
  107.      * <p>The seed generation algorithm starts with a random series of bits with the lowest bit
  108.      * set. Any occurrences of the least significant character in the remaining characters are
  109.      * replaced using {@link UniformRandomProvider#nextInt()}.
  110.      *
  111.      * <p>The remaining characters will be rejected at a rate of 2<sup>-n</sup>. The
  112.      * character size is a compromise between a low rejection rate and the highest supported
  113.      * count that may receive a prepended seed.
  114.      *
  115.      * <p>The JDK's {@code java.util.random} package uses 4-bits for the character size when
  116.      * creating a stream of SplittableGenerator. This achieves a rejection rate
  117.      * of {@code 1/16}. Using this size will require 1 call to generate a {@code long} and
  118.      * on average 1 call to {@code nextInt(15)}. The maximum supported stream size with a unique
  119.      * seed per object is 2<sup>60</sup>. The algorithm here also uses a character size of 4-bits;
  120.      * this simplifies the implementation as there are exactly 16 characters. The algorithm is a
  121.      * different implementation to the JDK and creates an output seed with similar properties.
  122.      *
  123.      * @param rng Source of randomness.
  124.      * @return the seed
  125.      */
  126.     static long createSeed(UniformRandomProvider rng) {
  127.         // Initial random bits. Lowest bit must be set.
  128.         long bits = rng.nextLong() | 1;
  129.         // Mask to extract characters.
  130.         // Can be used to sample from (n-1) n-bit characters.
  131.         final long n = (1L << SEED_CHAR_BITS) - 1;

  132.         // Extract the unique character.
  133.         final long unique = bits & n;

  134.         // Check the rest of the characters do not match the unique character.
  135.         // This loop extracts the remaining characters and replaces if required.
  136.         // This will work if the characters do not evenly divide into 64 as we iterate
  137.         // over the count of remaining bits. The original order is maintained so that
  138.         // if the bits already satisfy the requirements they are unchanged.
  139.         for (int i = SEED_CHAR_BITS; i < Long.SIZE; i += SEED_CHAR_BITS) {
  140.             // Next character
  141.             long c = (bits >>> i) & n;
  142.             if (c == unique) {
  143.                 // Branch frequency of 2^-bits.
  144.                 // This code is deliberately branchless.
  145.                 // Avoid nextInt(n) using: c = floor(n * ([0, 2^32) / 2^32))
  146.                 // Rejection rate for non-uniformity will be negligible: 2^32 % 15 == 1
  147.                 // so any rejection algorithm only has to exclude 1 value from nextInt().
  148.                 c = (n * Integer.toUnsignedLong(rng.nextInt())) >>> Integer.SIZE;
  149.                 // Ensure the sample is uniform in [0, n] excluding the unique character
  150.                 c = (unique + c + 1) & n;
  151.                 // Replace by masking out the current character and bitwise add the new one
  152.                 bits = (bits & ~(n << i)) | (c << i);
  153.             }
  154.         }
  155.         return bits;
  156.     }

  157.     /**
  158.      * Spliterator for streams of a given object type that can be created from a seed
  159.      * and source of randomness. The source of randomness is splittable allowing parallel
  160.      * stream support.
  161.      *
  162.      * <p>The seed is mixed with the stream position to ensure each object is created using
  163.      * a unique seed value. As the position increases the seed is left shifted until there
  164.      * is no bit overlap between the seed and the position, i.e the right-most 1-bit of the seed
  165.      * is larger than the left-most 1-bit of the position.
  166.      *s
  167.      * @param <T> the object type
  168.      */
  169.     private static final class SeededObjectSpliterator<T>
  170.             implements Spliterator<T> {
  171.         /** Message when the consumer action is null. */
  172.         private static final String NULL_ACTION = "action must not be null";

  173.         /** The current position in the range. */
  174.         private long position;
  175.         /** The upper limit of the range. */
  176.         private final long end;
  177.         /** Seed used to initialise the new instances. The least significant 1-bit of
  178.          * the seed must be above the most significant bit of the position. This is maintained
  179.          * by left shift when the position is updated. */
  180.         private long seed;
  181.         /** Source of randomness used to initialise the new instances. */
  182.         private final SplittableUniformRandomProvider source;
  183.         /** Factory to create new instances. */
  184.         private final SeededObjectFactory<T> factory;

  185.         /**
  186.          * @param start Start position of the stream (inclusive).
  187.          * @param end Upper limit of the stream (exclusive).
  188.          * @param source Source of randomness used to initialise the new instances.
  189.          * @param factory Factory to create new instances.
  190.          * @param seed Seed used to initialise the instances. The least significant 1-bit of
  191.          * the seed must be above the most significant bit of the {@code start} position.
  192.          */
  193.         SeededObjectSpliterator(long start, long end,
  194.                                 SplittableUniformRandomProvider source,
  195.                                 SeededObjectFactory<T> factory,
  196.                                 long seed) {
  197.             position = start;
  198.             this.end = end;
  199.             this.seed = seed;
  200.             this.source = source;
  201.             this.factory = factory;
  202.         }

  203.         @Override
  204.         public long estimateSize() {
  205.             return end - position;
  206.         }

  207.         @Override
  208.         public int characteristics() {
  209.             return SIZED | SUBSIZED | IMMUTABLE;
  210.         }

  211.         @Override
  212.         public Spliterator<T> trySplit() {
  213.             final long start = position;
  214.             final long middle = (start + end) >>> 1;
  215.             if (middle <= start) {
  216.                 return null;
  217.             }
  218.             // The child spliterator can use the same seed as the position does not overlap
  219.             final SeededObjectSpliterator<T> s =
  220.                 new SeededObjectSpliterator<>(start, middle, source.split(), factory, seed);
  221.             // Since the position has increased ensure the seed does not overlap
  222.             position = middle;
  223.             while (seed != 0 && Long.compareUnsigned(Long.lowestOneBit(seed), middle) <= 0) {
  224.                 seed <<= SEED_CHAR_BITS;
  225.             }
  226.             return s;
  227.         }

  228.         @Override
  229.         public boolean tryAdvance(Consumer<? super T> action) {
  230.             Objects.requireNonNull(action, NULL_ACTION);
  231.             final long pos = position;
  232.             if (pos < end) {
  233.                 // Advance before exceptions from the action are relayed to the caller
  234.                 position = pos + 1;
  235.                 action.accept(factory.create(seed | pos, source));
  236.                 // If the position overlaps the seed, shift it by 1 character
  237.                 if ((position & seed) != 0) {
  238.                     seed <<= SEED_CHAR_BITS;
  239.                 }
  240.                 return true;
  241.             }
  242.             return false;
  243.         }

  244.         @Override
  245.         public void forEachRemaining(Consumer<? super T> action) {
  246.             Objects.requireNonNull(action, NULL_ACTION);
  247.             long pos = position;
  248.             final long last = end;
  249.             if (pos < last) {
  250.                 // Ensure forEachRemaining is called only once
  251.                 position = last;
  252.                 final SplittableUniformRandomProvider s = source;
  253.                 final SeededObjectFactory<T> f = factory;
  254.                 do {
  255.                     action.accept(f.create(seed | pos, s));
  256.                     pos++;
  257.                     // If the position overlaps the seed, shift it by 1 character
  258.                     if ((pos & seed) != 0) {
  259.                         seed <<= SEED_CHAR_BITS;
  260.                     }
  261.                 } while (pos < last);
  262.             }
  263.         }
  264.     }
  265. }