AbstractIntegerDistribution.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.math4.legacy.distribution;

  18. import org.apache.commons.statistics.distribution.DiscreteDistribution;
  19. import org.apache.commons.math4.legacy.exception.MathInternalError;
  20. import org.apache.commons.math4.legacy.exception.NumberIsTooLargeException;
  21. import org.apache.commons.math4.legacy.exception.OutOfRangeException;
  22. import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
  23. import org.apache.commons.rng.UniformRandomProvider;
  24. import org.apache.commons.rng.sampling.distribution.InverseTransformDiscreteSampler;
  25. import org.apache.commons.math4.core.jdkmath.JdkMath;

  26. /**
  27.  * Base class for integer-valued discrete distributions.  Default
  28.  * implementations are provided for some of the methods that do not vary
  29.  * from distribution to distribution.
  30.  *
  31.  */
  32. public abstract class AbstractIntegerDistribution
  33.     implements DiscreteDistribution {
  34.     /**
  35.      * {@inheritDoc}
  36.      *
  37.      * The default implementation uses the identity
  38.      * <p>{@code P(x0 < X <= x1) = P(X <= x1) - P(X <= x0)}</p>
  39.      *
  40.      * @since 4.0, was previously named cumulativeProbability
  41.      */
  42.     @Override
  43.     public double probability(int x0, int x1) throws NumberIsTooLargeException {
  44.         if (x1 < x0) {
  45.             throw new NumberIsTooLargeException(LocalizedFormats.LOWER_ENDPOINT_ABOVE_UPPER_ENDPOINT,
  46.                     x0, x1, true);
  47.         }
  48.         return cumulativeProbability(x1) - cumulativeProbability(x0);
  49.     }

  50.     /**
  51.      * {@inheritDoc}
  52.      *
  53.      * The default implementation returns
  54.      * <ul>
  55.      * <li>{@link #getSupportLowerBound()} for {@code p = 0},</li>
  56.      * <li>{@link #getSupportUpperBound()} for {@code p = 1}, and</li>
  57.      * <li>{@link #solveInverseCumulativeProbability(double, int, int)} for
  58.      *     {@code 0 < p < 1}.</li>
  59.      * </ul>
  60.      */
  61.     @Override
  62.     public int inverseCumulativeProbability(final double p) throws OutOfRangeException {
  63.         if (p < 0.0 || p > 1.0) {
  64.             throw new OutOfRangeException(p, 0, 1);
  65.         }

  66.         int lower = getSupportLowerBound();
  67.         if (p == 0.0) {
  68.             return lower;
  69.         }
  70.         if (lower == Integer.MIN_VALUE) {
  71.             if (checkedCumulativeProbability(lower) >= p) {
  72.                 return lower;
  73.             }
  74.         } else {
  75.             lower -= 1; // this ensures cumulativeProbability(lower) < p, which
  76.                         // is important for the solving step
  77.         }

  78.         int upper = getSupportUpperBound();
  79.         if (p == 1.0) {
  80.             return upper;
  81.         }

  82.         // use the one-sided Chebyshev inequality to narrow the bracket
  83.         // cf. AbstractRealDistribution.inverseCumulativeProbability(double)
  84.         final double mu = getMean();
  85.         final double sigma = JdkMath.sqrt(getVariance());
  86.         final boolean chebyshevApplies = !(Double.isInfinite(mu) || Double.isNaN(mu) ||
  87.                 Double.isInfinite(sigma) || Double.isNaN(sigma) || sigma == 0.0);
  88.         if (chebyshevApplies) {
  89.             double k = JdkMath.sqrt((1.0 - p) / p);
  90.             double tmp = mu - k * sigma;
  91.             if (tmp > lower) {
  92.                 lower = ((int) JdkMath.ceil(tmp)) - 1;
  93.             }
  94.             k = 1.0 / k;
  95.             tmp = mu + k * sigma;
  96.             if (tmp < upper) {
  97.                 upper = ((int) JdkMath.ceil(tmp)) - 1;
  98.             }
  99.         }

  100.         return solveInverseCumulativeProbability(p, lower, upper);
  101.     }

  102.     /**
  103.      * This is a utility function used by {@link
  104.      * #inverseCumulativeProbability(double)}. It assumes {@code 0 < p < 1} and
  105.      * that the inverse cumulative probability lies in the bracket {@code
  106.      * (lower, upper]}. The implementation does simple bisection to find the
  107.      * smallest {@code p}-quantile {@code inf{x in Z | P(X<=x) >= p}}.
  108.      *
  109.      * @param p the cumulative probability
  110.      * @param lower a value satisfying {@code cumulativeProbability(lower) < p}
  111.      * @param upper a value satisfying {@code p <= cumulativeProbability(upper)}
  112.      * @return the smallest {@code p}-quantile of this distribution
  113.      */
  114.     protected int solveInverseCumulativeProbability(final double p, int lower, int upper) {
  115.         while (lower + 1 < upper) {
  116.             int xm = (lower + upper) / 2;
  117.             if (xm < lower || xm > upper) {
  118.                 /*
  119.                  * Overflow.
  120.                  * There will never be an overflow in both calculation methods
  121.                  * for xm at the same time
  122.                  */
  123.                 xm = lower + (upper - lower) / 2;
  124.             }

  125.             double pm = checkedCumulativeProbability(xm);
  126.             if (pm >= p) {
  127.                 upper = xm;
  128.             } else {
  129.                 lower = xm;
  130.             }
  131.         }
  132.         return upper;
  133.     }

  134.     /**
  135.      * Computes the cumulative probability function and checks for {@code NaN}
  136.      * values returned. Throws {@code MathInternalError} if the value is
  137.      * {@code NaN}. Rethrows any exception encountered evaluating the cumulative
  138.      * probability function. Throws {@code MathInternalError} if the cumulative
  139.      * probability function returns {@code NaN}.
  140.      *
  141.      * @param argument input value
  142.      * @return the cumulative probability
  143.      * @throws MathInternalError if the cumulative probability is {@code NaN}
  144.      */
  145.     private double checkedCumulativeProbability(int argument)
  146.         throws MathInternalError {
  147.         final double result = cumulativeProbability(argument);
  148.         if (Double.isNaN(result)) {
  149.             throw new MathInternalError(LocalizedFormats
  150.                     .DISCRETE_CUMULATIVE_PROBABILITY_RETURNED_NAN, argument);
  151.         }
  152.         return result;
  153.     }

  154.     /**
  155.      * {@inheritDoc}
  156.      * <p>
  157.      * The default implementation simply computes the logarithm of {@code probability(x)}.
  158.      */
  159.     @Override
  160.     public double logProbability(int x) {
  161.         return JdkMath.log(probability(x));
  162.     }

  163.     /**
  164.      * Utility function for allocating an array and filling it with {@code n}
  165.      * samples generated by the given {@code sampler}.
  166.      *
  167.      * @param n Number of samples.
  168.      * @param sampler Sampler.
  169.      * @return an array of size {@code n}.
  170.      */
  171.     public static int[] sample(int n,
  172.                                DiscreteDistribution.Sampler sampler) {
  173.         final int[] samples = new int[n];
  174.         for (int i = 0; i < n; i++) {
  175.             samples[i] = sampler.sample();
  176.         }
  177.         return samples;
  178.     }

  179.     /**{@inheritDoc} */
  180.     @Override
  181.     public DiscreteDistribution.Sampler createSampler(final UniformRandomProvider rng) {
  182.         // Inversion method distribution sampler.
  183.         return InverseTransformDiscreteSampler.of(rng, this::inverseCumulativeProbability)::sample;
  184.     }
  185. }