AbstractDiscreteDistribution.java
- /*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- package org.apache.commons.statistics.distribution;
- import java.util.function.IntUnaryOperator;
- import org.apache.commons.rng.UniformRandomProvider;
- import org.apache.commons.rng.sampling.distribution.InverseTransformDiscreteSampler;
- /**
- * Base class for integer-valued discrete distributions. Default
- * implementations are provided for some of the methods that do not vary
- * from distribution to distribution.
- *
- * <p>This base class provides a default factory method for creating
- * a {@linkplain DiscreteDistribution.Sampler sampler instance} that uses the
- * <a href="https://en.wikipedia.org/wiki/Inverse_transform_sampling">
- * inversion method</a> for generating random samples that follow the
- * distribution.
- *
- * <p>The class provides functionality to evaluate the probability in a range
- * using either the cumulative probability or the survival probability.
- * The survival probability is used if both arguments to
- * {@link #probability(int, int)} are above the median.
- * Child classes with a known median can override the default {@link #getMedian()}
- * method.
- */
- abstract class AbstractDiscreteDistribution
- implements DiscreteDistribution {
- /** Marker value for no median.
- * This is a long to be outside the value of any possible int valued median. */
- private static final long NO_MEDIAN = Long.MIN_VALUE;
- /** Cached value of the median. */
- private long median = NO_MEDIAN;
- /**
- * Gets the median. This is used to determine if the arguments to the
- * {@link #probability(int, int)} function are in the upper or lower domain.
- *
- * <p>The default implementation calls {@link #inverseCumulativeProbability(double)}
- * with a value of 0.5.
- *
- * @return the median
- */
- int getMedian() {
- long m = median;
- if (m == NO_MEDIAN) {
- median = m = inverseCumulativeProbability(0.5);
- }
- return (int) m;
- }
- /** {@inheritDoc} */
- @Override
- public double probability(int x0,
- int x1) {
- if (x0 > x1) {
- throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GT_HIGH, x0, x1);
- }
- // As per the default interface method handle special cases:
- // x0 = x1 : return 0
- // x0 + 1 = x1 : return probability(x1)
- // Long addition avoids overflow
- if (x0 + 1L >= x1) {
- return x0 == x1 ? 0.0 : probability(x1);
- }
- // Use the survival probability when in the upper domain [3]:
- //
- // lower median upper
- // | | |
- // 1. |------|
- // x0 x1
- // 2. |----------|
- // x0 x1
- // 3. |--------|
- // x0 x1
- final double m = getMedian();
- if (x0 >= m) {
- return survivalProbability(x0) - survivalProbability(x1);
- }
- return cumulativeProbability(x1) - cumulativeProbability(x0);
- }
- /**
- * {@inheritDoc}
- *
- * <p>The default implementation returns:
- * <ul>
- * <li>{@link #getSupportLowerBound()} for {@code p = 0},</li>
- * <li>{@link #getSupportUpperBound()} for {@code p = 1}, or</li>
- * <li>the result of a binary search between the lower and upper bound using
- * {@link #cumulativeProbability(int) cumulativeProbability(x)}.
- * The bounds may be bracketed for efficiency.</li>
- * </ul>
- *
- * @throws IllegalArgumentException if {@code p < 0} or {@code p > 1}
- */
- @Override
- public int inverseCumulativeProbability(double p) {
- ArgumentUtils.checkProbability(p);
- return inverseProbability(p, 1 - p, false);
- }
- /**
- * {@inheritDoc}
- *
- * <p>The default implementation returns:
- * <ul>
- * <li>{@link #getSupportLowerBound()} for {@code p = 1},</li>
- * <li>{@link #getSupportUpperBound()} for {@code p = 0}, or</li>
- * <li>the result of a binary search between the lower and upper bound using
- * {@link #survivalProbability(int) survivalProbability(x)}.
- * The bounds may be bracketed for efficiency.</li>
- * </ul>
- *
- * @throws IllegalArgumentException if {@code p < 0} or {@code p > 1}
- */
- @Override
- public int inverseSurvivalProbability(double p) {
- ArgumentUtils.checkProbability(p);
- return inverseProbability(1 - p, p, true);
- }
- /**
- * Implementation for the inverse cumulative or survival probability.
- *
- * @param p Cumulative probability.
- * @param q Survival probability.
- * @param complement Set to true to compute the inverse survival probability
- * @return the value
- */
- private int inverseProbability(double p, double q, boolean complement) {
- int lower = getSupportLowerBound();
- if (p == 0) {
- return lower;
- }
- int upper = getSupportUpperBound();
- if (q == 0) {
- return upper;
- }
- // The binary search sets the upper value to the mid-point
- // based on fun(x) >= 0. The upper value is returned.
- //
- // Create a function to search for x where the upper bound can be
- // lowered if:
- // cdf(x) >= p
- // sf(x) <= q
- final IntUnaryOperator fun = complement ?
- x -> Double.compare(q, survivalProbability(x)) :
- x -> Double.compare(cumulativeProbability(x), p);
- if (lower == Integer.MIN_VALUE) {
- if (fun.applyAsInt(lower) >= 0) {
- return lower;
- }
- } else {
- // this ensures:
- // cumulativeProbability(lower) < p
- // survivalProbability(lower) > q
- // which is important for the solving step
- lower -= 1;
- }
- // use the one-sided Chebyshev inequality to narrow the bracket
- // cf. AbstractContinuousDistribution.inverseCumulativeProbability(double)
- final double mu = getMean();
- final double sig = Math.sqrt(getVariance());
- final boolean chebyshevApplies = Double.isFinite(mu) &&
- ArgumentUtils.isFiniteStrictlyPositive(sig);
- if (chebyshevApplies) {
- double tmp = mu - sig * Math.sqrt(q / p);
- if (tmp > lower) {
- lower = ((int) Math.ceil(tmp)) - 1;
- }
- tmp = mu + sig * Math.sqrt(p / q);
- if (tmp < upper) {
- upper = ((int) Math.ceil(tmp)) - 1;
- }
- }
- return solveInverseProbability(fun, lower, upper);
- }
- /**
- * This is a utility function used by {@link
- * #inverseProbability(double, double, boolean)}. It assumes
- * that the inverse probability lies in the bracket {@code
- * (lower, upper]}. The implementation does simple bisection to find the
- * smallest {@code x} such that {@code fun(x) >= 0}.
- *
- * @param fun Probability function.
- * @param lowerBound Value satisfying {@code fun(lower) < 0}.
- * @param upperBound Value satisfying {@code fun(upper) >= 0}.
- * @return the smallest x
- */
- private static int solveInverseProbability(IntUnaryOperator fun,
- int lowerBound,
- int upperBound) {
- // Use long to prevent overflow during computation of the middle
- long lower = lowerBound;
- long upper = upperBound;
- while (lower + 1 < upper) {
- // Note: Cannot replace division by 2 with a right shift because
- // (lower + upper) can be negative.
- final long middle = (lower + upper) / 2;
- final int pm = fun.applyAsInt((int) middle);
- if (pm < 0) {
- lower = middle;
- } else {
- upper = middle;
- }
- }
- return (int) upper;
- }
- /** {@inheritDoc} */
- @Override
- public DiscreteDistribution.Sampler createSampler(final UniformRandomProvider rng) {
- // Inversion method distribution sampler.
- return InverseTransformDiscreteSampler.of(rng, this::inverseCumulativeProbability)::sample;
- }
- }