001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.math3.distribution;
018
019import java.io.Serializable;
020import java.lang.reflect.Array;
021import java.util.ArrayList;
022import java.util.Arrays;
023import java.util.List;
024
025import org.apache.commons.math3.exception.MathArithmeticException;
026import org.apache.commons.math3.exception.NotANumberException;
027import org.apache.commons.math3.exception.NotFiniteNumberException;
028import org.apache.commons.math3.exception.NotPositiveException;
029import org.apache.commons.math3.exception.NotStrictlyPositiveException;
030import org.apache.commons.math3.exception.NullArgumentException;
031import org.apache.commons.math3.exception.util.LocalizedFormats;
032import org.apache.commons.math3.random.RandomGenerator;
033import org.apache.commons.math3.random.Well19937c;
034import org.apache.commons.math3.util.MathArrays;
035import org.apache.commons.math3.util.Pair;
036
037/**
038 * <p>A generic implementation of a
039 * <a href="http://en.wikipedia.org/wiki/Probability_distribution#Discrete_probability_distribution">
040 * discrete probability distribution (Wikipedia)</a> over a finite sample space,
041 * based on an enumerated list of &lt;value, probability&gt; pairs.  Input probabilities must all be non-negative,
042 * but zero values are allowed and their sum does not have to equal one. Constructors will normalize input
043 * probabilities to make them sum to one.</p>
044 *
045 * <p>The list of <value, probability> pairs does not, strictly speaking, have to be a function and it can
046 * contain null values.  The pmf created by the constructor will combine probabilities of equal values and
047 * will treat null values as equal.  For example, if the list of pairs &lt;"dog", 0.2&gt;, &lt;null, 0.1&gt;,
048 * &lt;"pig", 0.2&gt;, &lt;"dog", 0.1&gt;, &lt;null, 0.4&gt; is provided to the constructor, the resulting
049 * pmf will assign mass of 0.5 to null, 0.3 to "dog" and 0.2 to null.</p>
050 *
051 * @param <T> type of the elements in the sample space.
052 * @since 3.2
053 */
054public class EnumeratedDistribution<T> implements Serializable {
055
056    /** Serializable UID. */
057    private static final long serialVersionUID = 20123308L;
058
059    /**
060     * RNG instance used to generate samples from the distribution.
061     */
062    protected final RandomGenerator random;
063
064    /**
065     * List of random variable values.
066     */
067    private final List<T> singletons;
068
069    /**
070     * Probabilities of respective random variable values. For i = 0, ..., singletons.size() - 1,
071     * probability[i] is the probability that a random variable following this distribution takes
072     * the value singletons[i].
073     */
074    private final double[] probabilities;
075
076    /**
077     * Cumulative probabilities, cached to speed up sampling.
078     */
079    private final double[] cumulativeProbabilities;
080
081    /**
082     * Create an enumerated distribution using the given probability mass function
083     * enumeration.
084     * <p>
085     * <b>Note:</b> this constructor will implicitly create an instance of
086     * {@link Well19937c} as random generator to be used for sampling only (see
087     * {@link #sample()} and {@link #sample(int)}). In case no sampling is
088     * needed for the created distribution, it is advised to pass {@code null}
089     * as random generator via the appropriate constructors to avoid the
090     * additional initialisation overhead.
091     *
092     * @param pmf probability mass function enumerated as a list of <T, probability>
093     * pairs.
094     * @throws NotPositiveException if any of the probabilities are negative.
095     * @throws NotFiniteNumberException if any of the probabilities are infinite.
096     * @throws NotANumberException if any of the probabilities are NaN.
097     * @throws MathArithmeticException all of the probabilities are 0.
098     */
099    public EnumeratedDistribution(final List<Pair<T, Double>> pmf)
100        throws NotPositiveException, MathArithmeticException, NotFiniteNumberException, NotANumberException {
101        this(new Well19937c(), pmf);
102    }
103
104    /**
105     * Create an enumerated distribution using the given random number generator
106     * and probability mass function enumeration.
107     *
108     * @param rng random number generator.
109     * @param pmf probability mass function enumerated as a list of <T, probability>
110     * pairs.
111     * @throws NotPositiveException if any of the probabilities are negative.
112     * @throws NotFiniteNumberException if any of the probabilities are infinite.
113     * @throws NotANumberException if any of the probabilities are NaN.
114     * @throws MathArithmeticException all of the probabilities are 0.
115     */
116    public EnumeratedDistribution(final RandomGenerator rng, final List<Pair<T, Double>> pmf)
117        throws NotPositiveException, MathArithmeticException, NotFiniteNumberException, NotANumberException {
118        random = rng;
119
120        singletons = new ArrayList<T>(pmf.size());
121        final double[] probs = new double[pmf.size()];
122
123        for (int i = 0; i < pmf.size(); i++) {
124            final Pair<T, Double> sample = pmf.get(i);
125            singletons.add(sample.getKey());
126            final double p = sample.getValue();
127            if (p < 0) {
128                throw new NotPositiveException(sample.getValue());
129            }
130            if (Double.isInfinite(p)) {
131                throw new NotFiniteNumberException(p);
132            }
133            if (Double.isNaN(p)) {
134                throw new NotANumberException();
135            }
136            probs[i] = p;
137        }
138
139        probabilities = MathArrays.normalizeArray(probs, 1.0);
140
141        cumulativeProbabilities = new double[probabilities.length];
142        double sum = 0;
143        for (int i = 0; i < probabilities.length; i++) {
144            sum += probabilities[i];
145            cumulativeProbabilities[i] = sum;
146        }
147    }
148
149    /**
150     * Reseed the random generator used to generate samples.
151     *
152     * @param seed the new seed
153     */
154    public void reseedRandomGenerator(long seed) {
155        random.setSeed(seed);
156    }
157
158    /**
159     * <p>For a random variable {@code X} whose values are distributed according to
160     * this distribution, this method returns {@code P(X = x)}. In other words,
161     * this method represents the probability mass function (PMF) for the
162     * distribution.</p>
163     *
164     * <p>Note that if {@code x1} and {@code x2} satisfy {@code x1.equals(x2)},
165     * or both are null, then {@code probability(x1) = probability(x2)}.</p>
166     *
167     * @param x the point at which the PMF is evaluated
168     * @return the value of the probability mass function at {@code x}
169     */
170    double probability(final T x) {
171        double probability = 0;
172
173        for (int i = 0; i < probabilities.length; i++) {
174            if ((x == null && singletons.get(i) == null) ||
175                (x != null && x.equals(singletons.get(i)))) {
176                probability += probabilities[i];
177            }
178        }
179
180        return probability;
181    }
182
183    /**
184     * <p>Return the probability mass function as a list of <value, probability> pairs.</p>
185     *
186     * <p>Note that if duplicate and / or null values were provided to the constructor
187     * when creating this EnumeratedDistribution, the returned list will contain these
188     * values.  If duplicates values exist, what is returned will not represent
189     * a pmf (i.e., it is up to the caller to consolidate duplicate mass points).</p>
190     *
191     * @return the probability mass function.
192     */
193    public List<Pair<T, Double>> getPmf() {
194        final List<Pair<T, Double>> samples = new ArrayList<Pair<T, Double>>(probabilities.length);
195
196        for (int i = 0; i < probabilities.length; i++) {
197            samples.add(new Pair<T, Double>(singletons.get(i), probabilities[i]));
198        }
199
200        return samples;
201    }
202
203    /**
204     * Generate a random value sampled from this distribution.
205     *
206     * @return a random value.
207     */
208    public T sample() {
209        final double randomValue = random.nextDouble();
210
211        int index = Arrays.binarySearch(cumulativeProbabilities, randomValue);
212        if (index < 0) {
213            index = -index-1;
214        }
215
216        if (index >= 0 &&
217            index < probabilities.length &&
218            randomValue < cumulativeProbabilities[index]) {
219            return singletons.get(index);
220        }
221
222        /* This should never happen, but it ensures we will return a correct
223         * object in case there is some floating point inequality problem
224         * wrt the cumulative probabilities. */
225        return singletons.get(singletons.size() - 1);
226    }
227
228    /**
229     * Generate a random sample from the distribution.
230     *
231     * @param sampleSize the number of random values to generate.
232     * @return an array representing the random sample.
233     * @throws NotStrictlyPositiveException if {@code sampleSize} is not
234     * positive.
235     */
236    public Object[] sample(int sampleSize) throws NotStrictlyPositiveException {
237        if (sampleSize <= 0) {
238            throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES,
239                    sampleSize);
240        }
241
242        final Object[] out = new Object[sampleSize];
243
244        for (int i = 0; i < sampleSize; i++) {
245            out[i] = sample();
246        }
247
248        return out;
249
250    }
251
252    /**
253     * Generate a random sample from the distribution.
254     * <p>
255     * If the requested samples fit in the specified array, it is returned
256     * therein. Otherwise, a new array is allocated with the runtime type of
257     * the specified array and the size of this collection.
258     *
259     * @param sampleSize the number of random values to generate.
260     * @param array the array to populate.
261     * @return an array representing the random sample.
262     * @throws NotStrictlyPositiveException if {@code sampleSize} is not positive.
263     * @throws NullArgumentException if {@code array} is null
264     */
265    public T[] sample(int sampleSize, final T[] array) throws NotStrictlyPositiveException {
266        if (sampleSize <= 0) {
267            throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES, sampleSize);
268        }
269
270        if (array == null) {
271            throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY);
272        }
273
274        T[] out;
275        if (array.length < sampleSize) {
276            @SuppressWarnings("unchecked") // safe as both are of type T
277            final T[] unchecked = (T[]) Array.newInstance(array.getClass().getComponentType(), sampleSize);
278            out = unchecked;
279        } else {
280            out = array;
281        }
282
283        for (int i = 0; i < sampleSize; i++) {
284            out[i] = sample();
285        }
286
287        return out;
288
289    }
290
291}