001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.commons.math3.distribution; 018 019import java.io.Serializable; 020import java.lang.reflect.Array; 021import java.util.ArrayList; 022import java.util.Arrays; 023import java.util.List; 024 025import org.apache.commons.math3.exception.MathArithmeticException; 026import org.apache.commons.math3.exception.NotANumberException; 027import org.apache.commons.math3.exception.NotFiniteNumberException; 028import org.apache.commons.math3.exception.NotPositiveException; 029import org.apache.commons.math3.exception.NotStrictlyPositiveException; 030import org.apache.commons.math3.exception.NullArgumentException; 031import org.apache.commons.math3.exception.util.LocalizedFormats; 032import org.apache.commons.math3.random.RandomGenerator; 033import org.apache.commons.math3.random.Well19937c; 034import org.apache.commons.math3.util.MathArrays; 035import org.apache.commons.math3.util.Pair; 036 037/** 038 * <p>A generic implementation of a 039 * <a href="http://en.wikipedia.org/wiki/Probability_distribution#Discrete_probability_distribution"> 040 * discrete probability distribution (Wikipedia)</a> over a finite sample space, 041 * based on an enumerated list of <value, probability> pairs. Input probabilities must all be non-negative, 042 * but zero values are allowed and their sum does not have to equal one. Constructors will normalize input 043 * probabilities to make them sum to one.</p> 044 * 045 * <p>The list of <value, probability> pairs does not, strictly speaking, have to be a function and it can 046 * contain null values. The pmf created by the constructor will combine probabilities of equal values and 047 * will treat null values as equal. For example, if the list of pairs <"dog", 0.2>, <null, 0.1>, 048 * <"pig", 0.2>, <"dog", 0.1>, <null, 0.4> is provided to the constructor, the resulting 049 * pmf will assign mass of 0.5 to null, 0.3 to "dog" and 0.2 to null.</p> 050 * 051 * @param <T> type of the elements in the sample space. 052 * @since 3.2 053 */ 054public class EnumeratedDistribution<T> implements Serializable { 055 056 /** Serializable UID. */ 057 private static final long serialVersionUID = 20123308L; 058 059 /** 060 * RNG instance used to generate samples from the distribution. 061 */ 062 protected final RandomGenerator random; 063 064 /** 065 * List of random variable values. 066 */ 067 private final List<T> singletons; 068 069 /** 070 * Probabilities of respective random variable values. For i = 0, ..., singletons.size() - 1, 071 * probability[i] is the probability that a random variable following this distribution takes 072 * the value singletons[i]. 073 */ 074 private final double[] probabilities; 075 076 /** 077 * Cumulative probabilities, cached to speed up sampling. 078 */ 079 private final double[] cumulativeProbabilities; 080 081 /** 082 * Create an enumerated distribution using the given probability mass function 083 * enumeration. 084 * <p> 085 * <b>Note:</b> this constructor will implicitly create an instance of 086 * {@link Well19937c} as random generator to be used for sampling only (see 087 * {@link #sample()} and {@link #sample(int)}). In case no sampling is 088 * needed for the created distribution, it is advised to pass {@code null} 089 * as random generator via the appropriate constructors to avoid the 090 * additional initialisation overhead. 091 * 092 * @param pmf probability mass function enumerated as a list of <T, probability> 093 * pairs. 094 * @throws NotPositiveException if any of the probabilities are negative. 095 * @throws NotFiniteNumberException if any of the probabilities are infinite. 096 * @throws NotANumberException if any of the probabilities are NaN. 097 * @throws MathArithmeticException all of the probabilities are 0. 098 */ 099 public EnumeratedDistribution(final List<Pair<T, Double>> pmf) 100 throws NotPositiveException, MathArithmeticException, NotFiniteNumberException, NotANumberException { 101 this(new Well19937c(), pmf); 102 } 103 104 /** 105 * Create an enumerated distribution using the given random number generator 106 * and probability mass function enumeration. 107 * 108 * @param rng random number generator. 109 * @param pmf probability mass function enumerated as a list of <T, probability> 110 * pairs. 111 * @throws NotPositiveException if any of the probabilities are negative. 112 * @throws NotFiniteNumberException if any of the probabilities are infinite. 113 * @throws NotANumberException if any of the probabilities are NaN. 114 * @throws MathArithmeticException all of the probabilities are 0. 115 */ 116 public EnumeratedDistribution(final RandomGenerator rng, final List<Pair<T, Double>> pmf) 117 throws NotPositiveException, MathArithmeticException, NotFiniteNumberException, NotANumberException { 118 random = rng; 119 120 singletons = new ArrayList<T>(pmf.size()); 121 final double[] probs = new double[pmf.size()]; 122 123 for (int i = 0; i < pmf.size(); i++) { 124 final Pair<T, Double> sample = pmf.get(i); 125 singletons.add(sample.getKey()); 126 final double p = sample.getValue(); 127 if (p < 0) { 128 throw new NotPositiveException(sample.getValue()); 129 } 130 if (Double.isInfinite(p)) { 131 throw new NotFiniteNumberException(p); 132 } 133 if (Double.isNaN(p)) { 134 throw new NotANumberException(); 135 } 136 probs[i] = p; 137 } 138 139 probabilities = MathArrays.normalizeArray(probs, 1.0); 140 141 cumulativeProbabilities = new double[probabilities.length]; 142 double sum = 0; 143 for (int i = 0; i < probabilities.length; i++) { 144 sum += probabilities[i]; 145 cumulativeProbabilities[i] = sum; 146 } 147 } 148 149 /** 150 * Reseed the random generator used to generate samples. 151 * 152 * @param seed the new seed 153 */ 154 public void reseedRandomGenerator(long seed) { 155 random.setSeed(seed); 156 } 157 158 /** 159 * <p>For a random variable {@code X} whose values are distributed according to 160 * this distribution, this method returns {@code P(X = x)}. In other words, 161 * this method represents the probability mass function (PMF) for the 162 * distribution.</p> 163 * 164 * <p>Note that if {@code x1} and {@code x2} satisfy {@code x1.equals(x2)}, 165 * or both are null, then {@code probability(x1) = probability(x2)}.</p> 166 * 167 * @param x the point at which the PMF is evaluated 168 * @return the value of the probability mass function at {@code x} 169 */ 170 double probability(final T x) { 171 double probability = 0; 172 173 for (int i = 0; i < probabilities.length; i++) { 174 if ((x == null && singletons.get(i) == null) || 175 (x != null && x.equals(singletons.get(i)))) { 176 probability += probabilities[i]; 177 } 178 } 179 180 return probability; 181 } 182 183 /** 184 * <p>Return the probability mass function as a list of <value, probability> pairs.</p> 185 * 186 * <p>Note that if duplicate and / or null values were provided to the constructor 187 * when creating this EnumeratedDistribution, the returned list will contain these 188 * values. If duplicates values exist, what is returned will not represent 189 * a pmf (i.e., it is up to the caller to consolidate duplicate mass points).</p> 190 * 191 * @return the probability mass function. 192 */ 193 public List<Pair<T, Double>> getPmf() { 194 final List<Pair<T, Double>> samples = new ArrayList<Pair<T, Double>>(probabilities.length); 195 196 for (int i = 0; i < probabilities.length; i++) { 197 samples.add(new Pair<T, Double>(singletons.get(i), probabilities[i])); 198 } 199 200 return samples; 201 } 202 203 /** 204 * Generate a random value sampled from this distribution. 205 * 206 * @return a random value. 207 */ 208 public T sample() { 209 final double randomValue = random.nextDouble(); 210 211 int index = Arrays.binarySearch(cumulativeProbabilities, randomValue); 212 if (index < 0) { 213 index = -index-1; 214 } 215 216 if (index >= 0 && 217 index < probabilities.length && 218 randomValue < cumulativeProbabilities[index]) { 219 return singletons.get(index); 220 } 221 222 /* This should never happen, but it ensures we will return a correct 223 * object in case there is some floating point inequality problem 224 * wrt the cumulative probabilities. */ 225 return singletons.get(singletons.size() - 1); 226 } 227 228 /** 229 * Generate a random sample from the distribution. 230 * 231 * @param sampleSize the number of random values to generate. 232 * @return an array representing the random sample. 233 * @throws NotStrictlyPositiveException if {@code sampleSize} is not 234 * positive. 235 */ 236 public Object[] sample(int sampleSize) throws NotStrictlyPositiveException { 237 if (sampleSize <= 0) { 238 throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES, 239 sampleSize); 240 } 241 242 final Object[] out = new Object[sampleSize]; 243 244 for (int i = 0; i < sampleSize; i++) { 245 out[i] = sample(); 246 } 247 248 return out; 249 250 } 251 252 /** 253 * Generate a random sample from the distribution. 254 * <p> 255 * If the requested samples fit in the specified array, it is returned 256 * therein. Otherwise, a new array is allocated with the runtime type of 257 * the specified array and the size of this collection. 258 * 259 * @param sampleSize the number of random values to generate. 260 * @param array the array to populate. 261 * @return an array representing the random sample. 262 * @throws NotStrictlyPositiveException if {@code sampleSize} is not positive. 263 * @throws NullArgumentException if {@code array} is null 264 */ 265 public T[] sample(int sampleSize, final T[] array) throws NotStrictlyPositiveException { 266 if (sampleSize <= 0) { 267 throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES, sampleSize); 268 } 269 270 if (array == null) { 271 throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY); 272 } 273 274 T[] out; 275 if (array.length < sampleSize) { 276 @SuppressWarnings("unchecked") // safe as both are of type T 277 final T[] unchecked = (T[]) Array.newInstance(array.getClass().getComponentType(), sampleSize); 278 out = unchecked; 279 } else { 280 out = array; 281 } 282 283 for (int i = 0; i < sampleSize; i++) { 284 out[i] = sample(); 285 } 286 287 return out; 288 289 } 290 291}