001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.commons.math4.legacy.distribution; 018 019import java.lang.reflect.Array; 020import java.util.ArrayList; 021import java.util.List; 022 023import org.apache.commons.math4.legacy.exception.MathArithmeticException; 024import org.apache.commons.math4.legacy.exception.NotANumberException; 025import org.apache.commons.math4.legacy.exception.NotFiniteNumberException; 026import org.apache.commons.math4.legacy.exception.NotPositiveException; 027import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException; 028import org.apache.commons.math4.legacy.exception.NullArgumentException; 029import org.apache.commons.math4.legacy.exception.util.LocalizedFormats; 030import org.apache.commons.rng.UniformRandomProvider; 031import org.apache.commons.rng.sampling.DiscreteProbabilityCollectionSampler; 032import org.apache.commons.math4.legacy.core.MathArrays; 033import org.apache.commons.math4.legacy.core.Pair; 034 035/** 036 * <p>A generic implementation of a 037 * <a href="http://en.wikipedia.org/wiki/Probability_distribution#Discrete_probability_distribution"> 038 * discrete probability distribution (Wikipedia)</a> over a finite sample space, 039 * based on an enumerated list of <value, probability> pairs. Input probabilities must all be non-negative, 040 * but zero values are allowed and their sum does not have to equal one. Constructors will normalize input 041 * probabilities to make them sum to one.</p> 042 * 043 * <p>The list of <value, probability> pairs does not, strictly speaking, have to be a function and it can 044 * contain null values. The pmf created by the constructor will combine probabilities of equal values and 045 * will treat null values as equal. For example, if the list of pairs <"dog", 0.2>, <null, 0.1>, 046 * <"pig", 0.2>, <"dog", 0.1>, <null, 0.4> is provided to the constructor, the resulting 047 * pmf will assign mass of 0.5 to null, 0.3 to "dog" and 0.2 to pig.</p> 048 * 049 * @param <T> type of the elements in the sample space. 050 * @since 3.2 051 */ 052public class EnumeratedDistribution<T> { 053 /** 054 * List of random variable values. 055 */ 056 private final List<T> singletons; 057 /** 058 * Probabilities of respective random variable values. For i = 0, ..., singletons.size() - 1, 059 * probability[i] is the probability that a random variable following this distribution takes 060 * the value singletons[i]. 061 */ 062 private final double[] probabilities; 063 /** 064 * Cumulative probabilities, cached to speed up sampling. 065 */ 066 private final double[] cumulativeProbabilities; 067 068 /** 069 * Create an enumerated distribution using the given random number generator 070 * and probability mass function enumeration. 071 * 072 * @param pmf probability mass function enumerated as a list of 073 * {@code <T, probability>} pairs. 074 * @throws NotPositiveException if any of the probabilities are negative. 075 * @throws NotFiniteNumberException if any of the probabilities are infinite. 076 * @throws NotANumberException if any of the probabilities are NaN. 077 * @throws MathArithmeticException all of the probabilities are 0. 078 */ 079 public EnumeratedDistribution(final List<Pair<T, Double>> pmf) 080 throws NotPositiveException, 081 MathArithmeticException, 082 NotFiniteNumberException, 083 NotANumberException { 084 singletons = new ArrayList<>(pmf.size()); 085 final double[] probs = new double[pmf.size()]; 086 int count = 0; 087 for (Pair<T, Double> sample : pmf) { 088 singletons.add(sample.getKey()); 089 final double p = sample.getValue(); 090 if (p < 0) { 091 throw new NotPositiveException(sample.getValue()); 092 } 093 if (Double.isInfinite(p)) { 094 throw new NotFiniteNumberException(p); 095 } 096 if (Double.isNaN(p)) { 097 throw new NotANumberException(); 098 } 099 probs[count++] = p; 100 } 101 102 probabilities = MathArrays.normalizeArray(probs, 1.0); 103 104 cumulativeProbabilities = new double[probabilities.length]; 105 double sum = 0; 106 for (int i = 0; i < probabilities.length; i++) { 107 sum += probabilities[i]; 108 cumulativeProbabilities[i] = sum; 109 } 110 } 111 112 /** 113 * <p>For a random variable {@code X} whose values are distributed according to 114 * this distribution, this method returns {@code P(X = x)}. In other words, 115 * this method represents the probability mass function (PMF) for the 116 * distribution.</p> 117 * 118 * <p>Note that if {@code x1} and {@code x2} satisfy {@code x1.equals(x2)}, 119 * or both are null, then {@code probability(x1) = probability(x2)}.</p> 120 * 121 * @param x the point at which the PMF is evaluated 122 * @return the value of the probability mass function at {@code x} 123 */ 124 double probability(final T x) { 125 double probability = 0; 126 127 for (int i = 0; i < probabilities.length; i++) { 128 if ((x == null && singletons.get(i) == null) || 129 (x != null && x.equals(singletons.get(i)))) { 130 probability += probabilities[i]; 131 } 132 } 133 134 return probability; 135 } 136 137 /** 138 * <p>Return the probability mass function as a list of <value, probability> pairs.</p> 139 * 140 * <p>Note that if duplicate and / or null values were provided to the constructor 141 * when creating this EnumeratedDistribution, the returned list will contain these 142 * values. If duplicates values exist, what is returned will not represent 143 * a pmf (i.e., it is up to the caller to consolidate duplicate mass points).</p> 144 * 145 * @return the probability mass function. 146 */ 147 public List<Pair<T, Double>> getPmf() { 148 final List<Pair<T, Double>> samples = new ArrayList<>(probabilities.length); 149 150 for (int i = 0; i < probabilities.length; i++) { 151 samples.add(new Pair<>(singletons.get(i), probabilities[i])); 152 } 153 154 return samples; 155 } 156 157 /** 158 * Creates a {@link Sampler}. 159 * 160 * @param rng Random number generator. 161 * @return a new sampler instance. 162 */ 163 public Sampler createSampler(final UniformRandomProvider rng) { 164 return new Sampler(rng); 165 } 166 167 /** 168 * Sampler functionality. 169 * 170 * <ul> 171 * <li> 172 * The cumulative probability distribution is created (and sampled from) 173 * using the input order of the {@link EnumeratedDistribution#EnumeratedDistribution(List) 174 * constructor arguments}: A different input order will create a different 175 * sequence of samples. 176 * The samples will only be reproducible with the same RNG starting from 177 * the same RNG state and the same input order to constructor. 178 * </li> 179 * <li> 180 * The minimum supported probability is 2<sup>-53</sup>. 181 * </li> 182 * </ul> 183 */ 184 public class Sampler { 185 /** Underlying sampler. */ 186 private final DiscreteProbabilityCollectionSampler<T> sampler; 187 188 /** 189 * @param rng Random number generator. 190 */ 191 Sampler(UniformRandomProvider rng) { 192 sampler = new DiscreteProbabilityCollectionSampler<>(rng, singletons, probabilities); 193 } 194 195 /** 196 * Generates a random value sampled from this distribution. 197 * 198 * @return a random value. 199 */ 200 public T sample() { 201 return sampler.sample(); 202 } 203 204 /** 205 * Generates a random sample from the distribution. 206 * 207 * @param sampleSize the number of random values to generate. 208 * @return an array representing the random sample. 209 * @throws NotStrictlyPositiveException if {@code sampleSize} is not 210 * positive. 211 */ 212 public Object[] sample(int sampleSize) throws NotStrictlyPositiveException { 213 if (sampleSize <= 0) { 214 throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES, 215 sampleSize); 216 } 217 218 final Object[] out = new Object[sampleSize]; 219 220 for (int i = 0; i < sampleSize; i++) { 221 out[i] = sample(); 222 } 223 224 return out; 225 } 226 227 /** 228 * Generates a random sample from the distribution. 229 * <p> 230 * If the requested samples fit in the specified array, it is returned 231 * therein. Otherwise, a new array is allocated with the runtime type of 232 * the specified array and the size of this collection. 233 * 234 * @param sampleSize the number of random values to generate. 235 * @param array the array to populate. 236 * @return an array representing the random sample. 237 * @throws NotStrictlyPositiveException if {@code sampleSize} is not positive. 238 * @throws NullArgumentException if {@code array} is null 239 */ 240 public T[] sample(int sampleSize, final T[] array) throws NotStrictlyPositiveException { 241 if (sampleSize <= 0) { 242 throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES, sampleSize); 243 } 244 245 if (array == null) { 246 throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY); 247 } 248 249 T[] out; 250 if (array.length < sampleSize) { 251 @SuppressWarnings("unchecked") // safe as both are of type T 252 final T[] unchecked = (T[]) Array.newInstance(array.getClass().getComponentType(), sampleSize); 253 out = unchecked; 254 } else { 255 out = array; 256 } 257 258 for (int i = 0; i < sampleSize; i++) { 259 out[i] = sample(); 260 } 261 262 return out; 263 } 264 } 265}