AbstractContinuousDistribution.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.statistics.distribution;

  18. import java.util.function.DoubleBinaryOperator;
  19. import java.util.function.DoubleUnaryOperator;
  20. import org.apache.commons.numbers.rootfinder.BrentSolver;
  21. import org.apache.commons.rng.UniformRandomProvider;
  22. import org.apache.commons.rng.sampling.distribution.InverseTransformContinuousSampler;

  23. /**
  24.  * Base class for probability distributions on the reals.
  25.  * Default implementations are provided for some of the methods
  26.  * that do not vary from distribution to distribution.
  27.  *
  28.  * <p>This base class provides a default factory method for creating
  29.  * a {@linkplain ContinuousDistribution.Sampler sampler instance} that uses the
  30.  * <a href="https://en.wikipedia.org/wiki/Inverse_transform_sampling">
  31.  * inversion method</a> for generating random samples that follow the
  32.  * distribution.
  33.  *
  34.  * <p>The class provides functionality to evaluate the probability in a range
  35.  * using either the cumulative probability or the survival probability.
  36.  * The survival probability is used if both arguments to
  37.  * {@link #probability(double, double)} are above the median.
  38.  * Child classes with a known median can override the default {@link #getMedian()}
  39.  * method.
  40.  */
  41. abstract class AbstractContinuousDistribution
  42.     implements ContinuousDistribution {

  43.     // Notes on the inverse probability implementation:
  44.     //
  45.     // The Brent solver does not allow a stopping criteria for the proximity
  46.     // to the root; it uses equality to zero within 1 ULP. The search is
  47.     // iterated until there is a small difference between the upper
  48.     // and lower bracket of the root, expressed as a combination of relative
  49.     // and absolute thresholds.

  50.     /** BrentSolver relative accuracy.
  51.      * This is used with {@code tol = 2 * relEps * abs(b) + absEps} so the minimum
  52.      * non-zero value with an effect is half of machine epsilon (2^-53). */
  53.     private static final double SOLVER_RELATIVE_ACCURACY = 0x1.0p-53;
  54.     /** BrentSolver absolute accuracy.
  55.      * This is used with {@code tol = 2 * relEps * abs(b) + absEps} so set to MIN_VALUE
  56.      * so that when the relative epsilon has no effect (as b is too small) the tolerance
  57.      * is at least 1 ULP for sub-normal numbers. */
  58.     private static final double SOLVER_ABSOLUTE_ACCURACY = Double.MIN_VALUE;
  59.     /** BrentSolver function value accuracy.
  60.      * Determines if the Brent solver performs a search. It is not used during the search.
  61.      * Set to a very low value to search using Brent's method unless
  62.      * the starting point is correct, or within 1 ULP for sub-normal probabilities. */
  63.     private static final double SOLVER_FUNCTION_VALUE_ACCURACY = Double.MIN_VALUE;

  64.     /** Cached value of the median. */
  65.     private double median = Double.NaN;

  66.     /**
  67.      * Gets the median. This is used to determine if the arguments to the
  68.      * {@link #probability(double, double)} function are in the upper or lower domain.
  69.      *
  70.      * <p>The default implementation calls {@link #inverseCumulativeProbability(double)}
  71.      * with a value of 0.5.
  72.      *
  73.      * @return the median
  74.      */
  75.     double getMedian() {
  76.         double m = median;
  77.         if (Double.isNaN(m)) {
  78.             median = m = inverseCumulativeProbability(0.5);
  79.         }
  80.         return m;
  81.     }

  82.     /** {@inheritDoc} */
  83.     @Override
  84.     public double probability(double x0,
  85.                               double x1) {
  86.         if (x0 > x1) {
  87.             throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GT_HIGH, x0, x1);
  88.         }
  89.         // Use the survival probability when in the upper domain [3]:
  90.         //
  91.         //  lower          median         upper
  92.         //    |              |              |
  93.         // 1.     |------|
  94.         //        x0     x1
  95.         // 2.         |----------|
  96.         //            x0         x1
  97.         // 3.                  |--------|
  98.         //                     x0       x1

  99.         final double m = getMedian();
  100.         if (x0 >= m) {
  101.             return survivalProbability(x0) - survivalProbability(x1);
  102.         }
  103.         return cumulativeProbability(x1) - cumulativeProbability(x0);
  104.     }

  105.     /**
  106.      * {@inheritDoc}
  107.      *
  108.      * <p>The default implementation returns:
  109.      * <ul>
  110.      * <li>{@link #getSupportLowerBound()} for {@code p = 0},</li>
  111.      * <li>{@link #getSupportUpperBound()} for {@code p = 1}, or</li>
  112.      * <li>the result of a search for a root between the lower and upper bound using
  113.      *     {@link #cumulativeProbability(double) cumulativeProbability(x) - p}.
  114.      *     The bounds may be bracketed for efficiency.</li>
  115.      * </ul>
  116.      *
  117.      * @throws IllegalArgumentException if {@code p < 0} or {@code p > 1}
  118.      */
  119.     @Override
  120.     public double inverseCumulativeProbability(double p) {
  121.         ArgumentUtils.checkProbability(p);
  122.         return inverseProbability(p, 1 - p, false);
  123.     }

  124.     /**
  125.      * {@inheritDoc}
  126.      *
  127.      * <p>The default implementation returns:
  128.      * <ul>
  129.      * <li>{@link #getSupportLowerBound()} for {@code p = 1},</li>
  130.      * <li>{@link #getSupportUpperBound()} for {@code p = 0}, or</li>
  131.      * <li>the result of a search for a root between the lower and upper bound using
  132.      *     {@link #survivalProbability(double) survivalProbability(x) - p}.
  133.      *     The bounds may be bracketed for efficiency.</li>
  134.      * </ul>
  135.      *
  136.      * @throws IllegalArgumentException if {@code p < 0} or {@code p > 1}
  137.      */
  138.     @Override
  139.     public double inverseSurvivalProbability(double p) {
  140.         ArgumentUtils.checkProbability(p);
  141.         return inverseProbability(1 - p, p, true);
  142.     }

  143.     /**
  144.      * Implementation for the inverse cumulative or survival probability.
  145.      *
  146.      * @param p Cumulative probability.
  147.      * @param q Survival probability.
  148.      * @param complement Set to true to compute the inverse survival probability
  149.      * @return the value
  150.      */
  151.     private double inverseProbability(final double p, final double q, boolean complement) {
  152.         /* IMPLEMENTATION NOTES
  153.          * --------------------
  154.          * Where applicable, use is made of the one-sided Chebyshev inequality
  155.          * to bracket the root. This inequality states that
  156.          * P(X - mu >= k * sig) <= 1 / (1 + k^2),
  157.          * mu: mean, sig: standard deviation. Equivalently
  158.          * 1 - P(X < mu + k * sig) <= 1 / (1 + k^2),
  159.          * F(mu + k * sig) >= k^2 / (1 + k^2).
  160.          *
  161.          * For k = sqrt(p / (1 - p)), we find
  162.          * F(mu + k * sig) >= p,
  163.          * and (mu + k * sig) is an upper-bound for the root.
  164.          *
  165.          * Then, introducing Y = -X, mean(Y) = -mu, sd(Y) = sig, and
  166.          * P(Y >= -mu + k * sig) <= 1 / (1 + k^2),
  167.          * P(-X >= -mu + k * sig) <= 1 / (1 + k^2),
  168.          * P(X <= mu - k * sig) <= 1 / (1 + k^2),
  169.          * F(mu - k * sig) <= 1 / (1 + k^2).
  170.          *
  171.          * For k = sqrt((1 - p) / p), we find
  172.          * F(mu - k * sig) <= p,
  173.          * and (mu - k * sig) is a lower-bound for the root.
  174.          *
  175.          * In cases where the Chebyshev inequality does not apply, geometric
  176.          * progressions 1, 2, 4, ... and -1, -2, -4, ... are used to bracket
  177.          * the root.
  178.          *
  179.          * In the case of the survival probability the bracket can be set using the same
  180.          * bound given that the argument p = 1 - q, with q the survival probability.
  181.          */

  182.         double lowerBound = getSupportLowerBound();
  183.         if (p == 0) {
  184.             return lowerBound;
  185.         }
  186.         double upperBound = getSupportUpperBound();
  187.         if (q == 0) {
  188.             return upperBound;
  189.         }

  190.         final double mu = getMean();
  191.         final double sig = Math.sqrt(getVariance());
  192.         final boolean chebyshevApplies = Double.isFinite(mu) &&
  193.                                          ArgumentUtils.isFiniteStrictlyPositive(sig);

  194.         if (lowerBound == Double.NEGATIVE_INFINITY) {
  195.             lowerBound = createFiniteLowerBound(p, q, complement, upperBound, mu, sig, chebyshevApplies);
  196.         }

  197.         if (upperBound == Double.POSITIVE_INFINITY) {
  198.             upperBound = createFiniteUpperBound(p, q, complement, lowerBound, mu, sig, chebyshevApplies);
  199.         }

  200.         // Here the bracket [lower, upper] uses finite values. If the support
  201.         // is infinite the bracket can truncate the distribution and the target
  202.         // probability can be outside the range of [lower, upper].
  203.         if (upperBound == Double.MAX_VALUE) {
  204.             if (complement) {
  205.                 if (survivalProbability(upperBound) > q) {
  206.                     return getSupportUpperBound();
  207.                 }
  208.             } else if (cumulativeProbability(upperBound) < p) {
  209.                 return getSupportUpperBound();
  210.             }
  211.         }
  212.         if (lowerBound == -Double.MAX_VALUE) {
  213.             if (complement) {
  214.                 if (survivalProbability(lowerBound) < q) {
  215.                     return getSupportLowerBound();
  216.                 }
  217.             } else if (cumulativeProbability(lowerBound) > p) {
  218.                 return getSupportLowerBound();
  219.             }
  220.         }

  221.         final DoubleUnaryOperator fun = complement ?
  222.             arg -> survivalProbability(arg) - q :
  223.             arg -> cumulativeProbability(arg) - p;
  224.         // Note the initial value is robust to overflow.
  225.         // Do not use 0.5 * (lowerBound + upperBound).
  226.         final double x = new BrentSolver(SOLVER_RELATIVE_ACCURACY,
  227.                                          SOLVER_ABSOLUTE_ACCURACY,
  228.                                          SOLVER_FUNCTION_VALUE_ACCURACY)
  229.             .findRoot(fun,
  230.                       lowerBound,
  231.                       lowerBound + 0.5 * (upperBound - lowerBound),
  232.                       upperBound);

  233.         if (!isSupportConnected()) {
  234.             return searchPlateau(complement, lowerBound, x);
  235.         }
  236.         return x;
  237.     }

  238.     /**
  239.      * Create a finite lower bound. Assumes the current lower bound is negative infinity.
  240.      *
  241.      * @param p Cumulative probability.
  242.      * @param q Survival probability.
  243.      * @param complement Set to true to compute the inverse survival probability
  244.      * @param upperBound Current upper bound
  245.      * @param mu Mean
  246.      * @param sig Standard deviation
  247.      * @param chebyshevApplies True if the Chebyshev inequality applies (mean is finite and {@code sig > 0}}
  248.      * @return the finite lower bound
  249.      */
  250.     private double createFiniteLowerBound(final double p, final double q, boolean complement,
  251.         double upperBound, final double mu, final double sig, final boolean chebyshevApplies) {
  252.         double lowerBound;
  253.         if (chebyshevApplies) {
  254.             lowerBound = mu - sig * Math.sqrt(q / p);
  255.         } else {
  256.             lowerBound = Double.NEGATIVE_INFINITY;
  257.         }
  258.         // Bound may have been set as infinite
  259.         if (lowerBound == Double.NEGATIVE_INFINITY) {
  260.             lowerBound = Math.min(-1, upperBound);
  261.             if (complement) {
  262.                 while (survivalProbability(lowerBound) < q) {
  263.                     lowerBound *= 2;
  264.                 }
  265.             } else {
  266.                 while (cumulativeProbability(lowerBound) >= p) {
  267.                     lowerBound *= 2;
  268.                 }
  269.             }
  270.             // Ensure finite
  271.             lowerBound = Math.max(lowerBound, -Double.MAX_VALUE);
  272.         }
  273.         return lowerBound;
  274.     }

  275.     /**
  276.      * Create a finite upper bound. Assumes the current upper bound is positive infinity.
  277.      *
  278.      * @param p Cumulative probability.
  279.      * @param q Survival probability.
  280.      * @param complement Set to true to compute the inverse survival probability
  281.      * @param lowerBound Current lower bound
  282.      * @param mu Mean
  283.      * @param sig Standard deviation
  284.      * @param chebyshevApplies True if the Chebyshev inequality applies (mean is finite and {@code sig > 0}}
  285.      * @return the finite lower bound
  286.      */
  287.     private double createFiniteUpperBound(final double p, final double q, boolean complement,
  288.         double lowerBound, final double mu, final double sig, final boolean chebyshevApplies) {
  289.         double upperBound;
  290.         if (chebyshevApplies) {
  291.             upperBound = mu + sig * Math.sqrt(p / q);
  292.         } else {
  293.             upperBound = Double.POSITIVE_INFINITY;
  294.         }
  295.         // Bound may have been set as infinite
  296.         if (upperBound == Double.POSITIVE_INFINITY) {
  297.             upperBound = Math.max(1, lowerBound);
  298.             if (complement) {
  299.                 while (survivalProbability(upperBound) >= q) {
  300.                     upperBound *= 2;
  301.                 }
  302.             } else {
  303.                 while (cumulativeProbability(upperBound) < p) {
  304.                     upperBound *= 2;
  305.                 }
  306.             }
  307.             // Ensure finite
  308.             upperBound = Math.min(upperBound, Double.MAX_VALUE);
  309.         }
  310.         return upperBound;
  311.     }

  312.     /**
  313.      * Indicates whether the support is connected, i.e. whether all values between the
  314.      * lower and upper bound of the support are included in the support.
  315.      *
  316.      * <p>This method is used in the default implementation of the inverse cumulative and
  317.      * survival probability functions.
  318.      *
  319.      * <p>The default value is true which assumes the cdf and sf have no plateau regions
  320.      * where the same probability value is returned for a large range of x.
  321.      * Override this method if there are gaps in the support of the cdf and sf.
  322.      *
  323.      * <p>If false then the inverse will perform an additional step to ensure that the
  324.      * lower-bound of the interval on which the cdf is constant should be returned. This
  325.      * will search from the initial point x downwards if a smaller value also has the same
  326.      * cumulative (survival) probability.
  327.      *
  328.      * <p>Any plateau with a width in x smaller than the inverse absolute accuracy will
  329.      * not be searched.
  330.      *
  331.      * <p>Note: This method was public in commons math. It has been reduced to package private
  332.      * in commons statistics as it is an implementation detail.
  333.      *
  334.      * @return whether the support is connected.
  335.      * @see <a href="https://issues.apache.org/jira/browse/MATH-699">MATH-699</a>
  336.      */
  337.     boolean isSupportConnected() {
  338.         return true;
  339.     }

  340.     /**
  341.      * Test the probability function for a plateau at the point x. If detected
  342.      * search the plateau for the lowest point y such that
  343.      * {@code inf{y in R | P(y) == P(x)}}.
  344.      *
  345.      * <p>This function is used when the distribution support is not connected
  346.      * to satisfy the inverse probability requirements of {@link ContinuousDistribution}
  347.      * on the returned value.
  348.      *
  349.      * @param complement Set to true to search the survival probability.
  350.      * @param lower Lower bound used to limit the search downwards.
  351.      * @param x Current value.
  352.      * @return the infimum y
  353.      */
  354.     private double searchPlateau(boolean complement, double lower, final double x) {
  355.         // Test for plateau. Lower the value x if the probability is the same.
  356.         // Ensure the step is robust to the solver accuracy being less
  357.         // than 1 ulp of x (e.g. dx=0 will infinite loop)
  358.         final double dx = Math.max(SOLVER_ABSOLUTE_ACCURACY, Math.ulp(x));
  359.         if (x - dx >= lower) {
  360.             final DoubleUnaryOperator fun = complement ?
  361.                 this::survivalProbability :
  362.                 this::cumulativeProbability;
  363.             final double px = fun.applyAsDouble(x);
  364.             if (fun.applyAsDouble(x - dx) == px) {
  365.                 double upperBound = x;
  366.                 double lowerBound = lower;
  367.                 // Bisection search
  368.                 // Require cdf(x) < px and sf(x) > px to move the lower bound
  369.                 // to the midpoint.
  370.                 final DoubleBinaryOperator cmp = complement ?
  371.                     (a, b) -> a > b ? -1 : 1 :
  372.                     (a, b) -> a < b ? -1 : 1;
  373.                 while (upperBound - lowerBound > dx) {
  374.                     final double midPoint = 0.5 * (lowerBound + upperBound);
  375.                     if (cmp.applyAsDouble(fun.applyAsDouble(midPoint), px) < 0) {
  376.                         lowerBound = midPoint;
  377.                     } else {
  378.                         upperBound = midPoint;
  379.                     }
  380.                 }
  381.                 return upperBound;
  382.             }
  383.         }
  384.         return x;
  385.     }

  386.     /** {@inheritDoc} */
  387.     @Override
  388.     public ContinuousDistribution.Sampler createSampler(final UniformRandomProvider rng) {
  389.         // Inversion method distribution sampler.
  390.         return InverseTransformContinuousSampler.of(rng, this::inverseCumulativeProbability)::sample;
  391.     }
  392. }