1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17 package org.apache.commons.math4.legacy.distribution;
18
19 import java.lang.reflect.Array;
20 import java.util.ArrayList;
21 import java.util.List;
22
23 import org.apache.commons.math4.legacy.exception.MathArithmeticException;
24 import org.apache.commons.math4.legacy.exception.NotANumberException;
25 import org.apache.commons.math4.legacy.exception.NotFiniteNumberException;
26 import org.apache.commons.math4.legacy.exception.NotPositiveException;
27 import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
28 import org.apache.commons.math4.legacy.exception.NullArgumentException;
29 import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
30 import org.apache.commons.rng.UniformRandomProvider;
31 import org.apache.commons.rng.sampling.DiscreteProbabilityCollectionSampler;
32 import org.apache.commons.math4.legacy.core.MathArrays;
33 import org.apache.commons.math4.legacy.core.Pair;
34
35 /**
36 * <p>A generic implementation of a
37 * <a href="http://en.wikipedia.org/wiki/Probability_distribution#Discrete_probability_distribution">
38 * discrete probability distribution (Wikipedia)</a> over a finite sample space,
39 * based on an enumerated list of <value, probability> pairs. Input probabilities must all be non-negative,
40 * but zero values are allowed and their sum does not have to equal one. Constructors will normalize input
41 * probabilities to make them sum to one.</p>
42 *
43 * <p>The list of <value, probability> pairs does not, strictly speaking, have to be a function and it can
44 * contain null values. The pmf created by the constructor will combine probabilities of equal values and
45 * will treat null values as equal. For example, if the list of pairs <"dog", 0.2>, <null, 0.1>,
46 * <"pig", 0.2>, <"dog", 0.1>, <null, 0.4> is provided to the constructor, the resulting
47 * pmf will assign mass of 0.5 to null, 0.3 to "dog" and 0.2 to pig.</p>
48 *
49 * @param <T> type of the elements in the sample space.
50 * @since 3.2
51 */
52 public class EnumeratedDistribution<T> {
53 /**
54 * List of random variable values.
55 */
56 private final List<T> singletons;
57 /**
58 * Probabilities of respective random variable values. For i = 0, ..., singletons.size() - 1,
59 * probability[i] is the probability that a random variable following this distribution takes
60 * the value singletons[i].
61 */
62 private final double[] probabilities;
63 /**
64 * Cumulative probabilities, cached to speed up sampling.
65 */
66 private final double[] cumulativeProbabilities;
67
68 /**
69 * Create an enumerated distribution using the given random number generator
70 * and probability mass function enumeration.
71 *
72 * @param pmf probability mass function enumerated as a list of
73 * {@code <T, probability>} pairs.
74 * @throws NotPositiveException if any of the probabilities are negative.
75 * @throws NotFiniteNumberException if any of the probabilities are infinite.
76 * @throws NotANumberException if any of the probabilities are NaN.
77 * @throws MathArithmeticException all of the probabilities are 0.
78 */
79 public EnumeratedDistribution(final List<Pair<T, Double>> pmf)
80 throws NotPositiveException,
81 MathArithmeticException,
82 NotFiniteNumberException,
83 NotANumberException {
84 singletons = new ArrayList<>(pmf.size());
85 final double[] probs = new double[pmf.size()];
86 int count = 0;
87 for (Pair<T, Double> sample : pmf) {
88 singletons.add(sample.getKey());
89 final double p = sample.getValue();
90 if (p < 0) {
91 throw new NotPositiveException(sample.getValue());
92 }
93 if (Double.isInfinite(p)) {
94 throw new NotFiniteNumberException(p);
95 }
96 if (Double.isNaN(p)) {
97 throw new NotANumberException();
98 }
99 probs[count++] = p;
100 }
101
102 probabilities = MathArrays.normalizeArray(probs, 1.0);
103
104 cumulativeProbabilities = new double[probabilities.length];
105 double sum = 0;
106 for (int i = 0; i < probabilities.length; i++) {
107 sum += probabilities[i];
108 cumulativeProbabilities[i] = sum;
109 }
110 }
111
112 /**
113 * <p>For a random variable {@code X} whose values are distributed according to
114 * this distribution, this method returns {@code P(X = x)}. In other words,
115 * this method represents the probability mass function (PMF) for the
116 * distribution.</p>
117 *
118 * <p>Note that if {@code x1} and {@code x2} satisfy {@code x1.equals(x2)},
119 * or both are null, then {@code probability(x1) = probability(x2)}.</p>
120 *
121 * @param x the point at which the PMF is evaluated
122 * @return the value of the probability mass function at {@code x}
123 */
124 double probability(final T x) {
125 double probability = 0;
126
127 for (int i = 0; i < probabilities.length; i++) {
128 if ((x == null && singletons.get(i) == null) ||
129 (x != null && x.equals(singletons.get(i)))) {
130 probability += probabilities[i];
131 }
132 }
133
134 return probability;
135 }
136
137 /**
138 * <p>Return the probability mass function as a list of <value, probability> pairs.</p>
139 *
140 * <p>Note that if duplicate and / or null values were provided to the constructor
141 * when creating this EnumeratedDistribution, the returned list will contain these
142 * values. If duplicates values exist, what is returned will not represent
143 * a pmf (i.e., it is up to the caller to consolidate duplicate mass points).</p>
144 *
145 * @return the probability mass function.
146 */
147 public List<Pair<T, Double>> getPmf() {
148 final List<Pair<T, Double>> samples = new ArrayList<>(probabilities.length);
149
150 for (int i = 0; i < probabilities.length; i++) {
151 samples.add(new Pair<>(singletons.get(i), probabilities[i]));
152 }
153
154 return samples;
155 }
156
157 /**
158 * Creates a {@link Sampler}.
159 *
160 * @param rng Random number generator.
161 * @return a new sampler instance.
162 */
163 public Sampler createSampler(final UniformRandomProvider rng) {
164 return new Sampler(rng);
165 }
166
167 /**
168 * Sampler functionality.
169 *
170 * <ul>
171 * <li>
172 * The cumulative probability distribution is created (and sampled from)
173 * using the input order of the {@link EnumeratedDistribution#EnumeratedDistribution(List)
174 * constructor arguments}: A different input order will create a different
175 * sequence of samples.
176 * The samples will only be reproducible with the same RNG starting from
177 * the same RNG state and the same input order to constructor.
178 * </li>
179 * <li>
180 * The minimum supported probability is 2<sup>-53</sup>.
181 * </li>
182 * </ul>
183 */
184 public class Sampler {
185 /** Underlying sampler. */
186 private final DiscreteProbabilityCollectionSampler<T> sampler;
187
188 /**
189 * @param rng Random number generator.
190 */
191 Sampler(UniformRandomProvider rng) {
192 sampler = new DiscreteProbabilityCollectionSampler<>(rng, singletons, probabilities);
193 }
194
195 /**
196 * Generates a random value sampled from this distribution.
197 *
198 * @return a random value.
199 */
200 public T sample() {
201 return sampler.sample();
202 }
203
204 /**
205 * Generates a random sample from the distribution.
206 *
207 * @param sampleSize the number of random values to generate.
208 * @return an array representing the random sample.
209 * @throws NotStrictlyPositiveException if {@code sampleSize} is not
210 * positive.
211 */
212 public Object[] sample(int sampleSize) throws NotStrictlyPositiveException {
213 if (sampleSize <= 0) {
214 throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES,
215 sampleSize);
216 }
217
218 final Object[] out = new Object[sampleSize];
219
220 for (int i = 0; i < sampleSize; i++) {
221 out[i] = sample();
222 }
223
224 return out;
225 }
226
227 /**
228 * Generates a random sample from the distribution.
229 * <p>
230 * If the requested samples fit in the specified array, it is returned
231 * therein. Otherwise, a new array is allocated with the runtime type of
232 * the specified array and the size of this collection.
233 *
234 * @param sampleSize the number of random values to generate.
235 * @param array the array to populate.
236 * @return an array representing the random sample.
237 * @throws NotStrictlyPositiveException if {@code sampleSize} is not positive.
238 * @throws NullArgumentException if {@code array} is null
239 */
240 public T[] sample(int sampleSize, final T[] array) throws NotStrictlyPositiveException {
241 if (sampleSize <= 0) {
242 throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES, sampleSize);
243 }
244
245 if (array == null) {
246 throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY);
247 }
248
249 T[] out;
250 if (array.length < sampleSize) {
251 @SuppressWarnings("unchecked") // safe as both are of type T
252 final T[] unchecked = (T[]) Array.newInstance(array.getClass().getComponentType(), sampleSize);
253 out = unchecked;
254 } else {
255 out = array;
256 }
257
258 for (int i = 0; i < sampleSize; i++) {
259 out[i] = sample();
260 }
261
262 return out;
263 }
264 }
265 }