001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.math4.legacy.distribution;
018
019import java.lang.reflect.Array;
020import java.util.ArrayList;
021import java.util.List;
022
023import org.apache.commons.math4.legacy.exception.MathArithmeticException;
024import org.apache.commons.math4.legacy.exception.NotANumberException;
025import org.apache.commons.math4.legacy.exception.NotFiniteNumberException;
026import org.apache.commons.math4.legacy.exception.NotPositiveException;
027import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
028import org.apache.commons.math4.legacy.exception.NullArgumentException;
029import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
030import org.apache.commons.rng.UniformRandomProvider;
031import org.apache.commons.rng.sampling.DiscreteProbabilityCollectionSampler;
032import org.apache.commons.math4.legacy.core.MathArrays;
033import org.apache.commons.math4.legacy.core.Pair;
034
035/**
036 * <p>A generic implementation of a
037 * <a href="http://en.wikipedia.org/wiki/Probability_distribution#Discrete_probability_distribution">
038 * discrete probability distribution (Wikipedia)</a> over a finite sample space,
039 * based on an enumerated list of &lt;value, probability&gt; pairs.  Input probabilities must all be non-negative,
040 * but zero values are allowed and their sum does not have to equal one. Constructors will normalize input
041 * probabilities to make them sum to one.</p>
042 *
043 * <p>The list of &lt;value, probability&gt; pairs does not, strictly speaking, have to be a function and it can
044 * contain null values.  The pmf created by the constructor will combine probabilities of equal values and
045 * will treat null values as equal.  For example, if the list of pairs &lt;"dog", 0.2&gt;, &lt;null, 0.1&gt;,
046 * &lt;"pig", 0.2&gt;, &lt;"dog", 0.1&gt;, &lt;null, 0.4&gt; is provided to the constructor, the resulting
047 * pmf will assign mass of 0.5 to null, 0.3 to "dog" and 0.2 to pig.</p>
048 *
049 * @param <T> type of the elements in the sample space.
050 * @since 3.2
051 */
052public class EnumeratedDistribution<T> {
053    /**
054     * List of random variable values.
055     */
056    private final List<T> singletons;
057    /**
058     * Probabilities of respective random variable values. For i = 0, ..., singletons.size() - 1,
059     * probability[i] is the probability that a random variable following this distribution takes
060     * the value singletons[i].
061     */
062    private final double[] probabilities;
063    /**
064     * Cumulative probabilities, cached to speed up sampling.
065     */
066    private final double[] cumulativeProbabilities;
067
068    /**
069     * Create an enumerated distribution using the given random number generator
070     * and probability mass function enumeration.
071     *
072     * @param pmf probability mass function enumerated as a list of
073     * {@code <T, probability>} pairs.
074     * @throws NotPositiveException if any of the probabilities are negative.
075     * @throws NotFiniteNumberException if any of the probabilities are infinite.
076     * @throws NotANumberException if any of the probabilities are NaN.
077     * @throws MathArithmeticException all of the probabilities are 0.
078     */
079    public EnumeratedDistribution(final List<Pair<T, Double>> pmf)
080        throws NotPositiveException,
081               MathArithmeticException,
082               NotFiniteNumberException,
083               NotANumberException {
084        singletons = new ArrayList<>(pmf.size());
085        final double[] probs = new double[pmf.size()];
086        int count = 0;
087        for (Pair<T, Double> sample : pmf) {
088            singletons.add(sample.getKey());
089            final double p = sample.getValue();
090            if (p < 0) {
091                throw new NotPositiveException(sample.getValue());
092            }
093            if (Double.isInfinite(p)) {
094                throw new NotFiniteNumberException(p);
095            }
096            if (Double.isNaN(p)) {
097                throw new NotANumberException();
098            }
099            probs[count++] = p;
100        }
101
102        probabilities = MathArrays.normalizeArray(probs, 1.0);
103
104        cumulativeProbabilities = new double[probabilities.length];
105        double sum = 0;
106        for (int i = 0; i < probabilities.length; i++) {
107            sum += probabilities[i];
108            cumulativeProbabilities[i] = sum;
109        }
110    }
111
112    /**
113     * <p>For a random variable {@code X} whose values are distributed according to
114     * this distribution, this method returns {@code P(X = x)}. In other words,
115     * this method represents the probability mass function (PMF) for the
116     * distribution.</p>
117     *
118     * <p>Note that if {@code x1} and {@code x2} satisfy {@code x1.equals(x2)},
119     * or both are null, then {@code probability(x1) = probability(x2)}.</p>
120     *
121     * @param x the point at which the PMF is evaluated
122     * @return the value of the probability mass function at {@code x}
123     */
124    double probability(final T x) {
125        double probability = 0;
126
127        for (int i = 0; i < probabilities.length; i++) {
128            if ((x == null && singletons.get(i) == null) ||
129                (x != null && x.equals(singletons.get(i)))) {
130                probability += probabilities[i];
131            }
132        }
133
134        return probability;
135    }
136
137    /**
138     * <p>Return the probability mass function as a list of &lt;value, probability&gt; pairs.</p>
139     *
140     * <p>Note that if duplicate and / or null values were provided to the constructor
141     * when creating this EnumeratedDistribution, the returned list will contain these
142     * values.  If duplicates values exist, what is returned will not represent
143     * a pmf (i.e., it is up to the caller to consolidate duplicate mass points).</p>
144     *
145     * @return the probability mass function.
146     */
147    public List<Pair<T, Double>> getPmf() {
148        final List<Pair<T, Double>> samples = new ArrayList<>(probabilities.length);
149
150        for (int i = 0; i < probabilities.length; i++) {
151            samples.add(new Pair<>(singletons.get(i), probabilities[i]));
152        }
153
154        return samples;
155    }
156
157    /**
158     * Creates a {@link Sampler}.
159     *
160     * @param rng Random number generator.
161     * @return a new sampler instance.
162     */
163    public Sampler createSampler(final UniformRandomProvider rng) {
164        return new Sampler(rng);
165    }
166
167    /**
168     * Sampler functionality.
169     *
170     * <ul>
171     *  <li>
172     *   The cumulative probability distribution is created (and sampled from)
173     *   using the input order of the {@link EnumeratedDistribution#EnumeratedDistribution(List)
174     *   constructor arguments}: A different input order will create a different
175     *   sequence of samples.
176     *   The samples will only be reproducible with the same RNG starting from
177     *   the same RNG state and the same input order to constructor.
178     *  </li>
179     *  <li>
180     *   The minimum supported probability is 2<sup>-53</sup>.
181     *  </li>
182     * </ul>
183     */
184    public class Sampler {
185        /** Underlying sampler. */
186        private final DiscreteProbabilityCollectionSampler<T> sampler;
187
188        /**
189         * @param rng Random number generator.
190         */
191        Sampler(UniformRandomProvider rng) {
192            sampler = new DiscreteProbabilityCollectionSampler<>(rng, singletons, probabilities);
193        }
194
195        /**
196         * Generates a random value sampled from this distribution.
197         *
198         * @return a random value.
199         */
200        public T sample() {
201            return sampler.sample();
202        }
203
204        /**
205         * Generates a random sample from the distribution.
206         *
207         * @param sampleSize the number of random values to generate.
208         * @return an array representing the random sample.
209         * @throws NotStrictlyPositiveException if {@code sampleSize} is not
210         * positive.
211         */
212        public Object[] sample(int sampleSize) throws NotStrictlyPositiveException {
213            if (sampleSize <= 0) {
214                throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES,
215                                                       sampleSize);
216            }
217
218            final Object[] out = new Object[sampleSize];
219
220            for (int i = 0; i < sampleSize; i++) {
221                out[i] = sample();
222            }
223
224            return out;
225        }
226
227        /**
228         * Generates a random sample from the distribution.
229         * <p>
230         * If the requested samples fit in the specified array, it is returned
231         * therein. Otherwise, a new array is allocated with the runtime type of
232         * the specified array and the size of this collection.
233         *
234         * @param sampleSize the number of random values to generate.
235         * @param array the array to populate.
236         * @return an array representing the random sample.
237         * @throws NotStrictlyPositiveException if {@code sampleSize} is not positive.
238         * @throws NullArgumentException if {@code array} is null
239         */
240        public T[] sample(int sampleSize, final T[] array) throws NotStrictlyPositiveException {
241            if (sampleSize <= 0) {
242                throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES, sampleSize);
243            }
244
245            if (array == null) {
246                throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY);
247            }
248
249            T[] out;
250            if (array.length < sampleSize) {
251                @SuppressWarnings("unchecked") // safe as both are of type T
252                final T[] unchecked = (T[]) Array.newInstance(array.getClass().getComponentType(), sampleSize);
253                out = unchecked;
254            } else {
255                out = array;
256            }
257
258            for (int i = 0; i < sampleSize; i++) {
259                out[i] = sample();
260            }
261
262            return out;
263        }
264    }
265}