View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.math4.legacy.distribution;
18  
19  import java.lang.reflect.Array;
20  import java.util.ArrayList;
21  import java.util.List;
22  
23  import org.apache.commons.math4.legacy.exception.MathArithmeticException;
24  import org.apache.commons.math4.legacy.exception.NotANumberException;
25  import org.apache.commons.math4.legacy.exception.NotFiniteNumberException;
26  import org.apache.commons.math4.legacy.exception.NotPositiveException;
27  import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
28  import org.apache.commons.math4.legacy.exception.NullArgumentException;
29  import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
30  import org.apache.commons.rng.UniformRandomProvider;
31  import org.apache.commons.rng.sampling.DiscreteProbabilityCollectionSampler;
32  import org.apache.commons.math4.legacy.core.MathArrays;
33  import org.apache.commons.math4.legacy.core.Pair;
34  
35  /**
36   * <p>A generic implementation of a
37   * <a href="http://en.wikipedia.org/wiki/Probability_distribution#Discrete_probability_distribution">
38   * discrete probability distribution (Wikipedia)</a> over a finite sample space,
39   * based on an enumerated list of &lt;value, probability&gt; pairs.  Input probabilities must all be non-negative,
40   * but zero values are allowed and their sum does not have to equal one. Constructors will normalize input
41   * probabilities to make them sum to one.</p>
42   *
43   * <p>The list of &lt;value, probability&gt; pairs does not, strictly speaking, have to be a function and it can
44   * contain null values.  The pmf created by the constructor will combine probabilities of equal values and
45   * will treat null values as equal.  For example, if the list of pairs &lt;"dog", 0.2&gt;, &lt;null, 0.1&gt;,
46   * &lt;"pig", 0.2&gt;, &lt;"dog", 0.1&gt;, &lt;null, 0.4&gt; is provided to the constructor, the resulting
47   * pmf will assign mass of 0.5 to null, 0.3 to "dog" and 0.2 to pig.</p>
48   *
49   * @param <T> type of the elements in the sample space.
50   * @since 3.2
51   */
52  public class EnumeratedDistribution<T> {
53      /**
54       * List of random variable values.
55       */
56      private final List<T> singletons;
57      /**
58       * Probabilities of respective random variable values. For i = 0, ..., singletons.size() - 1,
59       * probability[i] is the probability that a random variable following this distribution takes
60       * the value singletons[i].
61       */
62      private final double[] probabilities;
63      /**
64       * Cumulative probabilities, cached to speed up sampling.
65       */
66      private final double[] cumulativeProbabilities;
67  
68      /**
69       * Create an enumerated distribution using the given random number generator
70       * and probability mass function enumeration.
71       *
72       * @param pmf probability mass function enumerated as a list of
73       * {@code <T, probability>} pairs.
74       * @throws NotPositiveException if any of the probabilities are negative.
75       * @throws NotFiniteNumberException if any of the probabilities are infinite.
76       * @throws NotANumberException if any of the probabilities are NaN.
77       * @throws MathArithmeticException all of the probabilities are 0.
78       */
79      public EnumeratedDistribution(final List<Pair<T, Double>> pmf)
80          throws NotPositiveException,
81                 MathArithmeticException,
82                 NotFiniteNumberException,
83                 NotANumberException {
84          singletons = new ArrayList<>(pmf.size());
85          final double[] probs = new double[pmf.size()];
86          int count = 0;
87          for (Pair<T, Double> sample : pmf) {
88              singletons.add(sample.getKey());
89              final double p = sample.getValue();
90              if (p < 0) {
91                  throw new NotPositiveException(sample.getValue());
92              }
93              if (Double.isInfinite(p)) {
94                  throw new NotFiniteNumberException(p);
95              }
96              if (Double.isNaN(p)) {
97                  throw new NotANumberException();
98              }
99              probs[count++] = p;
100         }
101 
102         probabilities = MathArrays.normalizeArray(probs, 1.0);
103 
104         cumulativeProbabilities = new double[probabilities.length];
105         double sum = 0;
106         for (int i = 0; i < probabilities.length; i++) {
107             sum += probabilities[i];
108             cumulativeProbabilities[i] = sum;
109         }
110     }
111 
112     /**
113      * <p>For a random variable {@code X} whose values are distributed according to
114      * this distribution, this method returns {@code P(X = x)}. In other words,
115      * this method represents the probability mass function (PMF) for the
116      * distribution.</p>
117      *
118      * <p>Note that if {@code x1} and {@code x2} satisfy {@code x1.equals(x2)},
119      * or both are null, then {@code probability(x1) = probability(x2)}.</p>
120      *
121      * @param x the point at which the PMF is evaluated
122      * @return the value of the probability mass function at {@code x}
123      */
124     double probability(final T x) {
125         double probability = 0;
126 
127         for (int i = 0; i < probabilities.length; i++) {
128             if ((x == null && singletons.get(i) == null) ||
129                 (x != null && x.equals(singletons.get(i)))) {
130                 probability += probabilities[i];
131             }
132         }
133 
134         return probability;
135     }
136 
137     /**
138      * <p>Return the probability mass function as a list of &lt;value, probability&gt; pairs.</p>
139      *
140      * <p>Note that if duplicate and / or null values were provided to the constructor
141      * when creating this EnumeratedDistribution, the returned list will contain these
142      * values.  If duplicates values exist, what is returned will not represent
143      * a pmf (i.e., it is up to the caller to consolidate duplicate mass points).</p>
144      *
145      * @return the probability mass function.
146      */
147     public List<Pair<T, Double>> getPmf() {
148         final List<Pair<T, Double>> samples = new ArrayList<>(probabilities.length);
149 
150         for (int i = 0; i < probabilities.length; i++) {
151             samples.add(new Pair<>(singletons.get(i), probabilities[i]));
152         }
153 
154         return samples;
155     }
156 
157     /**
158      * Creates a {@link Sampler}.
159      *
160      * @param rng Random number generator.
161      * @return a new sampler instance.
162      */
163     public Sampler createSampler(final UniformRandomProvider rng) {
164         return new Sampler(rng);
165     }
166 
167     /**
168      * Sampler functionality.
169      *
170      * <ul>
171      *  <li>
172      *   The cumulative probability distribution is created (and sampled from)
173      *   using the input order of the {@link EnumeratedDistribution#EnumeratedDistribution(List)
174      *   constructor arguments}: A different input order will create a different
175      *   sequence of samples.
176      *   The samples will only be reproducible with the same RNG starting from
177      *   the same RNG state and the same input order to constructor.
178      *  </li>
179      *  <li>
180      *   The minimum supported probability is 2<sup>-53</sup>.
181      *  </li>
182      * </ul>
183      */
184     public class Sampler {
185         /** Underlying sampler. */
186         private final DiscreteProbabilityCollectionSampler<T> sampler;
187 
188         /**
189          * @param rng Random number generator.
190          */
191         Sampler(UniformRandomProvider rng) {
192             sampler = new DiscreteProbabilityCollectionSampler<>(rng, singletons, probabilities);
193         }
194 
195         /**
196          * Generates a random value sampled from this distribution.
197          *
198          * @return a random value.
199          */
200         public T sample() {
201             return sampler.sample();
202         }
203 
204         /**
205          * Generates a random sample from the distribution.
206          *
207          * @param sampleSize the number of random values to generate.
208          * @return an array representing the random sample.
209          * @throws NotStrictlyPositiveException if {@code sampleSize} is not
210          * positive.
211          */
212         public Object[] sample(int sampleSize) throws NotStrictlyPositiveException {
213             if (sampleSize <= 0) {
214                 throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES,
215                                                        sampleSize);
216             }
217 
218             final Object[] out = new Object[sampleSize];
219 
220             for (int i = 0; i < sampleSize; i++) {
221                 out[i] = sample();
222             }
223 
224             return out;
225         }
226 
227         /**
228          * Generates a random sample from the distribution.
229          * <p>
230          * If the requested samples fit in the specified array, it is returned
231          * therein. Otherwise, a new array is allocated with the runtime type of
232          * the specified array and the size of this collection.
233          *
234          * @param sampleSize the number of random values to generate.
235          * @param array the array to populate.
236          * @return an array representing the random sample.
237          * @throws NotStrictlyPositiveException if {@code sampleSize} is not positive.
238          * @throws NullArgumentException if {@code array} is null
239          */
240         public T[] sample(int sampleSize, final T[] array) throws NotStrictlyPositiveException {
241             if (sampleSize <= 0) {
242                 throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES, sampleSize);
243             }
244 
245             if (array == null) {
246                 throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY);
247             }
248 
249             T[] out;
250             if (array.length < sampleSize) {
251                 @SuppressWarnings("unchecked") // safe as both are of type T
252                 final T[] unchecked = (T[]) Array.newInstance(array.getClass().getComponentType(), sampleSize);
253                 out = unchecked;
254             } else {
255                 out = array;
256             }
257 
258             for (int i = 0; i < sampleSize; i++) {
259                 out[i] = sample();
260             }
261 
262             return out;
263         }
264     }
265 }