001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.math4.distribution;
018
019import java.io.Serializable;
020import java.lang.reflect.Array;
021import java.util.ArrayList;
022import java.util.List;
023
024import org.apache.commons.math4.exception.MathArithmeticException;
025import org.apache.commons.math4.exception.NotANumberException;
026import org.apache.commons.math4.exception.NotFiniteNumberException;
027import org.apache.commons.math4.exception.NotPositiveException;
028import org.apache.commons.math4.exception.NotStrictlyPositiveException;
029import org.apache.commons.math4.exception.NullArgumentException;
030import org.apache.commons.math4.exception.util.LocalizedFormats;
031import org.apache.commons.rng.UniformRandomProvider;
032import org.apache.commons.rng.sampling.DiscreteProbabilityCollectionSampler;
033import org.apache.commons.math4.util.MathArrays;
034import org.apache.commons.math4.util.Pair;
035
036/**
037 * <p>A generic implementation of a
038 * <a href="http://en.wikipedia.org/wiki/Probability_distribution#Discrete_probability_distribution">
039 * discrete probability distribution (Wikipedia)</a> over a finite sample space,
040 * based on an enumerated list of &lt;value, probability&gt; pairs.  Input probabilities must all be non-negative,
041 * but zero values are allowed and their sum does not have to equal one. Constructors will normalize input
042 * probabilities to make them sum to one.</p>
043 *
044 * <p>The list of &lt;value, probability&gt; pairs does not, strictly speaking, have to be a function and it can
045 * contain null values.  The pmf created by the constructor will combine probabilities of equal values and
046 * will treat null values as equal.  For example, if the list of pairs &lt;"dog", 0.2&gt;, &lt;null, 0.1&gt;,
047 * &lt;"pig", 0.2&gt;, &lt;"dog", 0.1&gt;, &lt;null, 0.4&gt; is provided to the constructor, the resulting
048 * pmf will assign mass of 0.5 to null, 0.3 to "dog" and 0.2 to pig.</p>
049 *
050 * @param <T> type of the elements in the sample space.
051 * @since 3.2
052 */
053public class EnumeratedDistribution<T> implements Serializable {
054    /** Serializable UID. */
055    private static final long serialVersionUID = 20160319L;
056    /**
057     * List of random variable values.
058     */
059    private final List<T> singletons;
060    /**
061     * Probabilities of respective random variable values. For i = 0, ..., singletons.size() - 1,
062     * probability[i] is the probability that a random variable following this distribution takes
063     * the value singletons[i].
064     */
065    private final double[] probabilities;
066    /**
067     * Cumulative probabilities, cached to speed up sampling.
068     */
069    private final double[] cumulativeProbabilities;
070
071    /**
072     * Create an enumerated distribution using the given random number generator
073     * and probability mass function enumeration.
074     *
075     * @param pmf probability mass function enumerated as a list of
076     * {@code <T, probability>} pairs.
077     * @throws NotPositiveException if any of the probabilities are negative.
078     * @throws NotFiniteNumberException if any of the probabilities are infinite.
079     * @throws NotANumberException if any of the probabilities are NaN.
080     * @throws MathArithmeticException all of the probabilities are 0.
081     */
082    public EnumeratedDistribution(final List<Pair<T, Double>> pmf)
083        throws NotPositiveException,
084               MathArithmeticException,
085               NotFiniteNumberException,
086               NotANumberException {
087        singletons = new ArrayList<>(pmf.size());
088        final double[] probs = new double[pmf.size()];
089        int count = 0;
090        for (Pair<T, Double> sample : pmf) {
091            singletons.add(sample.getKey());
092            final double p = sample.getValue();
093            if (p < 0) {
094                throw new NotPositiveException(sample.getValue());
095            }
096            if (Double.isInfinite(p)) {
097                throw new NotFiniteNumberException(p);
098            }
099            if (Double.isNaN(p)) {
100                throw new NotANumberException();
101            }
102            probs[count++] = p;
103        }
104
105        probabilities = MathArrays.normalizeArray(probs, 1.0);
106
107        cumulativeProbabilities = new double[probabilities.length];
108        double sum = 0;
109        for (int i = 0; i < probabilities.length; i++) {
110            sum += probabilities[i];
111            cumulativeProbabilities[i] = sum;
112        }
113    }
114
115    /**
116     * <p>For a random variable {@code X} whose values are distributed according to
117     * this distribution, this method returns {@code P(X = x)}. In other words,
118     * this method represents the probability mass function (PMF) for the
119     * distribution.</p>
120     *
121     * <p>Note that if {@code x1} and {@code x2} satisfy {@code x1.equals(x2)},
122     * or both are null, then {@code probability(x1) = probability(x2)}.</p>
123     *
124     * @param x the point at which the PMF is evaluated
125     * @return the value of the probability mass function at {@code x}
126     */
127    double probability(final T x) {
128        double probability = 0;
129
130        for (int i = 0; i < probabilities.length; i++) {
131            if ((x == null && singletons.get(i) == null) ||
132                (x != null && x.equals(singletons.get(i)))) {
133                probability += probabilities[i];
134            }
135        }
136
137        return probability;
138    }
139
140    /**
141     * <p>Return the probability mass function as a list of &lt;value, probability&gt; pairs.</p>
142     *
143     * <p>Note that if duplicate and / or null values were provided to the constructor
144     * when creating this EnumeratedDistribution, the returned list will contain these
145     * values.  If duplicates values exist, what is returned will not represent
146     * a pmf (i.e., it is up to the caller to consolidate duplicate mass points).</p>
147     *
148     * @return the probability mass function.
149     */
150    public List<Pair<T, Double>> getPmf() {
151        final List<Pair<T, Double>> samples = new ArrayList<>(probabilities.length);
152
153        for (int i = 0; i < probabilities.length; i++) {
154            samples.add(new Pair<>(singletons.get(i), probabilities[i]));
155        }
156
157        return samples;
158    }
159
160    /**
161     * Creates a {@link Sampler}.
162     *
163     * @param rng Random number generator.
164     * @return a new sampler instance.
165     */
166    public Sampler createSampler(final UniformRandomProvider rng) {
167        return new Sampler(rng);
168    }
169
170    /**
171     * Sampler functionality.
172     */
173    public class Sampler {
174        /** Underlying sampler. */
175        private final DiscreteProbabilityCollectionSampler<T> sampler;
176
177        /**
178         * @param rng Random number generator.
179         */
180        Sampler(UniformRandomProvider rng) {
181            sampler = new DiscreteProbabilityCollectionSampler<T>(rng, singletons, probabilities);
182        }
183
184        /**
185         * Generates a random value sampled from this distribution.
186         *
187         * @return a random value.
188         */
189        public T sample() {
190            return sampler.sample();
191        }
192
193        /**
194         * Generates a random sample from the distribution.
195         *
196         * @param sampleSize the number of random values to generate.
197         * @return an array representing the random sample.
198         * @throws NotStrictlyPositiveException if {@code sampleSize} is not
199         * positive.
200         */
201        public Object[] sample(int sampleSize) throws NotStrictlyPositiveException {
202            if (sampleSize <= 0) {
203                throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES,
204                                                       sampleSize);
205            }
206
207            final Object[] out = new Object[sampleSize];
208
209            for (int i = 0; i < sampleSize; i++) {
210                out[i] = sample();
211            }
212
213            return out;
214        }
215
216        /**
217         * Generates a random sample from the distribution.
218         * <p>
219         * If the requested samples fit in the specified array, it is returned
220         * therein. Otherwise, a new array is allocated with the runtime type of
221         * the specified array and the size of this collection.
222         *
223         * @param sampleSize the number of random values to generate.
224         * @param array the array to populate.
225         * @return an array representing the random sample.
226         * @throws NotStrictlyPositiveException if {@code sampleSize} is not positive.
227         * @throws NullArgumentException if {@code array} is null
228         */
229        public T[] sample(int sampleSize, final T[] array) throws NotStrictlyPositiveException {
230            if (sampleSize <= 0) {
231                throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES, sampleSize);
232            }
233
234            if (array == null) {
235                throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY);
236            }
237
238            T[] out;
239            if (array.length < sampleSize) {
240                @SuppressWarnings("unchecked") // safe as both are of type T
241                final T[] unchecked = (T[]) Array.newInstance(array.getClass().getComponentType(), sampleSize);
242                out = unchecked;
243            } else {
244                out = array;
245            }
246
247            for (int i = 0; i < sampleSize; i++) {
248                out[i] = sample();
249            }
250
251            return out;
252        }
253    }
254}