UniformDiscreteDistribution.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.statistics.distribution;

  18. import org.apache.commons.rng.UniformRandomProvider;
  19. import org.apache.commons.rng.sampling.distribution.DiscreteUniformSampler;

  20. /**
  21.  * Implementation of the uniform discrete distribution.
  22.  *
  23.  * <p>The probability mass function of \( X \) is:
  24.  *
  25.  * <p>\[ f(k; a, b) = \frac{1}{b-a+1} \]
  26.  *
  27.  * <p>for integer \( a, b \) and \( a \le b \) and
  28.  * \( k \in [a, b] \).
  29.  *
  30.  * @see <a href="https://en.wikipedia.org/wiki/Uniform_distribution_(discrete)">
  31.  * Uniform distribution (discrete) (Wikipedia)</a>
  32.  * @see <a href="https://mathworld.wolfram.com/DiscreteUniformDistribution.html">
  33.  * Discrete uniform distribution (MathWorld)</a>
  34.  */
  35. public final class UniformDiscreteDistribution extends AbstractDiscreteDistribution {
  36.     /** Lower bound (inclusive) of this distribution. */
  37.     private final int lower;
  38.     /** Upper bound (inclusive) of this distribution. */
  39.     private final int upper;
  40.     /** "upper" - "lower" + 1 (as a double to avoid overflow). */
  41.     private final double upperMinusLowerPlus1;
  42.     /** Cache of the probability. */
  43.     private final double pmf;
  44.     /** Cache of the log probability. */
  45.     private final double logPmf;
  46.     /** Value of survival probability for x=0. Used in the inverse survival function. */
  47.     private final double sf0;

  48.     /**
  49.      * @param lower Lower bound (inclusive) of this distribution.
  50.      * @param upper Upper bound (inclusive) of this distribution.
  51.      */
  52.     private UniformDiscreteDistribution(int lower,
  53.                                         int upper) {
  54.         this.lower = lower;
  55.         this.upper = upper;
  56.         upperMinusLowerPlus1 = (double) upper - lower + 1;
  57.         pmf = 1.0 / upperMinusLowerPlus1;
  58.         logPmf = -Math.log(upperMinusLowerPlus1);
  59.         sf0 = (upperMinusLowerPlus1 - 1) / upperMinusLowerPlus1;
  60.     }

  61.     /**
  62.      * Creates a new uniform discrete distribution.
  63.      *
  64.      * @param lower Lower bound (inclusive) of this distribution.
  65.      * @param upper Upper bound (inclusive) of this distribution.
  66.      * @return the distribution
  67.      * @throws IllegalArgumentException if {@code lower > upper}.
  68.      */
  69.     public static UniformDiscreteDistribution of(int lower,
  70.                                                  int upper) {
  71.         if (lower > upper) {
  72.             throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GT_HIGH,
  73.                                             lower, upper);
  74.         }
  75.         return new UniformDiscreteDistribution(lower, upper);
  76.     }

  77.     /** {@inheritDoc} */
  78.     @Override
  79.     public double probability(int x) {
  80.         if (x < lower || x > upper) {
  81.             return 0;
  82.         }
  83.         return pmf;
  84.     }

  85.     /** {@inheritDoc} */
  86.     @Override
  87.     public double probability(int x0,
  88.                               int x1) {
  89.         if (x0 > x1) {
  90.             throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GT_HIGH, x0, x1);
  91.         }
  92.         if (x0 >= upper || x1 < lower) {
  93.             // (x0, x1] does not overlap [lower, upper]
  94.             return 0;
  95.         }

  96.         // x0 < upper
  97.         // x1 >= lower

  98.         // Find the range between x0 (exclusive) and x1 (inclusive) within [lower, upper].
  99.         // In the case of x0 < lower set l so that u - l == (u - lower) + 1
  100.         // long arithmetic prevents overflow
  101.         final long l = Math.max(lower - 1L, x0);
  102.         final long u = Math.min(upper, x1);

  103.         return (u - l) / upperMinusLowerPlus1;
  104.     }

  105.     /** {@inheritDoc} */
  106.     @Override
  107.     public double logProbability(int x) {
  108.         if (x < lower || x > upper) {
  109.             return Double.NEGATIVE_INFINITY;
  110.         }
  111.         return logPmf;
  112.     }

  113.     /** {@inheritDoc} */
  114.     @Override
  115.     public double cumulativeProbability(int x) {
  116.         if (x <= lower) {
  117.             // Note: CDF(x=0) = PDF(x=0)
  118.             return x == lower ? pmf : 0;
  119.         }
  120.         if (x >= upper) {
  121.             return 1;
  122.         }
  123.         return ((double) x - lower + 1) / upperMinusLowerPlus1;
  124.     }

  125.     /** {@inheritDoc} */
  126.     @Override
  127.     public double survivalProbability(int x) {
  128.         if (x <= lower) {
  129.             // Note: SF(x=0) = 1 - PDF(x=0)
  130.             // Use a pre-computed value to avoid cancellation when probabilityOfSuccess -> 0
  131.             return x == lower ? sf0 : 1;
  132.         }
  133.         if (x >= upper) {
  134.             return 0;
  135.         }
  136.         return ((double) upper - x) / upperMinusLowerPlus1;
  137.     }

  138.     /** {@inheritDoc} */
  139.     @Override
  140.     public int inverseCumulativeProbability(double p) {
  141.         ArgumentUtils.checkProbability(p);
  142.         if (p > sf0) {
  143.             return upper;
  144.         }
  145.         if (p <= pmf) {
  146.             return lower;
  147.         }
  148.         // p in ( pmf         , sf0             ]
  149.         // p in ( 1 / {u-l+1} , {u-l} / {u-l+1} ]
  150.         // x in ( l           , u-1             ]
  151.         int x = (int) (lower + Math.ceil(p * upperMinusLowerPlus1) - 1);

  152.         // Correct rounding errors.
  153.         // This ensures x == icdf(cdf(x))
  154.         // Note: Directly computing the CDF(x-1) avoids integer overflow if x=min_value

  155.         if (((double) x - lower) / upperMinusLowerPlus1 >= p) {
  156.             // No check for x > lower: cdf(x=lower) = 0 and thus is below p
  157.             // cdf(x-1) >= p
  158.             x--;
  159.         } else if (((double) x - lower + 1) / upperMinusLowerPlus1 < p) {
  160.             // No check for x < upper: cdf(x=upper) = 1 and thus is above p
  161.             // cdf(x) < p
  162.             x++;
  163.         }

  164.         return x;
  165.     }

  166.     /** {@inheritDoc} */
  167.     @Override
  168.     public int inverseSurvivalProbability(final double p) {
  169.         ArgumentUtils.checkProbability(p);
  170.         if (p < pmf) {
  171.             return upper;
  172.         }
  173.         if (p >= sf0) {
  174.             return lower;
  175.         }
  176.         // p in [ pmf         , sf0             )
  177.         // p in [ 1 / {u-l+1} , {u-l} / {u-l+1} )
  178.         // x in [ u-1         , l               )
  179.         int x = (int) (upper - Math.floor(p * upperMinusLowerPlus1));

  180.         // Correct rounding errors.
  181.         // This ensures x == isf(sf(x))
  182.         // Note: Directly computing the SF(x-1) avoids integer overflow if x=min_value

  183.         if (((double) upper - x + 1) / upperMinusLowerPlus1 <= p) {
  184.             // No check for x > lower: sf(x=lower) = 1 and thus is above p
  185.             // sf(x-1) <= p
  186.             x--;
  187.         } else if (((double) upper - x) / upperMinusLowerPlus1 > p) {
  188.             // No check for x < upper: sf(x=upper) = 0 and thus is below p
  189.             // sf(x) > p
  190.             x++;
  191.         }

  192.         return x;
  193.     }

  194.     /**
  195.      * {@inheritDoc}
  196.      *
  197.      * <p>For lower bound \( a \) and upper bound \( b \), the mean is \( \frac{1}{2} (a + b) \).
  198.      */
  199.     @Override
  200.     public double getMean() {
  201.         // Avoid overflow
  202.         return 0.5 * ((double) upper + lower);
  203.     }

  204.     /**
  205.      * {@inheritDoc}
  206.      *
  207.      * <p>For lower bound \( a \) and upper bound \( b \), the variance is:
  208.      *
  209.      * <p>\[ \frac{1}{12} (n^2 - 1) \]
  210.      *
  211.      * <p>where \( n = b - a + 1 \).
  212.      */
  213.     @Override
  214.     public double getVariance() {
  215.         return (upperMinusLowerPlus1 * upperMinusLowerPlus1 - 1) / 12;
  216.     }

  217.     /**
  218.      * {@inheritDoc}
  219.      *
  220.      * <p>The lower bound of the support is equal to the lower bound parameter
  221.      * of the distribution.
  222.      */
  223.     @Override
  224.     public int getSupportLowerBound() {
  225.         return lower;
  226.     }

  227.     /**
  228.      * {@inheritDoc}
  229.      *
  230.      * <p>The upper bound of the support is equal to the upper bound parameter
  231.      * of the distribution.
  232.      */
  233.     @Override
  234.     public int getSupportUpperBound() {
  235.         return upper;
  236.     }

  237.     /** {@inheritDoc} */
  238.     @Override
  239.     public DiscreteDistribution.Sampler createSampler(final UniformRandomProvider rng) {
  240.         // Discrete uniform distribution sampler.
  241.         return DiscreteUniformSampler.of(rng, lower, upper)::sample;
  242.     }
  243. }