EnumeratedDistribution.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.math4.legacy.distribution;

  18. import java.lang.reflect.Array;
  19. import java.util.ArrayList;
  20. import java.util.List;

  21. import org.apache.commons.math4.legacy.exception.MathArithmeticException;
  22. import org.apache.commons.math4.legacy.exception.NotANumberException;
  23. import org.apache.commons.math4.legacy.exception.NotFiniteNumberException;
  24. import org.apache.commons.math4.legacy.exception.NotPositiveException;
  25. import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
  26. import org.apache.commons.math4.legacy.exception.NullArgumentException;
  27. import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
  28. import org.apache.commons.rng.UniformRandomProvider;
  29. import org.apache.commons.rng.sampling.DiscreteProbabilityCollectionSampler;
  30. import org.apache.commons.math4.legacy.core.MathArrays;
  31. import org.apache.commons.math4.legacy.core.Pair;

  32. /**
  33.  * <p>A generic implementation of a
  34.  * <a href="http://en.wikipedia.org/wiki/Probability_distribution#Discrete_probability_distribution">
  35.  * discrete probability distribution (Wikipedia)</a> over a finite sample space,
  36.  * based on an enumerated list of &lt;value, probability&gt; pairs.  Input probabilities must all be non-negative,
  37.  * but zero values are allowed and their sum does not have to equal one. Constructors will normalize input
  38.  * probabilities to make them sum to one.</p>
  39.  *
  40.  * <p>The list of &lt;value, probability&gt; pairs does not, strictly speaking, have to be a function and it can
  41.  * contain null values.  The pmf created by the constructor will combine probabilities of equal values and
  42.  * will treat null values as equal.  For example, if the list of pairs &lt;"dog", 0.2&gt;, &lt;null, 0.1&gt;,
  43.  * &lt;"pig", 0.2&gt;, &lt;"dog", 0.1&gt;, &lt;null, 0.4&gt; is provided to the constructor, the resulting
  44.  * pmf will assign mass of 0.5 to null, 0.3 to "dog" and 0.2 to pig.</p>
  45.  *
  46.  * @param <T> type of the elements in the sample space.
  47.  * @since 3.2
  48.  */
  49. public class EnumeratedDistribution<T> {
  50.     /**
  51.      * List of random variable values.
  52.      */
  53.     private final List<T> singletons;
  54.     /**
  55.      * Probabilities of respective random variable values. For i = 0, ..., singletons.size() - 1,
  56.      * probability[i] is the probability that a random variable following this distribution takes
  57.      * the value singletons[i].
  58.      */
  59.     private final double[] probabilities;
  60.     /**
  61.      * Cumulative probabilities, cached to speed up sampling.
  62.      */
  63.     private final double[] cumulativeProbabilities;

  64.     /**
  65.      * Create an enumerated distribution using the given random number generator
  66.      * and probability mass function enumeration.
  67.      *
  68.      * @param pmf probability mass function enumerated as a list of
  69.      * {@code <T, probability>} pairs.
  70.      * @throws NotPositiveException if any of the probabilities are negative.
  71.      * @throws NotFiniteNumberException if any of the probabilities are infinite.
  72.      * @throws NotANumberException if any of the probabilities are NaN.
  73.      * @throws MathArithmeticException all of the probabilities are 0.
  74.      */
  75.     public EnumeratedDistribution(final List<Pair<T, Double>> pmf)
  76.         throws NotPositiveException,
  77.                MathArithmeticException,
  78.                NotFiniteNumberException,
  79.                NotANumberException {
  80.         singletons = new ArrayList<>(pmf.size());
  81.         final double[] probs = new double[pmf.size()];
  82.         int count = 0;
  83.         for (Pair<T, Double> sample : pmf) {
  84.             singletons.add(sample.getKey());
  85.             final double p = sample.getValue();
  86.             if (p < 0) {
  87.                 throw new NotPositiveException(sample.getValue());
  88.             }
  89.             if (Double.isInfinite(p)) {
  90.                 throw new NotFiniteNumberException(p);
  91.             }
  92.             if (Double.isNaN(p)) {
  93.                 throw new NotANumberException();
  94.             }
  95.             probs[count++] = p;
  96.         }

  97.         probabilities = MathArrays.normalizeArray(probs, 1.0);

  98.         cumulativeProbabilities = new double[probabilities.length];
  99.         double sum = 0;
  100.         for (int i = 0; i < probabilities.length; i++) {
  101.             sum += probabilities[i];
  102.             cumulativeProbabilities[i] = sum;
  103.         }
  104.     }

  105.     /**
  106.      * <p>For a random variable {@code X} whose values are distributed according to
  107.      * this distribution, this method returns {@code P(X = x)}. In other words,
  108.      * this method represents the probability mass function (PMF) for the
  109.      * distribution.</p>
  110.      *
  111.      * <p>Note that if {@code x1} and {@code x2} satisfy {@code x1.equals(x2)},
  112.      * or both are null, then {@code probability(x1) = probability(x2)}.</p>
  113.      *
  114.      * @param x the point at which the PMF is evaluated
  115.      * @return the value of the probability mass function at {@code x}
  116.      */
  117.     double probability(final T x) {
  118.         double probability = 0;

  119.         for (int i = 0; i < probabilities.length; i++) {
  120.             if ((x == null && singletons.get(i) == null) ||
  121.                 (x != null && x.equals(singletons.get(i)))) {
  122.                 probability += probabilities[i];
  123.             }
  124.         }

  125.         return probability;
  126.     }

  127.     /**
  128.      * <p>Return the probability mass function as a list of &lt;value, probability&gt; pairs.</p>
  129.      *
  130.      * <p>Note that if duplicate and / or null values were provided to the constructor
  131.      * when creating this EnumeratedDistribution, the returned list will contain these
  132.      * values.  If duplicates values exist, what is returned will not represent
  133.      * a pmf (i.e., it is up to the caller to consolidate duplicate mass points).</p>
  134.      *
  135.      * @return the probability mass function.
  136.      */
  137.     public List<Pair<T, Double>> getPmf() {
  138.         final List<Pair<T, Double>> samples = new ArrayList<>(probabilities.length);

  139.         for (int i = 0; i < probabilities.length; i++) {
  140.             samples.add(new Pair<>(singletons.get(i), probabilities[i]));
  141.         }

  142.         return samples;
  143.     }

  144.     /**
  145.      * Creates a {@link Sampler}.
  146.      *
  147.      * @param rng Random number generator.
  148.      * @return a new sampler instance.
  149.      */
  150.     public Sampler createSampler(final UniformRandomProvider rng) {
  151.         return new Sampler(rng);
  152.     }

  153.     /**
  154.      * Sampler functionality.
  155.      *
  156.      * <ul>
  157.      *  <li>
  158.      *   The cumulative probability distribution is created (and sampled from)
  159.      *   using the input order of the {@link EnumeratedDistribution#EnumeratedDistribution(List)
  160.      *   constructor arguments}: A different input order will create a different
  161.      *   sequence of samples.
  162.      *   The samples will only be reproducible with the same RNG starting from
  163.      *   the same RNG state and the same input order to constructor.
  164.      *  </li>
  165.      *  <li>
  166.      *   The minimum supported probability is 2<sup>-53</sup>.
  167.      *  </li>
  168.      * </ul>
  169.      */
  170.     public class Sampler {
  171.         /** Underlying sampler. */
  172.         private final DiscreteProbabilityCollectionSampler<T> sampler;

  173.         /**
  174.          * @param rng Random number generator.
  175.          */
  176.         Sampler(UniformRandomProvider rng) {
  177.             sampler = new DiscreteProbabilityCollectionSampler<>(rng, singletons, probabilities);
  178.         }

  179.         /**
  180.          * Generates a random value sampled from this distribution.
  181.          *
  182.          * @return a random value.
  183.          */
  184.         public T sample() {
  185.             return sampler.sample();
  186.         }

  187.         /**
  188.          * Generates a random sample from the distribution.
  189.          *
  190.          * @param sampleSize the number of random values to generate.
  191.          * @return an array representing the random sample.
  192.          * @throws NotStrictlyPositiveException if {@code sampleSize} is not
  193.          * positive.
  194.          */
  195.         public Object[] sample(int sampleSize) throws NotStrictlyPositiveException {
  196.             if (sampleSize <= 0) {
  197.                 throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES,
  198.                                                        sampleSize);
  199.             }

  200.             final Object[] out = new Object[sampleSize];

  201.             for (int i = 0; i < sampleSize; i++) {
  202.                 out[i] = sample();
  203.             }

  204.             return out;
  205.         }

  206.         /**
  207.          * Generates a random sample from the distribution.
  208.          * <p>
  209.          * If the requested samples fit in the specified array, it is returned
  210.          * therein. Otherwise, a new array is allocated with the runtime type of
  211.          * the specified array and the size of this collection.
  212.          *
  213.          * @param sampleSize the number of random values to generate.
  214.          * @param array the array to populate.
  215.          * @return an array representing the random sample.
  216.          * @throws NotStrictlyPositiveException if {@code sampleSize} is not positive.
  217.          * @throws NullArgumentException if {@code array} is null
  218.          */
  219.         public T[] sample(int sampleSize, final T[] array) throws NotStrictlyPositiveException {
  220.             if (sampleSize <= 0) {
  221.                 throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES, sampleSize);
  222.             }

  223.             if (array == null) {
  224.                 throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY);
  225.             }

  226.             T[] out;
  227.             if (array.length < sampleSize) {
  228.                 @SuppressWarnings("unchecked") // safe as both are of type T
  229.                 final T[] unchecked = (T[]) Array.newInstance(array.getClass().getComponentType(), sampleSize);
  230.                 out = unchecked;
  231.             } else {
  232.                 out = array;
  233.             }

  234.             for (int i = 0; i < sampleSize; i++) {
  235.                 out[i] = sample();
  236.             }

  237.             return out;
  238.         }
  239.     }
  240. }