View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.statistics.distribution;
18  
19  import java.util.function.IntUnaryOperator;
20  import org.apache.commons.rng.UniformRandomProvider;
21  import org.apache.commons.rng.sampling.distribution.InverseTransformDiscreteSampler;
22  
23  /**
24   * Base class for integer-valued discrete distributions.  Default
25   * implementations are provided for some of the methods that do not vary
26   * from distribution to distribution.
27   *
28   * <p>This base class provides a default factory method for creating
29   * a {@linkplain DiscreteDistribution.Sampler sampler instance} that uses the
30   * <a href="https://en.wikipedia.org/wiki/Inverse_transform_sampling">
31   * inversion method</a> for generating random samples that follow the
32   * distribution.
33   *
34   * <p>The class provides functionality to evaluate the probability in a range
35   * using either the cumulative probability or the survival probability.
36   * The survival probability is used if both arguments to
37   * {@link #probability(int, int)} are above the median.
38   * Child classes with a known median can override the default {@link #getMedian()}
39   * method.
40   */
41  abstract class AbstractDiscreteDistribution
42      implements DiscreteDistribution {
43      /** Marker value for no median.
44       * This is a long to be outside the value of any possible int valued median. */
45      private static final long NO_MEDIAN = Long.MIN_VALUE;
46  
47      /** Cached value of the median. */
48      private long median = NO_MEDIAN;
49  
50      /**
51       * Gets the median. This is used to determine if the arguments to the
52       * {@link #probability(int, int)} function are in the upper or lower domain.
53       *
54       * <p>The default implementation calls {@link #inverseCumulativeProbability(double)}
55       * with a value of 0.5.
56       *
57       * @return the median
58       */
59      int getMedian() {
60          long m = median;
61          if (m == NO_MEDIAN) {
62              median = m = inverseCumulativeProbability(0.5);
63          }
64          return (int) m;
65      }
66  
67      /** {@inheritDoc} */
68      @Override
69      public double probability(int x0,
70                                int x1) {
71          if (x0 > x1) {
72              throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GT_HIGH, x0, x1);
73          }
74          // As per the default interface method handle special cases:
75          // x0     = x1 : return 0
76          // x0 + 1 = x1 : return probability(x1)
77          // Long addition avoids overflow
78          if (x0 + 1L >= x1) {
79              return x0 == x1 ? 0.0 : probability(x1);
80          }
81  
82          // Use the survival probability when in the upper domain [3]:
83          //
84          //  lower          median         upper
85          //    |              |              |
86          // 1.     |------|
87          //        x0     x1
88          // 2.         |----------|
89          //            x0         x1
90          // 3.                  |--------|
91          //                     x0       x1
92  
93          final double m = getMedian();
94          if (x0 >= m) {
95              return survivalProbability(x0) - survivalProbability(x1);
96          }
97          return cumulativeProbability(x1) - cumulativeProbability(x0);
98      }
99  
100     /**
101      * {@inheritDoc}
102      *
103      * <p>The default implementation returns:
104      * <ul>
105      * <li>{@link #getSupportLowerBound()} for {@code p = 0},</li>
106      * <li>{@link #getSupportUpperBound()} for {@code p = 1}, or</li>
107      * <li>the result of a binary search between the lower and upper bound using
108      *     {@link #cumulativeProbability(int) cumulativeProbability(x)}.
109      *     The bounds may be bracketed for efficiency.</li>
110      * </ul>
111      *
112      * @throws IllegalArgumentException if {@code p < 0} or {@code p > 1}
113      */
114     @Override
115     public int inverseCumulativeProbability(double p) {
116         ArgumentUtils.checkProbability(p);
117         return inverseProbability(p, 1 - p, false);
118     }
119 
120     /**
121      * {@inheritDoc}
122      *
123      * <p>The default implementation returns:
124      * <ul>
125      * <li>{@link #getSupportLowerBound()} for {@code p = 1},</li>
126      * <li>{@link #getSupportUpperBound()} for {@code p = 0}, or</li>
127      * <li>the result of a binary search between the lower and upper bound using
128      *     {@link #survivalProbability(int) survivalProbability(x)}.
129      *     The bounds may be bracketed for efficiency.</li>
130      * </ul>
131      *
132      * @throws IllegalArgumentException if {@code p < 0} or {@code p > 1}
133      */
134     @Override
135     public int inverseSurvivalProbability(double p) {
136         ArgumentUtils.checkProbability(p);
137         return inverseProbability(1 - p, p, true);
138     }
139 
140     /**
141      * Implementation for the inverse cumulative or survival probability.
142      *
143      * @param p Cumulative probability.
144      * @param q Survival probability.
145      * @param complement Set to true to compute the inverse survival probability
146      * @return the value
147      */
148     private int inverseProbability(double p, double q, boolean complement) {
149 
150         int lower = getSupportLowerBound();
151         if (p == 0) {
152             return lower;
153         }
154         int upper = getSupportUpperBound();
155         if (q == 0) {
156             return upper;
157         }
158 
159         // The binary search sets the upper value to the mid-point
160         // based on fun(x) >= 0. The upper value is returned.
161         //
162         // Create a function to search for x where the upper bound can be
163         // lowered if:
164         // cdf(x) >= p
165         // sf(x)  <= q
166         final IntUnaryOperator fun = complement ?
167             x -> Double.compare(q, survivalProbability(x)) :
168             x -> Double.compare(cumulativeProbability(x), p);
169 
170         if (lower == Integer.MIN_VALUE) {
171             if (fun.applyAsInt(lower) >= 0) {
172                 return lower;
173             }
174         } else {
175             // this ensures:
176             // cumulativeProbability(lower) < p
177             // survivalProbability(lower) > q
178             // which is important for the solving step
179             lower -= 1;
180         }
181 
182         // use the one-sided Chebyshev inequality to narrow the bracket
183         // cf. AbstractContinuousDistribution.inverseCumulativeProbability(double)
184         final double mu = getMean();
185         final double sig = Math.sqrt(getVariance());
186         final boolean chebyshevApplies = Double.isFinite(mu) &&
187                                          ArgumentUtils.isFiniteStrictlyPositive(sig);
188 
189         if (chebyshevApplies) {
190             double tmp = mu - sig * Math.sqrt(q / p);
191             if (tmp > lower) {
192                 lower = ((int) Math.ceil(tmp)) - 1;
193             }
194             tmp = mu + sig * Math.sqrt(p / q);
195             if (tmp < upper) {
196                 upper = ((int) Math.ceil(tmp)) - 1;
197             }
198         }
199 
200         return solveInverseProbability(fun, lower, upper);
201     }
202 
203     /**
204      * This is a utility function used by {@link
205      * #inverseProbability(double, double, boolean)}. It assumes
206      * that the inverse probability lies in the bracket {@code
207      * (lower, upper]}. The implementation does simple bisection to find the
208      * smallest {@code x} such that {@code fun(x) >= 0}.
209      *
210      * @param fun Probability function.
211      * @param lowerBound Value satisfying {@code fun(lower) < 0}.
212      * @param upperBound Value satisfying {@code fun(upper) >= 0}.
213      * @return the smallest x
214      */
215     private static int solveInverseProbability(IntUnaryOperator fun,
216                                                int lowerBound,
217                                                int upperBound) {
218         // Use long to prevent overflow during computation of the middle
219         long lower = lowerBound;
220         long upper = upperBound;
221         while (lower + 1 < upper) {
222             // Note: Cannot replace division by 2 with a right shift because
223             // (lower + upper) can be negative.
224             final long middle = (lower + upper) / 2;
225             final int pm = fun.applyAsInt((int) middle);
226             if (pm < 0) {
227                 lower = middle;
228             } else {
229                 upper = middle;
230             }
231         }
232         return (int) upper;
233     }
234 
235     /** {@inheritDoc} */
236     @Override
237     public DiscreteDistribution.Sampler createSampler(final UniformRandomProvider rng) {
238         // Inversion method distribution sampler.
239         return InverseTransformDiscreteSampler.of(rng, this::inverseCumulativeProbability)::sample;
240     }
241 }