AbstractDiscreteDistribution.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.statistics.distribution;

  18. import java.util.function.IntUnaryOperator;
  19. import org.apache.commons.rng.UniformRandomProvider;
  20. import org.apache.commons.rng.sampling.distribution.InverseTransformDiscreteSampler;

  21. /**
  22.  * Base class for integer-valued discrete distributions.  Default
  23.  * implementations are provided for some of the methods that do not vary
  24.  * from distribution to distribution.
  25.  *
  26.  * <p>This base class provides a default factory method for creating
  27.  * a {@linkplain DiscreteDistribution.Sampler sampler instance} that uses the
  28.  * <a href="https://en.wikipedia.org/wiki/Inverse_transform_sampling">
  29.  * inversion method</a> for generating random samples that follow the
  30.  * distribution.
  31.  *
  32.  * <p>The class provides functionality to evaluate the probability in a range
  33.  * using either the cumulative probability or the survival probability.
  34.  * The survival probability is used if both arguments to
  35.  * {@link #probability(int, int)} are above the median.
  36.  * Child classes with a known median can override the default {@link #getMedian()}
  37.  * method.
  38.  */
  39. abstract class AbstractDiscreteDistribution
  40.     implements DiscreteDistribution {
  41.     /** Marker value for no median.
  42.      * This is a long to be outside the value of any possible int valued median. */
  43.     private static final long NO_MEDIAN = Long.MIN_VALUE;

  44.     /** Cached value of the median. */
  45.     private long median = NO_MEDIAN;

  46.     /**
  47.      * Gets the median. This is used to determine if the arguments to the
  48.      * {@link #probability(int, int)} function are in the upper or lower domain.
  49.      *
  50.      * <p>The default implementation calls {@link #inverseCumulativeProbability(double)}
  51.      * with a value of 0.5.
  52.      *
  53.      * @return the median
  54.      */
  55.     int getMedian() {
  56.         long m = median;
  57.         if (m == NO_MEDIAN) {
  58.             median = m = inverseCumulativeProbability(0.5);
  59.         }
  60.         return (int) m;
  61.     }

  62.     /** {@inheritDoc} */
  63.     @Override
  64.     public double probability(int x0,
  65.                               int x1) {
  66.         if (x0 > x1) {
  67.             throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GT_HIGH, x0, x1);
  68.         }
  69.         // As per the default interface method handle special cases:
  70.         // x0     = x1 : return 0
  71.         // x0 + 1 = x1 : return probability(x1)
  72.         // Long addition avoids overflow
  73.         if (x0 + 1L >= x1) {
  74.             return x0 == x1 ? 0.0 : probability(x1);
  75.         }

  76.         // Use the survival probability when in the upper domain [3]:
  77.         //
  78.         //  lower          median         upper
  79.         //    |              |              |
  80.         // 1.     |------|
  81.         //        x0     x1
  82.         // 2.         |----------|
  83.         //            x0         x1
  84.         // 3.                  |--------|
  85.         //                     x0       x1

  86.         final double m = getMedian();
  87.         if (x0 >= m) {
  88.             return survivalProbability(x0) - survivalProbability(x1);
  89.         }
  90.         return cumulativeProbability(x1) - cumulativeProbability(x0);
  91.     }

  92.     /**
  93.      * {@inheritDoc}
  94.      *
  95.      * <p>The default implementation returns:
  96.      * <ul>
  97.      * <li>{@link #getSupportLowerBound()} for {@code p = 0},</li>
  98.      * <li>{@link #getSupportUpperBound()} for {@code p = 1}, or</li>
  99.      * <li>the result of a binary search between the lower and upper bound using
  100.      *     {@link #cumulativeProbability(int) cumulativeProbability(x)}.
  101.      *     The bounds may be bracketed for efficiency.</li>
  102.      * </ul>
  103.      *
  104.      * @throws IllegalArgumentException if {@code p < 0} or {@code p > 1}
  105.      */
  106.     @Override
  107.     public int inverseCumulativeProbability(double p) {
  108.         ArgumentUtils.checkProbability(p);
  109.         return inverseProbability(p, 1 - p, false);
  110.     }

  111.     /**
  112.      * {@inheritDoc}
  113.      *
  114.      * <p>The default implementation returns:
  115.      * <ul>
  116.      * <li>{@link #getSupportLowerBound()} for {@code p = 1},</li>
  117.      * <li>{@link #getSupportUpperBound()} for {@code p = 0}, or</li>
  118.      * <li>the result of a binary search between the lower and upper bound using
  119.      *     {@link #survivalProbability(int) survivalProbability(x)}.
  120.      *     The bounds may be bracketed for efficiency.</li>
  121.      * </ul>
  122.      *
  123.      * @throws IllegalArgumentException if {@code p < 0} or {@code p > 1}
  124.      */
  125.     @Override
  126.     public int inverseSurvivalProbability(double p) {
  127.         ArgumentUtils.checkProbability(p);
  128.         return inverseProbability(1 - p, p, true);
  129.     }

  130.     /**
  131.      * Implementation for the inverse cumulative or survival probability.
  132.      *
  133.      * @param p Cumulative probability.
  134.      * @param q Survival probability.
  135.      * @param complement Set to true to compute the inverse survival probability
  136.      * @return the value
  137.      */
  138.     private int inverseProbability(double p, double q, boolean complement) {

  139.         int lower = getSupportLowerBound();
  140.         if (p == 0) {
  141.             return lower;
  142.         }
  143.         int upper = getSupportUpperBound();
  144.         if (q == 0) {
  145.             return upper;
  146.         }

  147.         // The binary search sets the upper value to the mid-point
  148.         // based on fun(x) >= 0. The upper value is returned.
  149.         //
  150.         // Create a function to search for x where the upper bound can be
  151.         // lowered if:
  152.         // cdf(x) >= p
  153.         // sf(x)  <= q
  154.         final IntUnaryOperator fun = complement ?
  155.             x -> Double.compare(q, survivalProbability(x)) :
  156.             x -> Double.compare(cumulativeProbability(x), p);

  157.         if (lower == Integer.MIN_VALUE) {
  158.             if (fun.applyAsInt(lower) >= 0) {
  159.                 return lower;
  160.             }
  161.         } else {
  162.             // this ensures:
  163.             // cumulativeProbability(lower) < p
  164.             // survivalProbability(lower) > q
  165.             // which is important for the solving step
  166.             lower -= 1;
  167.         }

  168.         // use the one-sided Chebyshev inequality to narrow the bracket
  169.         // cf. AbstractContinuousDistribution.inverseCumulativeProbability(double)
  170.         final double mu = getMean();
  171.         final double sig = Math.sqrt(getVariance());
  172.         final boolean chebyshevApplies = Double.isFinite(mu) &&
  173.                                          ArgumentUtils.isFiniteStrictlyPositive(sig);

  174.         if (chebyshevApplies) {
  175.             double tmp = mu - sig * Math.sqrt(q / p);
  176.             if (tmp > lower) {
  177.                 lower = ((int) Math.ceil(tmp)) - 1;
  178.             }
  179.             tmp = mu + sig * Math.sqrt(p / q);
  180.             if (tmp < upper) {
  181.                 upper = ((int) Math.ceil(tmp)) - 1;
  182.             }
  183.         }

  184.         return solveInverseProbability(fun, lower, upper);
  185.     }

  186.     /**
  187.      * This is a utility function used by {@link
  188.      * #inverseProbability(double, double, boolean)}. It assumes
  189.      * that the inverse probability lies in the bracket {@code
  190.      * (lower, upper]}. The implementation does simple bisection to find the
  191.      * smallest {@code x} such that {@code fun(x) >= 0}.
  192.      *
  193.      * @param fun Probability function.
  194.      * @param lowerBound Value satisfying {@code fun(lower) < 0}.
  195.      * @param upperBound Value satisfying {@code fun(upper) >= 0}.
  196.      * @return the smallest x
  197.      */
  198.     private static int solveInverseProbability(IntUnaryOperator fun,
  199.                                                int lowerBound,
  200.                                                int upperBound) {
  201.         // Use long to prevent overflow during computation of the middle
  202.         long lower = lowerBound;
  203.         long upper = upperBound;
  204.         while (lower + 1 < upper) {
  205.             // Note: Cannot replace division by 2 with a right shift because
  206.             // (lower + upper) can be negative.
  207.             final long middle = (lower + upper) / 2;
  208.             final int pm = fun.applyAsInt((int) middle);
  209.             if (pm < 0) {
  210.                 lower = middle;
  211.             } else {
  212.                 upper = middle;
  213.             }
  214.         }
  215.         return (int) upper;
  216.     }

  217.     /** {@inheritDoc} */
  218.     @Override
  219.     public DiscreteDistribution.Sampler createSampler(final UniformRandomProvider rng) {
  220.         // Inversion method distribution sampler.
  221.         return InverseTransformDiscreteSampler.of(rng, this::inverseCumulativeProbability)::sample;
  222.     }
  223. }