View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.statistics.distribution;
18  
19  import java.util.function.DoubleBinaryOperator;
20  import java.util.function.DoubleUnaryOperator;
21  import org.apache.commons.numbers.rootfinder.BrentSolver;
22  import org.apache.commons.rng.UniformRandomProvider;
23  import org.apache.commons.rng.sampling.distribution.InverseTransformContinuousSampler;
24  
25  /**
26   * Base class for probability distributions on the reals.
27   * Default implementations are provided for some of the methods
28   * that do not vary from distribution to distribution.
29   *
30   * <p>This base class provides a default factory method for creating
31   * a {@linkplain ContinuousDistribution.Sampler sampler instance} that uses the
32   * <a href="https://en.wikipedia.org/wiki/Inverse_transform_sampling">
33   * inversion method</a> for generating random samples that follow the
34   * distribution.
35   *
36   * <p>The class provides functionality to evaluate the probability in a range
37   * using either the cumulative probability or the survival probability.
38   * The survival probability is used if both arguments to
39   * {@link #probability(double, double)} are above the median.
40   * Child classes with a known median can override the default {@link #getMedian()}
41   * method.
42   */
43  abstract class AbstractContinuousDistribution
44      implements ContinuousDistribution {
45  
46      // Notes on the inverse probability implementation:
47      //
48      // The Brent solver does not allow a stopping criteria for the proximity
49      // to the root; it uses equality to zero within 1 ULP. The search is
50      // iterated until there is a small difference between the upper
51      // and lower bracket of the root, expressed as a combination of relative
52      // and absolute thresholds.
53  
54      /** BrentSolver relative accuracy.
55       * This is used with {@code tol = 2 * relEps * abs(b) + absEps} so the minimum
56       * non-zero value with an effect is half of machine epsilon (2^-53). */
57      private static final double SOLVER_RELATIVE_ACCURACY = 0x1.0p-53;
58      /** BrentSolver absolute accuracy.
59       * This is used with {@code tol = 2 * relEps * abs(b) + absEps} so set to MIN_VALUE
60       * so that when the relative epsilon has no effect (as b is too small) the tolerance
61       * is at least 1 ULP for sub-normal numbers. */
62      private static final double SOLVER_ABSOLUTE_ACCURACY = Double.MIN_VALUE;
63      /** BrentSolver function value accuracy.
64       * Determines if the Brent solver performs a search. It is not used during the search.
65       * Set to a very low value to search using Brent's method unless
66       * the starting point is correct, or within 1 ULP for sub-normal probabilities. */
67      private static final double SOLVER_FUNCTION_VALUE_ACCURACY = Double.MIN_VALUE;
68  
69      /** Cached value of the median. */
70      private double median = Double.NaN;
71  
72      /**
73       * Gets the median. This is used to determine if the arguments to the
74       * {@link #probability(double, double)} function are in the upper or lower domain.
75       *
76       * <p>The default implementation calls {@link #inverseCumulativeProbability(double)}
77       * with a value of 0.5.
78       *
79       * @return the median
80       */
81      double getMedian() {
82          double m = median;
83          if (Double.isNaN(m)) {
84              median = m = inverseCumulativeProbability(0.5);
85          }
86          return m;
87      }
88  
89      /** {@inheritDoc} */
90      @Override
91      public double probability(double x0,
92                                double x1) {
93          if (x0 > x1) {
94              throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GT_HIGH, x0, x1);
95          }
96          // Use the survival probability when in the upper domain [3]:
97          //
98          //  lower          median         upper
99          //    |              |              |
100         // 1.     |------|
101         //        x0     x1
102         // 2.         |----------|
103         //            x0         x1
104         // 3.                  |--------|
105         //                     x0       x1
106 
107         final double m = getMedian();
108         if (x0 >= m) {
109             return survivalProbability(x0) - survivalProbability(x1);
110         }
111         return cumulativeProbability(x1) - cumulativeProbability(x0);
112     }
113 
114     /**
115      * {@inheritDoc}
116      *
117      * <p>The default implementation returns:
118      * <ul>
119      * <li>{@link #getSupportLowerBound()} for {@code p = 0},</li>
120      * <li>{@link #getSupportUpperBound()} for {@code p = 1}, or</li>
121      * <li>the result of a search for a root between the lower and upper bound using
122      *     {@link #cumulativeProbability(double) cumulativeProbability(x) - p}.
123      *     The bounds may be bracketed for efficiency.</li>
124      * </ul>
125      *
126      * @throws IllegalArgumentException if {@code p < 0} or {@code p > 1}
127      */
128     @Override
129     public double inverseCumulativeProbability(double p) {
130         ArgumentUtils.checkProbability(p);
131         return inverseProbability(p, 1 - p, false);
132     }
133 
134     /**
135      * {@inheritDoc}
136      *
137      * <p>The default implementation returns:
138      * <ul>
139      * <li>{@link #getSupportLowerBound()} for {@code p = 1},</li>
140      * <li>{@link #getSupportUpperBound()} for {@code p = 0}, or</li>
141      * <li>the result of a search for a root between the lower and upper bound using
142      *     {@link #survivalProbability(double) survivalProbability(x) - p}.
143      *     The bounds may be bracketed for efficiency.</li>
144      * </ul>
145      *
146      * @throws IllegalArgumentException if {@code p < 0} or {@code p > 1}
147      */
148     @Override
149     public double inverseSurvivalProbability(double p) {
150         ArgumentUtils.checkProbability(p);
151         return inverseProbability(1 - p, p, true);
152     }
153 
154     /**
155      * Implementation for the inverse cumulative or survival probability.
156      *
157      * @param p Cumulative probability.
158      * @param q Survival probability.
159      * @param complement Set to true to compute the inverse survival probability
160      * @return the value
161      */
162     private double inverseProbability(final double p, final double q, boolean complement) {
163         /* IMPLEMENTATION NOTES
164          * --------------------
165          * Where applicable, use is made of the one-sided Chebyshev inequality
166          * to bracket the root. This inequality states that
167          * P(X - mu >= k * sig) <= 1 / (1 + k^2),
168          * mu: mean, sig: standard deviation. Equivalently
169          * 1 - P(X < mu + k * sig) <= 1 / (1 + k^2),
170          * F(mu + k * sig) >= k^2 / (1 + k^2).
171          *
172          * For k = sqrt(p / (1 - p)), we find
173          * F(mu + k * sig) >= p,
174          * and (mu + k * sig) is an upper-bound for the root.
175          *
176          * Then, introducing Y = -X, mean(Y) = -mu, sd(Y) = sig, and
177          * P(Y >= -mu + k * sig) <= 1 / (1 + k^2),
178          * P(-X >= -mu + k * sig) <= 1 / (1 + k^2),
179          * P(X <= mu - k * sig) <= 1 / (1 + k^2),
180          * F(mu - k * sig) <= 1 / (1 + k^2).
181          *
182          * For k = sqrt((1 - p) / p), we find
183          * F(mu - k * sig) <= p,
184          * and (mu - k * sig) is a lower-bound for the root.
185          *
186          * In cases where the Chebyshev inequality does not apply, geometric
187          * progressions 1, 2, 4, ... and -1, -2, -4, ... are used to bracket
188          * the root.
189          *
190          * In the case of the survival probability the bracket can be set using the same
191          * bound given that the argument p = 1 - q, with q the survival probability.
192          */
193 
194         double lowerBound = getSupportLowerBound();
195         if (p == 0) {
196             return lowerBound;
197         }
198         double upperBound = getSupportUpperBound();
199         if (q == 0) {
200             return upperBound;
201         }
202 
203         final double mu = getMean();
204         final double sig = Math.sqrt(getVariance());
205         final boolean chebyshevApplies = Double.isFinite(mu) &&
206                                          ArgumentUtils.isFiniteStrictlyPositive(sig);
207 
208         if (lowerBound == Double.NEGATIVE_INFINITY) {
209             lowerBound = createFiniteLowerBound(p, q, complement, upperBound, mu, sig, chebyshevApplies);
210         }
211 
212         if (upperBound == Double.POSITIVE_INFINITY) {
213             upperBound = createFiniteUpperBound(p, q, complement, lowerBound, mu, sig, chebyshevApplies);
214         }
215 
216         // Here the bracket [lower, upper] uses finite values. If the support
217         // is infinite the bracket can truncate the distribution and the target
218         // probability can be outside the range of [lower, upper].
219         if (upperBound == Double.MAX_VALUE) {
220             if (complement) {
221                 if (survivalProbability(upperBound) > q) {
222                     return getSupportUpperBound();
223                 }
224             } else if (cumulativeProbability(upperBound) < p) {
225                 return getSupportUpperBound();
226             }
227         }
228         if (lowerBound == -Double.MAX_VALUE) {
229             if (complement) {
230                 if (survivalProbability(lowerBound) < q) {
231                     return getSupportLowerBound();
232                 }
233             } else if (cumulativeProbability(lowerBound) > p) {
234                 return getSupportLowerBound();
235             }
236         }
237 
238         final DoubleUnaryOperator fun = complement ?
239             arg -> survivalProbability(arg) - q :
240             arg -> cumulativeProbability(arg) - p;
241         // Note the initial value is robust to overflow.
242         // Do not use 0.5 * (lowerBound + upperBound).
243         final double x = new BrentSolver(SOLVER_RELATIVE_ACCURACY,
244                                          SOLVER_ABSOLUTE_ACCURACY,
245                                          SOLVER_FUNCTION_VALUE_ACCURACY)
246             .findRoot(fun,
247                       lowerBound,
248                       lowerBound + 0.5 * (upperBound - lowerBound),
249                       upperBound);
250 
251         if (!isSupportConnected()) {
252             return searchPlateau(complement, lowerBound, x);
253         }
254         return x;
255     }
256 
257     /**
258      * Create a finite lower bound. Assumes the current lower bound is negative infinity.
259      *
260      * @param p Cumulative probability.
261      * @param q Survival probability.
262      * @param complement Set to true to compute the inverse survival probability
263      * @param upperBound Current upper bound
264      * @param mu Mean
265      * @param sig Standard deviation
266      * @param chebyshevApplies True if the Chebyshev inequality applies (mean is finite and {@code sig > 0}}
267      * @return the finite lower bound
268      */
269     private double createFiniteLowerBound(final double p, final double q, boolean complement,
270         double upperBound, final double mu, final double sig, final boolean chebyshevApplies) {
271         double lowerBound;
272         if (chebyshevApplies) {
273             lowerBound = mu - sig * Math.sqrt(q / p);
274         } else {
275             lowerBound = Double.NEGATIVE_INFINITY;
276         }
277         // Bound may have been set as infinite
278         if (lowerBound == Double.NEGATIVE_INFINITY) {
279             lowerBound = Math.min(-1, upperBound);
280             if (complement) {
281                 while (survivalProbability(lowerBound) < q) {
282                     lowerBound *= 2;
283                 }
284             } else {
285                 while (cumulativeProbability(lowerBound) >= p) {
286                     lowerBound *= 2;
287                 }
288             }
289             // Ensure finite
290             lowerBound = Math.max(lowerBound, -Double.MAX_VALUE);
291         }
292         return lowerBound;
293     }
294 
295     /**
296      * Create a finite upper bound. Assumes the current upper bound is positive infinity.
297      *
298      * @param p Cumulative probability.
299      * @param q Survival probability.
300      * @param complement Set to true to compute the inverse survival probability
301      * @param lowerBound Current lower bound
302      * @param mu Mean
303      * @param sig Standard deviation
304      * @param chebyshevApplies True if the Chebyshev inequality applies (mean is finite and {@code sig > 0}}
305      * @return the finite lower bound
306      */
307     private double createFiniteUpperBound(final double p, final double q, boolean complement,
308         double lowerBound, final double mu, final double sig, final boolean chebyshevApplies) {
309         double upperBound;
310         if (chebyshevApplies) {
311             upperBound = mu + sig * Math.sqrt(p / q);
312         } else {
313             upperBound = Double.POSITIVE_INFINITY;
314         }
315         // Bound may have been set as infinite
316         if (upperBound == Double.POSITIVE_INFINITY) {
317             upperBound = Math.max(1, lowerBound);
318             if (complement) {
319                 while (survivalProbability(upperBound) >= q) {
320                     upperBound *= 2;
321                 }
322             } else {
323                 while (cumulativeProbability(upperBound) < p) {
324                     upperBound *= 2;
325                 }
326             }
327             // Ensure finite
328             upperBound = Math.min(upperBound, Double.MAX_VALUE);
329         }
330         return upperBound;
331     }
332 
333     /**
334      * Indicates whether the support is connected, i.e. whether all values between the
335      * lower and upper bound of the support are included in the support.
336      *
337      * <p>This method is used in the default implementation of the inverse cumulative and
338      * survival probability functions.
339      *
340      * <p>The default value is true which assumes the cdf and sf have no plateau regions
341      * where the same probability value is returned for a large range of x.
342      * Override this method if there are gaps in the support of the cdf and sf.
343      *
344      * <p>If false then the inverse will perform an additional step to ensure that the
345      * lower-bound of the interval on which the cdf is constant should be returned. This
346      * will search from the initial point x downwards if a smaller value also has the same
347      * cumulative (survival) probability.
348      *
349      * <p>Any plateau with a width in x smaller than the inverse absolute accuracy will
350      * not be searched.
351      *
352      * <p>Note: This method was public in commons math. It has been reduced to package private
353      * in commons statistics as it is an implementation detail.
354      *
355      * @return whether the support is connected.
356      * @see <a href="https://issues.apache.org/jira/browse/MATH-699">MATH-699</a>
357      */
358     boolean isSupportConnected() {
359         return true;
360     }
361 
362     /**
363      * Test the probability function for a plateau at the point x. If detected
364      * search the plateau for the lowest point y such that
365      * {@code inf{y in R | P(y) == P(x)}}.
366      *
367      * <p>This function is used when the distribution support is not connected
368      * to satisfy the inverse probability requirements of {@link ContinuousDistribution}
369      * on the returned value.
370      *
371      * @param complement Set to true to search the survival probability.
372      * @param lower Lower bound used to limit the search downwards.
373      * @param x Current value.
374      * @return the infimum y
375      */
376     private double searchPlateau(boolean complement, double lower, final double x) {
377         // Test for plateau. Lower the value x if the probability is the same.
378         // Ensure the step is robust to the solver accuracy being less
379         // than 1 ulp of x (e.g. dx=0 will infinite loop)
380         final double dx = Math.max(SOLVER_ABSOLUTE_ACCURACY, Math.ulp(x));
381         if (x - dx >= lower) {
382             final DoubleUnaryOperator fun = complement ?
383                 this::survivalProbability :
384                 this::cumulativeProbability;
385             final double px = fun.applyAsDouble(x);
386             if (fun.applyAsDouble(x - dx) == px) {
387                 double upperBound = x;
388                 double lowerBound = lower;
389                 // Bisection search
390                 // Require cdf(x) < px and sf(x) > px to move the lower bound
391                 // to the midpoint.
392                 final DoubleBinaryOperator cmp = complement ?
393                     (a, b) -> a > b ? -1 : 1 :
394                     (a, b) -> a < b ? -1 : 1;
395                 while (upperBound - lowerBound > dx) {
396                     final double midPoint = 0.5 * (lowerBound + upperBound);
397                     if (cmp.applyAsDouble(fun.applyAsDouble(midPoint), px) < 0) {
398                         lowerBound = midPoint;
399                     } else {
400                         upperBound = midPoint;
401                     }
402                 }
403                 return upperBound;
404             }
405         }
406         return x;
407     }
408 
409     /** {@inheritDoc} */
410     @Override
411     public ContinuousDistribution.Sampler createSampler(final UniformRandomProvider rng) {
412         // Inversion method distribution sampler.
413         return InverseTransformContinuousSampler.of(rng, this::inverseCumulativeProbability)::sample;
414     }
415 }