001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.commons.math3.distribution; 018 019import java.io.Serializable; 020 021import org.apache.commons.math3.exception.MathInternalError; 022import org.apache.commons.math3.exception.NotStrictlyPositiveException; 023import org.apache.commons.math3.exception.NumberIsTooLargeException; 024import org.apache.commons.math3.exception.OutOfRangeException; 025import org.apache.commons.math3.exception.util.LocalizedFormats; 026import org.apache.commons.math3.random.RandomGenerator; 027import org.apache.commons.math3.util.FastMath; 028 029/** 030 * Base class for integer-valued discrete distributions. Default 031 * implementations are provided for some of the methods that do not vary 032 * from distribution to distribution. 033 * 034 */ 035public abstract class AbstractIntegerDistribution implements IntegerDistribution, Serializable { 036 037 /** Serializable version identifier */ 038 private static final long serialVersionUID = -1146319659338487221L; 039 040 /** 041 * RandomData instance used to generate samples from the distribution. 042 * @deprecated As of 3.1, to be removed in 4.0. Please use the 043 * {@link #random} instance variable instead. 044 */ 045 @Deprecated 046 protected final org.apache.commons.math3.random.RandomDataImpl randomData = 047 new org.apache.commons.math3.random.RandomDataImpl(); 048 049 /** 050 * RNG instance used to generate samples from the distribution. 051 * @since 3.1 052 */ 053 protected final RandomGenerator random; 054 055 /** 056 * @deprecated As of 3.1, to be removed in 4.0. Please use 057 * {@link #AbstractIntegerDistribution(RandomGenerator)} instead. 058 */ 059 @Deprecated 060 protected AbstractIntegerDistribution() { 061 // Legacy users are only allowed to access the deprecated "randomData". 062 // New users are forbidden to use this constructor. 063 random = null; 064 } 065 066 /** 067 * @param rng Random number generator. 068 * @since 3.1 069 */ 070 protected AbstractIntegerDistribution(RandomGenerator rng) { 071 random = rng; 072 } 073 074 /** 075 * {@inheritDoc} 076 * 077 * The default implementation uses the identity 078 * <p>{@code P(x0 < X <= x1) = P(X <= x1) - P(X <= x0)}</p> 079 */ 080 public double cumulativeProbability(int x0, int x1) throws NumberIsTooLargeException { 081 if (x1 < x0) { 082 throw new NumberIsTooLargeException(LocalizedFormats.LOWER_ENDPOINT_ABOVE_UPPER_ENDPOINT, 083 x0, x1, true); 084 } 085 return cumulativeProbability(x1) - cumulativeProbability(x0); 086 } 087 088 /** 089 * {@inheritDoc} 090 * 091 * The default implementation returns 092 * <ul> 093 * <li>{@link #getSupportLowerBound()} for {@code p = 0},</li> 094 * <li>{@link #getSupportUpperBound()} for {@code p = 1}, and</li> 095 * <li>{@link #solveInverseCumulativeProbability(double, int, int)} for 096 * {@code 0 < p < 1}.</li> 097 * </ul> 098 */ 099 public int inverseCumulativeProbability(final double p) throws OutOfRangeException { 100 if (p < 0.0 || p > 1.0) { 101 throw new OutOfRangeException(p, 0, 1); 102 } 103 104 int lower = getSupportLowerBound(); 105 if (p == 0.0) { 106 return lower; 107 } 108 if (lower == Integer.MIN_VALUE) { 109 if (checkedCumulativeProbability(lower) >= p) { 110 return lower; 111 } 112 } else { 113 lower -= 1; // this ensures cumulativeProbability(lower) < p, which 114 // is important for the solving step 115 } 116 117 int upper = getSupportUpperBound(); 118 if (p == 1.0) { 119 return upper; 120 } 121 122 // use the one-sided Chebyshev inequality to narrow the bracket 123 // cf. AbstractRealDistribution.inverseCumulativeProbability(double) 124 final double mu = getNumericalMean(); 125 final double sigma = FastMath.sqrt(getNumericalVariance()); 126 final boolean chebyshevApplies = !(Double.isInfinite(mu) || Double.isNaN(mu) || 127 Double.isInfinite(sigma) || Double.isNaN(sigma) || sigma == 0.0); 128 if (chebyshevApplies) { 129 double k = FastMath.sqrt((1.0 - p) / p); 130 double tmp = mu - k * sigma; 131 if (tmp > lower) { 132 lower = ((int) FastMath.ceil(tmp)) - 1; 133 } 134 k = 1.0 / k; 135 tmp = mu + k * sigma; 136 if (tmp < upper) { 137 upper = ((int) FastMath.ceil(tmp)) - 1; 138 } 139 } 140 141 return solveInverseCumulativeProbability(p, lower, upper); 142 } 143 144 /** 145 * This is a utility function used by {@link 146 * #inverseCumulativeProbability(double)}. It assumes {@code 0 < p < 1} and 147 * that the inverse cumulative probability lies in the bracket {@code 148 * (lower, upper]}. The implementation does simple bisection to find the 149 * smallest {@code p}-quantile <code>inf{x in Z | P(X<=x) >= p}</code>. 150 * 151 * @param p the cumulative probability 152 * @param lower a value satisfying {@code cumulativeProbability(lower) < p} 153 * @param upper a value satisfying {@code p <= cumulativeProbability(upper)} 154 * @return the smallest {@code p}-quantile of this distribution 155 */ 156 protected int solveInverseCumulativeProbability(final double p, int lower, int upper) { 157 while (lower + 1 < upper) { 158 int xm = (lower + upper) / 2; 159 if (xm < lower || xm > upper) { 160 /* 161 * Overflow. 162 * There will never be an overflow in both calculation methods 163 * for xm at the same time 164 */ 165 xm = lower + (upper - lower) / 2; 166 } 167 168 double pm = checkedCumulativeProbability(xm); 169 if (pm >= p) { 170 upper = xm; 171 } else { 172 lower = xm; 173 } 174 } 175 return upper; 176 } 177 178 /** {@inheritDoc} */ 179 public void reseedRandomGenerator(long seed) { 180 random.setSeed(seed); 181 randomData.reSeed(seed); 182 } 183 184 /** 185 * {@inheritDoc} 186 * 187 * The default implementation uses the 188 * <a href="http://en.wikipedia.org/wiki/Inverse_transform_sampling"> 189 * inversion method</a>. 190 */ 191 public int sample() { 192 return inverseCumulativeProbability(random.nextDouble()); 193 } 194 195 /** 196 * {@inheritDoc} 197 * 198 * The default implementation generates the sample by calling 199 * {@link #sample()} in a loop. 200 */ 201 public int[] sample(int sampleSize) { 202 if (sampleSize <= 0) { 203 throw new NotStrictlyPositiveException( 204 LocalizedFormats.NUMBER_OF_SAMPLES, sampleSize); 205 } 206 int[] out = new int[sampleSize]; 207 for (int i = 0; i < sampleSize; i++) { 208 out[i] = sample(); 209 } 210 return out; 211 } 212 213 /** 214 * Computes the cumulative probability function and checks for {@code NaN} 215 * values returned. Throws {@code MathInternalError} if the value is 216 * {@code NaN}. Rethrows any exception encountered evaluating the cumulative 217 * probability function. Throws {@code MathInternalError} if the cumulative 218 * probability function returns {@code NaN}. 219 * 220 * @param argument input value 221 * @return the cumulative probability 222 * @throws MathInternalError if the cumulative probability is {@code NaN} 223 */ 224 private double checkedCumulativeProbability(int argument) 225 throws MathInternalError { 226 double result = Double.NaN; 227 result = cumulativeProbability(argument); 228 if (Double.isNaN(result)) { 229 throw new MathInternalError(LocalizedFormats 230 .DISCRETE_CUMULATIVE_PROBABILITY_RETURNED_NAN, argument); 231 } 232 return result; 233 } 234 235 /** 236 * For a random variable {@code X} whose values are distributed according to 237 * this distribution, this method returns {@code log(P(X = x))}, where 238 * {@code log} is the natural logarithm. In other words, this method 239 * represents the logarithm of the probability mass function (PMF) for the 240 * distribution. Note that due to the floating point precision and 241 * under/overflow issues, this method will for some distributions be more 242 * precise and faster than computing the logarithm of 243 * {@link #probability(int)}. 244 * <p> 245 * The default implementation simply computes the logarithm of {@code probability(x)}.</p> 246 * 247 * @param x the point at which the PMF is evaluated 248 * @return the logarithm of the value of the probability mass function at {@code x} 249 */ 250 public double logProbability(int x) { 251 return FastMath.log(probability(x)); 252 } 253}