DiscreteUniformSampler.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.rng.sampling.distribution;

  18. import org.apache.commons.rng.UniformRandomProvider;

  19. /**
  20.  * Discrete uniform distribution sampler.
  21.  *
  22.  * <p>Sampling uses {@link UniformRandomProvider#nextInt}.</p>
  23.  *
  24.  * <p>When the range is a power of two the number of calls is 1 per sample.
  25.  * Otherwise a rejection algorithm is used to ensure uniformity. In the worst
  26.  * case scenario where the range spans half the range of an {@code int}
  27.  * (2<sup>31</sup> + 1) the expected number of calls is 2 per sample.</p>
  28.  *
  29.  * <p>This sampler can be used as a replacement for {@link UniformRandomProvider#nextInt}
  30.  * with appropriate adjustment of the upper bound to be inclusive and will outperform that
  31.  * method when the range is not a power of two. The advantage is gained by pre-computation
  32.  * of the rejection threshold.</p>
  33.  *
  34.  * <p>The sampling algorithm is described in:</p>
  35.  *
  36.  * <blockquote>
  37.  *  Lemire, D (2019). <i>Fast Random Integer Generation in an Interval.</i>
  38.  *  <strong>ACM Transactions on Modeling and Computer Simulation</strong> 29 (1).
  39.  * </blockquote>
  40.  *
  41.  * <p>The number of {@code int} values required per sample follows a geometric distribution with
  42.  * a probability of success p of {@code 1 - ((2^32 % n) / 2^32)}. This requires on average 1/p random
  43.  * {@code int} values per sample.</p>
  44.  *
  45.  * @see <a href="https://arxiv.org/abs/1805.10941">Fast Random Integer Generation in an Interval</a>
  46.  *
  47.  * @since 1.0
  48.  */
  49. public class DiscreteUniformSampler
  50.     extends SamplerBase
  51.     implements SharedStateDiscreteSampler {

  52.     /** The appropriate uniform sampler for the parameters. */
  53.     private final SharedStateDiscreteSampler delegate;

  54.     /**
  55.      * Base class for a sampler from a discrete uniform distribution. This contains the
  56.      * source of randomness.
  57.      */
  58.     private abstract static class AbstractDiscreteUniformSampler
  59.             implements SharedStateDiscreteSampler {
  60.         /** Underlying source of randomness. */
  61.         protected final UniformRandomProvider rng;

  62.         /**
  63.          * @param rng Generator of uniformly distributed random numbers.
  64.          */
  65.         AbstractDiscreteUniformSampler(UniformRandomProvider rng) {
  66.             this.rng = rng;
  67.         }

  68.         @Override
  69.         public String toString() {
  70.             return "Uniform deviate [" + rng.toString() + "]";
  71.         }
  72.     }

  73.     /**
  74.      * Discrete uniform distribution sampler when the sample value is fixed.
  75.      */
  76.     private static final class FixedDiscreteUniformSampler
  77.             extends AbstractDiscreteUniformSampler {
  78.         /** The value. */
  79.         private final int value;

  80.         /**
  81.          * @param value The value.
  82.          */
  83.         FixedDiscreteUniformSampler(int value) {
  84.             // No requirement for the RNG
  85.             super(null);
  86.             this.value = value;
  87.         }

  88.         @Override
  89.         public int sample() {
  90.             return value;
  91.         }

  92.         @Override
  93.         public String toString() {
  94.             // No RNG to include in the string
  95.             return "Uniform deviate [X=" + value + "]";
  96.         }

  97.         @Override
  98.         public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
  99.             // No requirement for the RNG
  100.             return this;
  101.         }
  102.     }

  103.     /**
  104.      * Discrete uniform distribution sampler when the range is a power of 2 and greater than 1.
  105.      * This sampler assumes the lower bound of the range is 0.
  106.      *
  107.      * <p>Note: This cannot be used when the range is 1 (2^0) as the shift would be 32-bits
  108.      * which is ignored by the shift operator.</p>
  109.      */
  110.     private static final class PowerOf2RangeDiscreteUniformSampler
  111.             extends AbstractDiscreteUniformSampler {
  112.         /** Bit shift to apply to the integer sample. */
  113.         private final int shift;

  114.         /**
  115.          * @param rng Generator of uniformly distributed random numbers.
  116.          * @param range Maximum range of the sample (exclusive).
  117.          * Must be a power of 2 greater than 2^0.
  118.          */
  119.         PowerOf2RangeDiscreteUniformSampler(UniformRandomProvider rng,
  120.                                             int range) {
  121.             super(rng);
  122.             this.shift = Integer.numberOfLeadingZeros(range) + 1;
  123.         }

  124.         /**
  125.          * @param rng Generator of uniformly distributed random numbers.
  126.          * @param source Source to copy.
  127.          */
  128.         PowerOf2RangeDiscreteUniformSampler(UniformRandomProvider rng,
  129.                                             PowerOf2RangeDiscreteUniformSampler source) {
  130.             super(rng);
  131.             this.shift = source.shift;
  132.         }

  133.         @Override
  134.         public int sample() {
  135.             // Use a bit shift to favour the most significant bits.
  136.             // Note: The result would be the same as the rejection method used in the
  137.             // small range sampler when there is no rejection threshold.
  138.             return rng.nextInt() >>> shift;
  139.         }

  140.         @Override
  141.         public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
  142.             return new PowerOf2RangeDiscreteUniformSampler(rng, this);
  143.         }
  144.     }

  145.     /**
  146.      * Discrete uniform distribution sampler when the range is small
  147.      * enough to fit in a positive integer.
  148.      * This sampler assumes the lower bound of the range is 0.
  149.      *
  150.      * <p>Implements the algorithm of Lemire (2019).</p>
  151.      *
  152.      * @see <a href="https://arxiv.org/abs/1805.10941">Fast Random Integer Generation in an Interval</a>
  153.      */
  154.     private static final class SmallRangeDiscreteUniformSampler
  155.             extends AbstractDiscreteUniformSampler {
  156.         /** Maximum range of the sample (exclusive). */
  157.         private final long n;

  158.         /**
  159.          * The level below which samples are rejected based on the fraction remainder.
  160.          *
  161.          * <p>Any remainder below this denotes that there are still floor(2^32 / n) more
  162.          * observations of this sample from the interval [0, 2^32), where n is the range.</p>
  163.          */
  164.         private final long threshold;

  165.         /**
  166.          * @param rng Generator of uniformly distributed random numbers.
  167.          * @param range Maximum range of the sample (exclusive).
  168.          */
  169.         SmallRangeDiscreteUniformSampler(UniformRandomProvider rng,
  170.                                          int range) {
  171.             super(rng);
  172.             // Handle range as an unsigned 32-bit integer
  173.             this.n = range & 0xffffffffL;
  174.             // Compute 2^32 % n
  175.             threshold = (1L << 32) % n;
  176.         }

  177.         /**
  178.          * @param rng Generator of uniformly distributed random numbers.
  179.          * @param source Source to copy.
  180.          */
  181.         SmallRangeDiscreteUniformSampler(UniformRandomProvider rng,
  182.                                          SmallRangeDiscreteUniformSampler source) {
  183.             super(rng);
  184.             this.n = source.n;
  185.             this.threshold = source.threshold;
  186.         }

  187.         @Override
  188.         public int sample() {
  189.             // Rejection method using multiply by a fraction:
  190.             // n * [0, 2^32 - 1)
  191.             //     -------------
  192.             //         2^32
  193.             // The result is mapped back to an integer and will be in the range [0, n).
  194.             // Note this is comparable to range * rng.nextDouble() but with compensation for
  195.             // non-uniformity due to round-off.
  196.             long result;
  197.             do {
  198.                 // Compute 64-bit unsigned product of n * [0, 2^32 - 1).
  199.                 // The upper 32-bits contains the sample value in the range [0, n), i.e. result / 2^32.
  200.                 // The lower 32-bits contains the remainder, i.e. result % 2^32.
  201.                 result = n * (rng.nextInt() & 0xffffffffL);

  202.                 // Test the sample uniformity.
  203.                 // Samples are observed on average (2^32 / n) times at a frequency of either
  204.                 // floor(2^32 / n) or ceil(2^32 / n).
  205.                 // To ensure all samples have a frequency of floor(2^32 / n) reject any results with
  206.                 // a remainder < (2^32 % n), i.e. the level below which denotes that there are still
  207.                 // floor(2^32 / n) more observations of this sample.
  208.             } while ((result & 0xffffffffL) < threshold);
  209.             // Divide by 2^32 to get the sample.
  210.             return (int)(result >>> 32);
  211.         }

  212.         @Override
  213.         public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
  214.             return new SmallRangeDiscreteUniformSampler(rng, this);
  215.         }
  216.     }

  217.     /**
  218.      * Discrete uniform distribution sampler when the range between lower and upper is too large
  219.      * to fit in a positive integer.
  220.      */
  221.     private static final class LargeRangeDiscreteUniformSampler
  222.             extends AbstractDiscreteUniformSampler {
  223.         /** Lower bound. */
  224.         private final int lower;
  225.         /** Upper bound. */
  226.         private final int upper;

  227.         /**
  228.          * @param rng Generator of uniformly distributed random numbers.
  229.          * @param lower Lower bound (inclusive) of the distribution.
  230.          * @param upper Upper bound (inclusive) of the distribution.
  231.          */
  232.         LargeRangeDiscreteUniformSampler(UniformRandomProvider rng,
  233.                                          int lower,
  234.                                          int upper) {
  235.             super(rng);
  236.             this.lower = lower;
  237.             this.upper = upper;
  238.         }

  239.         @Override
  240.         public int sample() {
  241.             // Use a simple rejection method.
  242.             // This is used when (upper-lower) >= Integer.MAX_VALUE.
  243.             // This will loop on average 2 times in the worst case scenario
  244.             // when (upper-lower) == Integer.MAX_VALUE.
  245.             while (true) {
  246.                 final int r = rng.nextInt();
  247.                 if (r >= lower &&
  248.                     r <= upper) {
  249.                     return r;
  250.                 }
  251.             }
  252.         }

  253.         @Override
  254.         public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
  255.             return new LargeRangeDiscreteUniformSampler(rng, lower, upper);
  256.         }
  257.     }

  258.     /**
  259.      * Adds an offset to an underlying discrete sampler.
  260.      */
  261.     private static final class OffsetDiscreteUniformSampler
  262.             extends AbstractDiscreteUniformSampler {
  263.         /** The offset. */
  264.         private final int offset;
  265.         /** The discrete sampler. */
  266.         private final SharedStateDiscreteSampler sampler;

  267.         /**
  268.          * @param offset The offset for the sample.
  269.          * @param sampler The discrete sampler.
  270.          */
  271.         OffsetDiscreteUniformSampler(int offset,
  272.                                      SharedStateDiscreteSampler sampler) {
  273.             super(null);
  274.             this.offset = offset;
  275.             this.sampler = sampler;
  276.         }

  277.         @Override
  278.         public int sample() {
  279.             return offset + sampler.sample();
  280.         }

  281.         @Override
  282.         public String toString() {
  283.             return sampler.toString();
  284.         }

  285.         @Override
  286.         public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
  287.             return new OffsetDiscreteUniformSampler(offset, sampler.withUniformRandomProvider(rng));
  288.         }
  289.     }

  290.     /**
  291.      * This instance delegates sampling. Use the factory method
  292.      * {@link #of(UniformRandomProvider, int, int)} to create an optimal sampler.
  293.      *
  294.      * @param rng Generator of uniformly distributed random numbers.
  295.      * @param lower Lower bound (inclusive) of the distribution.
  296.      * @param upper Upper bound (inclusive) of the distribution.
  297.      * @throws IllegalArgumentException if {@code lower > upper}.
  298.      */
  299.     public DiscreteUniformSampler(UniformRandomProvider rng,
  300.                                   int lower,
  301.                                   int upper) {
  302.         this(of(rng, lower, upper));
  303.     }

  304.     /**
  305.      * Private constructor used by to prevent partially initialized object if the construction
  306.      * of the delegate throws. In future versions the public constructor should be removed.
  307.      *
  308.      * @param delegate Delegate.
  309.      */
  310.     private DiscreteUniformSampler(SharedStateDiscreteSampler delegate) {
  311.         super(null);
  312.         this.delegate = delegate;
  313.     }

  314.     /** {@inheritDoc} */
  315.     @Override
  316.     public int sample() {
  317.         return delegate.sample();
  318.     }

  319.     /** {@inheritDoc} */
  320.     @Override
  321.     public String toString() {
  322.         return delegate.toString();
  323.     }

  324.     /**
  325.      * {@inheritDoc}
  326.      *
  327.      * @since 1.3
  328.      */
  329.     @Override
  330.     public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
  331.         // Direct return of the optimised sampler
  332.         return delegate.withUniformRandomProvider(rng);
  333.     }

  334.     /**
  335.      * Creates a new discrete uniform distribution sampler.
  336.      *
  337.      * @param rng Generator of uniformly distributed random numbers.
  338.      * @param lower Lower bound (inclusive) of the distribution.
  339.      * @param upper Upper bound (inclusive) of the distribution.
  340.      * @return the sampler
  341.      * @throws IllegalArgumentException if {@code lower > upper}.
  342.      * @since 1.3
  343.      */
  344.     public static SharedStateDiscreteSampler of(UniformRandomProvider rng,
  345.                                                 int lower,
  346.                                                 int upper) {
  347.         if (lower > upper) {
  348.             throw new IllegalArgumentException(lower  + " > " + upper);
  349.         }

  350.         // Choose the algorithm depending on the range

  351.         // Edge case for no range.
  352.         // This must be done first as the methods to handle lower == 0
  353.         // do not handle upper == 0.
  354.         if (upper == lower) {
  355.             return new FixedDiscreteUniformSampler(lower);
  356.         }

  357.         // Algorithms to ignore the lower bound if it is zero.
  358.         if (lower == 0) {
  359.             return createZeroBoundedSampler(rng, upper);
  360.         }

  361.         final int range = (upper - lower) + 1;
  362.         // Check power of 2 first to handle range == 2^31.
  363.         if (isPowerOf2(range)) {
  364.             return new OffsetDiscreteUniformSampler(lower,
  365.                                                     new PowerOf2RangeDiscreteUniformSampler(rng, range));
  366.         }
  367.         if (range <= 0) {
  368.             // The range is too wide to fit in a positive int (larger
  369.             // than 2^31); use a simple rejection method.
  370.             // Note: if range == 0 then the input is [Integer.MIN_VALUE, Integer.MAX_VALUE].
  371.             // No specialisation exists for this and it is handled as a large range.
  372.             return new LargeRangeDiscreteUniformSampler(rng, lower, upper);
  373.         }
  374.         // Use a sample from the range added to the lower bound.
  375.         return new OffsetDiscreteUniformSampler(lower,
  376.                                                 new SmallRangeDiscreteUniformSampler(rng, range));
  377.     }

  378.     /**
  379.      * Create a new sampler for the range {@code 0} inclusive to {@code upper} inclusive.
  380.      *
  381.      * <p>This can handle any positive {@code upper}.
  382.      *
  383.      * @param rng Generator of uniformly distributed random numbers.
  384.      * @param upper Upper bound (inclusive) of the distribution. Must be positive.
  385.      * @return the sampler
  386.      */
  387.     private static AbstractDiscreteUniformSampler createZeroBoundedSampler(UniformRandomProvider rng,
  388.                                                                            int upper) {
  389.         // Note: Handle any range up to 2^31 (which is negative as a signed
  390.         // 32-bit integer but handled as a power of 2)
  391.         final int range = upper + 1;
  392.         return isPowerOf2(range) ?
  393.             new PowerOf2RangeDiscreteUniformSampler(rng, range) :
  394.             new SmallRangeDiscreteUniformSampler(rng, range);
  395.     }

  396.     /**
  397.      * Checks if the value is a power of 2.
  398.      *
  399.      * <p>This returns {@code true} for the value {@code Integer.MIN_VALUE} which can be
  400.      * handled as an unsigned integer of 2^31.</p>
  401.      *
  402.      * @param value Value.
  403.      * @return {@code true} if a power of 2
  404.      */
  405.     private static boolean isPowerOf2(final int value) {
  406.         return value != 0 && (value & (value - 1)) == 0;
  407.     }
  408. }