1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 package org.apache.commons.statistics.distribution; 18 19 import java.util.function.IntUnaryOperator; 20 import org.apache.commons.rng.UniformRandomProvider; 21 import org.apache.commons.rng.sampling.distribution.InverseTransformDiscreteSampler; 22 23 /** 24 * Base class for integer-valued discrete distributions. Default 25 * implementations are provided for some of the methods that do not vary 26 * from distribution to distribution. 27 * 28 * <p>This base class provides a default factory method for creating 29 * a {@linkplain DiscreteDistribution.Sampler sampler instance} that uses the 30 * <a href="https://en.wikipedia.org/wiki/Inverse_transform_sampling"> 31 * inversion method</a> for generating random samples that follow the 32 * distribution. 33 * 34 * <p>The class provides functionality to evaluate the probability in a range 35 * using either the cumulative probability or the survival probability. 36 * The survival probability is used if both arguments to 37 * {@link #probability(int, int)} are above the median. 38 * Child classes with a known median can override the default {@link #getMedian()} 39 * method. 40 */ 41 abstract class AbstractDiscreteDistribution 42 implements DiscreteDistribution { 43 /** Marker value for no median. 44 * This is a long to be outside the value of any possible int valued median. */ 45 private static final long NO_MEDIAN = Long.MIN_VALUE; 46 47 /** Cached value of the median. */ 48 private long median = NO_MEDIAN; 49 50 /** 51 * Gets the median. This is used to determine if the arguments to the 52 * {@link #probability(int, int)} function are in the upper or lower domain. 53 * 54 * <p>The default implementation calls {@link #inverseCumulativeProbability(double)} 55 * with a value of 0.5. 56 * 57 * @return the median 58 */ 59 int getMedian() { 60 long m = median; 61 if (m == NO_MEDIAN) { 62 median = m = inverseCumulativeProbability(0.5); 63 } 64 return (int) m; 65 } 66 67 /** {@inheritDoc} */ 68 @Override 69 public double probability(int x0, 70 int x1) { 71 if (x0 > x1) { 72 throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GT_HIGH, x0, x1); 73 } 74 // As per the default interface method handle special cases: 75 // x0 = x1 : return 0 76 // x0 + 1 = x1 : return probability(x1) 77 // Long addition avoids overflow 78 if (x0 + 1L >= x1) { 79 return x0 == x1 ? 0.0 : probability(x1); 80 } 81 82 // Use the survival probability when in the upper domain [3]: 83 // 84 // lower median upper 85 // | | | 86 // 1. |------| 87 // x0 x1 88 // 2. |----------| 89 // x0 x1 90 // 3. |--------| 91 // x0 x1 92 93 final double m = getMedian(); 94 if (x0 >= m) { 95 return survivalProbability(x0) - survivalProbability(x1); 96 } 97 return cumulativeProbability(x1) - cumulativeProbability(x0); 98 } 99 100 /** 101 * {@inheritDoc} 102 * 103 * <p>The default implementation returns: 104 * <ul> 105 * <li>{@link #getSupportLowerBound()} for {@code p = 0},</li> 106 * <li>{@link #getSupportUpperBound()} for {@code p = 1}, or</li> 107 * <li>the result of a binary search between the lower and upper bound using 108 * {@link #cumulativeProbability(int) cumulativeProbability(x)}. 109 * The bounds may be bracketed for efficiency.</li> 110 * </ul> 111 * 112 * @throws IllegalArgumentException if {@code p < 0} or {@code p > 1} 113 */ 114 @Override 115 public int inverseCumulativeProbability(double p) { 116 ArgumentUtils.checkProbability(p); 117 return inverseProbability(p, 1 - p, false); 118 } 119 120 /** 121 * {@inheritDoc} 122 * 123 * <p>The default implementation returns: 124 * <ul> 125 * <li>{@link #getSupportLowerBound()} for {@code p = 1},</li> 126 * <li>{@link #getSupportUpperBound()} for {@code p = 0}, or</li> 127 * <li>the result of a binary search between the lower and upper bound using 128 * {@link #survivalProbability(int) survivalProbability(x)}. 129 * The bounds may be bracketed for efficiency.</li> 130 * </ul> 131 * 132 * @throws IllegalArgumentException if {@code p < 0} or {@code p > 1} 133 */ 134 @Override 135 public int inverseSurvivalProbability(double p) { 136 ArgumentUtils.checkProbability(p); 137 return inverseProbability(1 - p, p, true); 138 } 139 140 /** 141 * Implementation for the inverse cumulative or survival probability. 142 * 143 * @param p Cumulative probability. 144 * @param q Survival probability. 145 * @param complement Set to true to compute the inverse survival probability 146 * @return the value 147 */ 148 private int inverseProbability(double p, double q, boolean complement) { 149 150 int lower = getSupportLowerBound(); 151 if (p == 0) { 152 return lower; 153 } 154 int upper = getSupportUpperBound(); 155 if (q == 0) { 156 return upper; 157 } 158 159 // The binary search sets the upper value to the mid-point 160 // based on fun(x) >= 0. The upper value is returned. 161 // 162 // Create a function to search for x where the upper bound can be 163 // lowered if: 164 // cdf(x) >= p 165 // sf(x) <= q 166 final IntUnaryOperator fun = complement ? 167 x -> Double.compare(q, survivalProbability(x)) : 168 x -> Double.compare(cumulativeProbability(x), p); 169 170 if (lower == Integer.MIN_VALUE) { 171 if (fun.applyAsInt(lower) >= 0) { 172 return lower; 173 } 174 } else { 175 // this ensures: 176 // cumulativeProbability(lower) < p 177 // survivalProbability(lower) > q 178 // which is important for the solving step 179 lower -= 1; 180 } 181 182 // use the one-sided Chebyshev inequality to narrow the bracket 183 // cf. AbstractContinuousDistribution.inverseCumulativeProbability(double) 184 final double mu = getMean(); 185 final double sig = Math.sqrt(getVariance()); 186 final boolean chebyshevApplies = Double.isFinite(mu) && 187 ArgumentUtils.isFiniteStrictlyPositive(sig); 188 189 if (chebyshevApplies) { 190 double tmp = mu - sig * Math.sqrt(q / p); 191 if (tmp > lower) { 192 lower = ((int) Math.ceil(tmp)) - 1; 193 } 194 tmp = mu + sig * Math.sqrt(p / q); 195 if (tmp < upper) { 196 upper = ((int) Math.ceil(tmp)) - 1; 197 } 198 } 199 200 return solveInverseProbability(fun, lower, upper); 201 } 202 203 /** 204 * This is a utility function used by {@link 205 * #inverseProbability(double, double, boolean)}. It assumes 206 * that the inverse probability lies in the bracket {@code 207 * (lower, upper]}. The implementation does simple bisection to find the 208 * smallest {@code x} such that {@code fun(x) >= 0}. 209 * 210 * @param fun Probability function. 211 * @param lowerBound Value satisfying {@code fun(lower) < 0}. 212 * @param upperBound Value satisfying {@code fun(upper) >= 0}. 213 * @return the smallest x 214 */ 215 private static int solveInverseProbability(IntUnaryOperator fun, 216 int lowerBound, 217 int upperBound) { 218 // Use long to prevent overflow during computation of the middle 219 long lower = lowerBound; 220 long upper = upperBound; 221 while (lower + 1 < upper) { 222 // Note: Cannot replace division by 2 with a right shift because 223 // (lower + upper) can be negative. 224 final long middle = (lower + upper) / 2; 225 final int pm = fun.applyAsInt((int) middle); 226 if (pm < 0) { 227 lower = middle; 228 } else { 229 upper = middle; 230 } 231 } 232 return (int) upper; 233 } 234 235 /** {@inheritDoc} */ 236 @Override 237 public DiscreteDistribution.Sampler createSampler(final UniformRandomProvider rng) { 238 // Inversion method distribution sampler. 239 return InverseTransformDiscreteSampler.of(rng, this::inverseCumulativeProbability)::sample; 240 } 241 }