GeometricDistribution.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.statistics.distribution;

  18. import java.util.function.IntToDoubleFunction;
  19. import org.apache.commons.rng.UniformRandomProvider;
  20. import org.apache.commons.rng.sampling.distribution.GeometricSampler;

  21. /**
  22.  * Implementation of the geometric distribution.
  23.  *
  24.  * <p>The probability mass function of \( X \) is:
  25.  *
  26.  * <p>\[ f(k; p) = (1-p)^k \, p \]
  27.  *
  28.  * <p>for \( p \in (0, 1] \) the probability of success and
  29.  * \( k \in \{0, 1, 2, \dots\} \) the number of failures.
  30.  *
  31.  * <p>This parameterization is used to model the number of failures until
  32.  * the first success.
  33.  *
  34.  * @see <a href="https://en.wikipedia.org/wiki/Geometric_distribution">Geometric distribution (Wikipedia)</a>
  35.  * @see <a href="https://mathworld.wolfram.com/GeometricDistribution.html">Geometric distribution (MathWorld)</a>
  36.  */
  37. public final class GeometricDistribution extends AbstractDiscreteDistribution {
  38.     /** 1/2. */
  39.     private static final double HALF = 0.5;

  40.     /** The probability of success. */
  41.     private final double probabilityOfSuccess;
  42.     /** {@code log(p)} where p is the probability of success. */
  43.     private final double logProbabilityOfSuccess;
  44.     /** {@code log(1 - p)} where p is the probability of success. */
  45.     private final double log1mProbabilityOfSuccess;
  46.     /** Value of survival probability for x=0.
  47.      * Used in the survival functions. Equal to (1 - probability of success). */
  48.     private final double sf0;
  49.     /** Implementation of PMF(x). Assumes that {@code x > 0}. */
  50.     private final IntToDoubleFunction pmf;

  51.     /**
  52.      * @param p Probability of success.
  53.      */
  54.     private GeometricDistribution(double p) {
  55.         probabilityOfSuccess = p;
  56.         logProbabilityOfSuccess = Math.log(p);
  57.         log1mProbabilityOfSuccess = Math.log1p(-p);
  58.         sf0 = 1 - p;

  59.         // Choose the PMF implementation.
  60.         // When p >= 0.5 then 1 - p is exact and using the power function
  61.         // is consistently more accurate than the use of the exponential function.
  62.         // When p -> 0 then the exponential function avoids large error propagation
  63.         // of the power function used with an inexact 1 - p.
  64.         // Also compute the survival probability for use when x=0.
  65.         if (p >= HALF) {
  66.             pmf = x -> Math.pow(sf0, x) * probabilityOfSuccess;
  67.         } else {
  68.             pmf = x -> Math.exp(log1mProbabilityOfSuccess * x) * probabilityOfSuccess;
  69.         }
  70.     }

  71.     /**
  72.      * Creates a geometric distribution.
  73.      *
  74.      * @param p Probability of success.
  75.      * @return the geometric distribution
  76.      * @throws IllegalArgumentException if {@code p <= 0} or {@code p > 1}.
  77.      */
  78.     public static GeometricDistribution of(double p) {
  79.         if (p <= 0 || p > 1) {
  80.             throw new DistributionException(DistributionException.INVALID_NON_ZERO_PROBABILITY, p);
  81.         }
  82.         return new GeometricDistribution(p);
  83.     }

  84.     /**
  85.      * Gets the probability of success parameter of this distribution.
  86.      *
  87.      * @return the probability of success.
  88.      */
  89.     public double getProbabilityOfSuccess() {
  90.         return probabilityOfSuccess;
  91.     }

  92.     /** {@inheritDoc} */
  93.     @Override
  94.     public double probability(int x) {
  95.         if (x <= 0) {
  96.             // Special case of x=0 exploiting cancellation.
  97.             return x == 0 ? probabilityOfSuccess : 0;
  98.         }
  99.         return pmf.applyAsDouble(x);
  100.     }

  101.     /** {@inheritDoc} */
  102.     @Override
  103.     public double logProbability(int x) {
  104.         if (x <= 0) {
  105.             // Special case of x=0 exploiting cancellation.
  106.             return x == 0 ? logProbabilityOfSuccess : Double.NEGATIVE_INFINITY;
  107.         }
  108.         return x * log1mProbabilityOfSuccess + logProbabilityOfSuccess;
  109.     }

  110.     /** {@inheritDoc} */
  111.     @Override
  112.     public double cumulativeProbability(int x) {
  113.         if (x <= 0) {
  114.             // Note: CDF(x=0) = PDF(x=0) = probabilityOfSuccess
  115.             return x == 0 ? probabilityOfSuccess : 0;
  116.         }
  117.         // Note: Double addition avoids overflow. This may compute a value less than 1.0
  118.         // for the max integer value when p is very small.
  119.         return -Math.expm1(log1mProbabilityOfSuccess * (x + 1.0));
  120.     }

  121.     /** {@inheritDoc} */
  122.     @Override
  123.     public double survivalProbability(int x) {
  124.         if (x <= 0) {
  125.             // Note: SF(x=0) = 1 - PDF(x=0) = 1 - probabilityOfSuccess
  126.             // Use a pre-computed value to avoid cancellation when probabilityOfSuccess -> 0
  127.             return x == 0 ? sf0 : 1;
  128.         }
  129.         // Note: Double addition avoids overflow. This may compute a value greater than 0.0
  130.         // for the max integer value when p is very small.
  131.         return Math.exp(log1mProbabilityOfSuccess * (x + 1.0));
  132.     }

  133.     /** {@inheritDoc} */
  134.     @Override
  135.     public int inverseCumulativeProbability(double p) {
  136.         ArgumentUtils.checkProbability(p);
  137.         if (p == 1) {
  138.             return getSupportUpperBound();
  139.         }
  140.         if (p <= probabilityOfSuccess) {
  141.             return 0;
  142.         }
  143.         // p > probabilityOfSuccess
  144.         // => log(1-p) < log(1-probabilityOfSuccess);
  145.         // Both terms are negative as probabilityOfSuccess > 0.
  146.         // This should be lower bounded to (2 - 1) = 1
  147.         int x = (int) (Math.ceil(Math.log1p(-p) / log1mProbabilityOfSuccess) - 1);

  148.         // Correct rounding errors.
  149.         // This ensures x == icdf(cdf(x))

  150.         if (cumulativeProbability(x - 1) >= p) {
  151.             // No checks for x=0.
  152.             // If x=0; cdf(-1) = 0 and the condition is false as p>0 at this point.
  153.             x--;
  154.         } else if (cumulativeProbability(x) < p && x < Integer.MAX_VALUE) {
  155.             // The supported upper bound is max_value here as probabilityOfSuccess != 1
  156.             x++;
  157.         }

  158.         return x;
  159.     }

  160.     /** {@inheritDoc} */
  161.     @Override
  162.     public int inverseSurvivalProbability(double p) {
  163.         ArgumentUtils.checkProbability(p);
  164.         if (p == 0) {
  165.             return getSupportUpperBound();
  166.         }
  167.         if (p >= sf0) {
  168.             return 0;
  169.         }

  170.         // p < 1 - probabilityOfSuccess
  171.         // Inversion as for icdf using log(p) in place of log1p(-p)
  172.         int x = (int) (Math.ceil(Math.log(p) / log1mProbabilityOfSuccess) - 1);

  173.         // Correct rounding errors.
  174.         // This ensures x == isf(sf(x))

  175.         if (survivalProbability(x - 1) <= p) {
  176.             // No checks for x=0
  177.             // If x=0; sf(-1) = 1 and the condition is false as p<1 at this point.
  178.             x--;
  179.         } else if (survivalProbability(x) > p && x < Integer.MAX_VALUE) {
  180.             // The supported upper bound is max_value here as probabilityOfSuccess != 1
  181.             x++;
  182.         }

  183.         return x;
  184.     }

  185.     /**
  186.      * {@inheritDoc}
  187.      *
  188.      * <p>For probability parameter \( p \), the mean is:
  189.      *
  190.      * <p>\[ \frac{1 - p}{p} \]
  191.      */
  192.     @Override
  193.     public double getMean() {
  194.         return (1 - probabilityOfSuccess) / probabilityOfSuccess;
  195.     }

  196.     /**
  197.      * {@inheritDoc}
  198.      *
  199.      * <p>For probability parameter \( p \), the variance is:
  200.      *
  201.      * <p>\[ \frac{1 - p}{p^2} \]
  202.      */
  203.     @Override
  204.     public double getVariance() {
  205.         return (1 - probabilityOfSuccess) / (probabilityOfSuccess * probabilityOfSuccess);
  206.     }

  207.     /**
  208.      * {@inheritDoc}
  209.      *
  210.      * <p>The lower bound of the support is always 0.
  211.      *
  212.      * @return 0.
  213.      */
  214.     @Override
  215.     public int getSupportLowerBound() {
  216.         return 0;
  217.     }

  218.     /**
  219.      * {@inheritDoc}
  220.      *
  221.      * <p>The upper bound of the support is positive infinity except for the
  222.      * probability parameter {@code p = 1.0}.
  223.      *
  224.      * @return {@link Integer#MAX_VALUE} or 0.
  225.      */
  226.     @Override
  227.     public int getSupportUpperBound() {
  228.         return probabilityOfSuccess < 1 ? Integer.MAX_VALUE : 0;
  229.     }

  230.     /** {@inheritDoc} */
  231.     @Override
  232.     public Sampler createSampler(UniformRandomProvider rng) {
  233.         return GeometricSampler.of(rng, probabilityOfSuccess)::sample;
  234.     }
  235. }