001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.commons.math4.distribution; 018 019import java.io.Serializable; 020import java.lang.reflect.Array; 021import java.util.ArrayList; 022import java.util.List; 023 024import org.apache.commons.math4.exception.MathArithmeticException; 025import org.apache.commons.math4.exception.NotANumberException; 026import org.apache.commons.math4.exception.NotFiniteNumberException; 027import org.apache.commons.math4.exception.NotPositiveException; 028import org.apache.commons.math4.exception.NotStrictlyPositiveException; 029import org.apache.commons.math4.exception.NullArgumentException; 030import org.apache.commons.math4.exception.util.LocalizedFormats; 031import org.apache.commons.rng.UniformRandomProvider; 032import org.apache.commons.rng.sampling.DiscreteProbabilityCollectionSampler; 033import org.apache.commons.math4.util.MathArrays; 034import org.apache.commons.math4.util.Pair; 035 036/** 037 * <p>A generic implementation of a 038 * <a href="http://en.wikipedia.org/wiki/Probability_distribution#Discrete_probability_distribution"> 039 * discrete probability distribution (Wikipedia)</a> over a finite sample space, 040 * based on an enumerated list of <value, probability> pairs. Input probabilities must all be non-negative, 041 * but zero values are allowed and their sum does not have to equal one. Constructors will normalize input 042 * probabilities to make them sum to one.</p> 043 * 044 * <p>The list of <value, probability> pairs does not, strictly speaking, have to be a function and it can 045 * contain null values. The pmf created by the constructor will combine probabilities of equal values and 046 * will treat null values as equal. For example, if the list of pairs <"dog", 0.2>, <null, 0.1>, 047 * <"pig", 0.2>, <"dog", 0.1>, <null, 0.4> is provided to the constructor, the resulting 048 * pmf will assign mass of 0.5 to null, 0.3 to "dog" and 0.2 to pig.</p> 049 * 050 * @param <T> type of the elements in the sample space. 051 * @since 3.2 052 */ 053public class EnumeratedDistribution<T> implements Serializable { 054 /** Serializable UID. */ 055 private static final long serialVersionUID = 20160319L; 056 /** 057 * List of random variable values. 058 */ 059 private final List<T> singletons; 060 /** 061 * Probabilities of respective random variable values. For i = 0, ..., singletons.size() - 1, 062 * probability[i] is the probability that a random variable following this distribution takes 063 * the value singletons[i]. 064 */ 065 private final double[] probabilities; 066 /** 067 * Cumulative probabilities, cached to speed up sampling. 068 */ 069 private final double[] cumulativeProbabilities; 070 071 /** 072 * Create an enumerated distribution using the given random number generator 073 * and probability mass function enumeration. 074 * 075 * @param pmf probability mass function enumerated as a list of 076 * {@code <T, probability>} pairs. 077 * @throws NotPositiveException if any of the probabilities are negative. 078 * @throws NotFiniteNumberException if any of the probabilities are infinite. 079 * @throws NotANumberException if any of the probabilities are NaN. 080 * @throws MathArithmeticException all of the probabilities are 0. 081 */ 082 public EnumeratedDistribution(final List<Pair<T, Double>> pmf) 083 throws NotPositiveException, 084 MathArithmeticException, 085 NotFiniteNumberException, 086 NotANumberException { 087 singletons = new ArrayList<>(pmf.size()); 088 final double[] probs = new double[pmf.size()]; 089 int count = 0; 090 for (Pair<T, Double> sample : pmf) { 091 singletons.add(sample.getKey()); 092 final double p = sample.getValue(); 093 if (p < 0) { 094 throw new NotPositiveException(sample.getValue()); 095 } 096 if (Double.isInfinite(p)) { 097 throw new NotFiniteNumberException(p); 098 } 099 if (Double.isNaN(p)) { 100 throw new NotANumberException(); 101 } 102 probs[count++] = p; 103 } 104 105 probabilities = MathArrays.normalizeArray(probs, 1.0); 106 107 cumulativeProbabilities = new double[probabilities.length]; 108 double sum = 0; 109 for (int i = 0; i < probabilities.length; i++) { 110 sum += probabilities[i]; 111 cumulativeProbabilities[i] = sum; 112 } 113 } 114 115 /** 116 * <p>For a random variable {@code X} whose values are distributed according to 117 * this distribution, this method returns {@code P(X = x)}. In other words, 118 * this method represents the probability mass function (PMF) for the 119 * distribution.</p> 120 * 121 * <p>Note that if {@code x1} and {@code x2} satisfy {@code x1.equals(x2)}, 122 * or both are null, then {@code probability(x1) = probability(x2)}.</p> 123 * 124 * @param x the point at which the PMF is evaluated 125 * @return the value of the probability mass function at {@code x} 126 */ 127 double probability(final T x) { 128 double probability = 0; 129 130 for (int i = 0; i < probabilities.length; i++) { 131 if ((x == null && singletons.get(i) == null) || 132 (x != null && x.equals(singletons.get(i)))) { 133 probability += probabilities[i]; 134 } 135 } 136 137 return probability; 138 } 139 140 /** 141 * <p>Return the probability mass function as a list of <value, probability> pairs.</p> 142 * 143 * <p>Note that if duplicate and / or null values were provided to the constructor 144 * when creating this EnumeratedDistribution, the returned list will contain these 145 * values. If duplicates values exist, what is returned will not represent 146 * a pmf (i.e., it is up to the caller to consolidate duplicate mass points).</p> 147 * 148 * @return the probability mass function. 149 */ 150 public List<Pair<T, Double>> getPmf() { 151 final List<Pair<T, Double>> samples = new ArrayList<>(probabilities.length); 152 153 for (int i = 0; i < probabilities.length; i++) { 154 samples.add(new Pair<>(singletons.get(i), probabilities[i])); 155 } 156 157 return samples; 158 } 159 160 /** 161 * Creates a {@link Sampler}. 162 * 163 * @param rng Random number generator. 164 * @return a new sampler instance. 165 */ 166 public Sampler createSampler(final UniformRandomProvider rng) { 167 return new Sampler(rng); 168 } 169 170 /** 171 * Sampler functionality. 172 */ 173 public class Sampler { 174 /** Underlying sampler. */ 175 private final DiscreteProbabilityCollectionSampler<T> sampler; 176 177 /** 178 * @param rng Random number generator. 179 */ 180 Sampler(UniformRandomProvider rng) { 181 sampler = new DiscreteProbabilityCollectionSampler<T>(rng, singletons, probabilities); 182 } 183 184 /** 185 * Generates a random value sampled from this distribution. 186 * 187 * @return a random value. 188 */ 189 public T sample() { 190 return sampler.sample(); 191 } 192 193 /** 194 * Generates a random sample from the distribution. 195 * 196 * @param sampleSize the number of random values to generate. 197 * @return an array representing the random sample. 198 * @throws NotStrictlyPositiveException if {@code sampleSize} is not 199 * positive. 200 */ 201 public Object[] sample(int sampleSize) throws NotStrictlyPositiveException { 202 if (sampleSize <= 0) { 203 throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES, 204 sampleSize); 205 } 206 207 final Object[] out = new Object[sampleSize]; 208 209 for (int i = 0; i < sampleSize; i++) { 210 out[i] = sample(); 211 } 212 213 return out; 214 } 215 216 /** 217 * Generates a random sample from the distribution. 218 * <p> 219 * If the requested samples fit in the specified array, it is returned 220 * therein. Otherwise, a new array is allocated with the runtime type of 221 * the specified array and the size of this collection. 222 * 223 * @param sampleSize the number of random values to generate. 224 * @param array the array to populate. 225 * @return an array representing the random sample. 226 * @throws NotStrictlyPositiveException if {@code sampleSize} is not positive. 227 * @throws NullArgumentException if {@code array} is null 228 */ 229 public T[] sample(int sampleSize, final T[] array) throws NotStrictlyPositiveException { 230 if (sampleSize <= 0) { 231 throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES, sampleSize); 232 } 233 234 if (array == null) { 235 throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY); 236 } 237 238 T[] out; 239 if (array.length < sampleSize) { 240 @SuppressWarnings("unchecked") // safe as both are of type T 241 final T[] unchecked = (T[]) Array.newInstance(array.getClass().getComponentType(), sampleSize); 242 out = unchecked; 243 } else { 244 out = array; 245 } 246 247 for (int i = 0; i < sampleSize; i++) { 248 out[i] = sample(); 249 } 250 251 return out; 252 } 253 } 254}