1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17 package org.apache.commons.statistics.distribution;
18
19 import java.util.function.IntUnaryOperator;
20 import org.apache.commons.rng.UniformRandomProvider;
21 import org.apache.commons.rng.sampling.distribution.InverseTransformDiscreteSampler;
22
23 /**
24 * Base class for integer-valued discrete distributions. Default
25 * implementations are provided for some of the methods that do not vary
26 * from distribution to distribution.
27 *
28 * <p>This base class provides a default factory method for creating
29 * a {@linkplain DiscreteDistribution.Sampler sampler instance} that uses the
30 * <a href="https://en.wikipedia.org/wiki/Inverse_transform_sampling">
31 * inversion method</a> for generating random samples that follow the
32 * distribution.
33 *
34 * <p>The class provides functionality to evaluate the probability in a range
35 * using either the cumulative probability or the survival probability.
36 * The survival probability is used if both arguments to
37 * {@link #probability(int, int)} are above the median.
38 * Child classes with a known median can override the default {@link #getMedian()}
39 * method.
40 */
41 abstract class AbstractDiscreteDistribution
42 implements DiscreteDistribution {
43 /** Marker value for no median.
44 * This is a long to be outside the value of any possible int valued median. */
45 private static final long NO_MEDIAN = Long.MIN_VALUE;
46
47 /** Cached value of the median. */
48 private long median = NO_MEDIAN;
49
50 /**
51 * Gets the median. This is used to determine if the arguments to the
52 * {@link #probability(int, int)} function are in the upper or lower domain.
53 *
54 * <p>The default implementation calls {@link #inverseCumulativeProbability(double)}
55 * with a value of 0.5.
56 *
57 * @return the median
58 */
59 int getMedian() {
60 long m = median;
61 if (m == NO_MEDIAN) {
62 median = m = inverseCumulativeProbability(0.5);
63 }
64 return (int) m;
65 }
66
67 /** {@inheritDoc} */
68 @Override
69 public double probability(int x0,
70 int x1) {
71 if (x0 > x1) {
72 throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GT_HIGH, x0, x1);
73 }
74 // As per the default interface method handle special cases:
75 // x0 = x1 : return 0
76 // x0 + 1 = x1 : return probability(x1)
77 // Long addition avoids overflow
78 if (x0 + 1L >= x1) {
79 return x0 == x1 ? 0.0 : probability(x1);
80 }
81
82 // Use the survival probability when in the upper domain [3]:
83 //
84 // lower median upper
85 // | | |
86 // 1. |------|
87 // x0 x1
88 // 2. |----------|
89 // x0 x1
90 // 3. |--------|
91 // x0 x1
92
93 final double m = getMedian();
94 if (x0 >= m) {
95 return survivalProbability(x0) - survivalProbability(x1);
96 }
97 return cumulativeProbability(x1) - cumulativeProbability(x0);
98 }
99
100 /**
101 * {@inheritDoc}
102 *
103 * <p>The default implementation returns:
104 * <ul>
105 * <li>{@link #getSupportLowerBound()} for {@code p = 0},</li>
106 * <li>{@link #getSupportUpperBound()} for {@code p = 1}, or</li>
107 * <li>the result of a binary search between the lower and upper bound using
108 * {@link #cumulativeProbability(int) cumulativeProbability(x)}.
109 * The bounds may be bracketed for efficiency.</li>
110 * </ul>
111 *
112 * @throws IllegalArgumentException if {@code p < 0} or {@code p > 1}
113 */
114 @Override
115 public int inverseCumulativeProbability(double p) {
116 ArgumentUtils.checkProbability(p);
117 return inverseProbability(p, 1 - p, false);
118 }
119
120 /**
121 * {@inheritDoc}
122 *
123 * <p>The default implementation returns:
124 * <ul>
125 * <li>{@link #getSupportLowerBound()} for {@code p = 1},</li>
126 * <li>{@link #getSupportUpperBound()} for {@code p = 0}, or</li>
127 * <li>the result of a binary search between the lower and upper bound using
128 * {@link #survivalProbability(int) survivalProbability(x)}.
129 * The bounds may be bracketed for efficiency.</li>
130 * </ul>
131 *
132 * @throws IllegalArgumentException if {@code p < 0} or {@code p > 1}
133 */
134 @Override
135 public int inverseSurvivalProbability(double p) {
136 ArgumentUtils.checkProbability(p);
137 return inverseProbability(1 - p, p, true);
138 }
139
140 /**
141 * Implementation for the inverse cumulative or survival probability.
142 *
143 * @param p Cumulative probability.
144 * @param q Survival probability.
145 * @param complement Set to true to compute the inverse survival probability
146 * @return the value
147 */
148 private int inverseProbability(double p, double q, boolean complement) {
149
150 int lower = getSupportLowerBound();
151 if (p == 0) {
152 return lower;
153 }
154 int upper = getSupportUpperBound();
155 if (q == 0) {
156 return upper;
157 }
158
159 // The binary search sets the upper value to the mid-point
160 // based on fun(x) >= 0. The upper value is returned.
161 //
162 // Create a function to search for x where the upper bound can be
163 // lowered if:
164 // cdf(x) >= p
165 // sf(x) <= q
166 final IntUnaryOperator fun = complement ?
167 x -> Double.compare(q, survivalProbability(x)) :
168 x -> Double.compare(cumulativeProbability(x), p);
169
170 if (lower == Integer.MIN_VALUE) {
171 if (fun.applyAsInt(lower) >= 0) {
172 return lower;
173 }
174 } else {
175 // this ensures:
176 // cumulativeProbability(lower) < p
177 // survivalProbability(lower) > q
178 // which is important for the solving step
179 lower -= 1;
180 }
181
182 // use the one-sided Chebyshev inequality to narrow the bracket
183 // cf. AbstractContinuousDistribution.inverseCumulativeProbability(double)
184 final double mu = getMean();
185 final double sig = Math.sqrt(getVariance());
186 final boolean chebyshevApplies = Double.isFinite(mu) &&
187 ArgumentUtils.isFiniteStrictlyPositive(sig);
188
189 if (chebyshevApplies) {
190 double tmp = mu - sig * Math.sqrt(q / p);
191 if (tmp > lower) {
192 lower = ((int) Math.ceil(tmp)) - 1;
193 }
194 tmp = mu + sig * Math.sqrt(p / q);
195 if (tmp < upper) {
196 upper = ((int) Math.ceil(tmp)) - 1;
197 }
198 }
199
200 return solveInverseProbability(fun, lower, upper);
201 }
202
203 /**
204 * This is a utility function used by {@link
205 * #inverseProbability(double, double, boolean)}. It assumes
206 * that the inverse probability lies in the bracket {@code
207 * (lower, upper]}. The implementation does simple bisection to find the
208 * smallest {@code x} such that {@code fun(x) >= 0}.
209 *
210 * @param fun Probability function.
211 * @param lowerBound Value satisfying {@code fun(lower) < 0}.
212 * @param upperBound Value satisfying {@code fun(upper) >= 0}.
213 * @return the smallest x
214 */
215 private static int solveInverseProbability(IntUnaryOperator fun,
216 int lowerBound,
217 int upperBound) {
218 // Use long to prevent overflow during computation of the middle
219 long lower = lowerBound;
220 long upper = upperBound;
221 while (lower + 1 < upper) {
222 // Note: Cannot replace division by 2 with a right shift because
223 // (lower + upper) can be negative.
224 final long middle = (lower + upper) / 2;
225 final int pm = fun.applyAsInt((int) middle);
226 if (pm < 0) {
227 lower = middle;
228 } else {
229 upper = middle;
230 }
231 }
232 return (int) upper;
233 }
234
235 /** {@inheritDoc} */
236 @Override
237 public DiscreteDistribution.Sampler createSampler(final UniformRandomProvider rng) {
238 // Inversion method distribution sampler.
239 return InverseTransformDiscreteSampler.of(rng, this::inverseCumulativeProbability)::sample;
240 }
241 }