DiscreteProbabilityCollectionSampler.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.rng.sampling;

  18. import java.util.List;
  19. import java.util.Map;
  20. import java.util.ArrayList;
  21. import org.apache.commons.rng.UniformRandomProvider;
  22. import org.apache.commons.rng.sampling.distribution.GuideTableDiscreteSampler;
  23. import org.apache.commons.rng.sampling.distribution.SharedStateDiscreteSampler;

  24. /**
  25.  * Sampling from a collection of items with user-defined
  26.  * <a href="http://en.wikipedia.org/wiki/Probability_distribution#Discrete_probability_distribution">
  27.  * probabilities</a>.
  28.  * Note that if all unique items are assigned the same probability,
  29.  * it is much more efficient to use {@link CollectionSampler}.
  30.  *
  31.  * <p>Sampling uses {@link UniformRandomProvider#nextDouble()}.</p>
  32.  *
  33.  * @param <T> Type of items in the collection.
  34.  *
  35.  * @since 1.1
  36.  */
  37. public class DiscreteProbabilityCollectionSampler<T> implements SharedStateObjectSampler<T> {
  38.     /** The error message for an empty collection. */
  39.     private static final String EMPTY_COLLECTION = "Empty collection";
  40.     /** Collection to be sampled from. */
  41.     private final List<T> items;
  42.     /** Sampler for the probabilities. */
  43.     private final SharedStateDiscreteSampler sampler;

  44.     /**
  45.      * Creates a sampler.
  46.      *
  47.      * @param rng Generator of uniformly distributed random numbers.
  48.      * @param collection Collection to be sampled, with the probabilities
  49.      * associated to each of its items.
  50.      * A (shallow) copy of the items will be stored in the created instance.
  51.      * The probabilities must be non-negative, but zero values are allowed
  52.      * and their sum does not have to equal one (input will be normalized
  53.      * to make the probabilities sum to one).
  54.      * @throws IllegalArgumentException if {@code collection} is empty, a
  55.      * probability is negative, infinite or {@code NaN}, or the sum of all
  56.      * probabilities is not strictly positive.
  57.      */
  58.     public DiscreteProbabilityCollectionSampler(UniformRandomProvider rng,
  59.                                                 Map<T, Double> collection) {
  60.         this(toList(collection),
  61.              createSampler(rng, toProbabilities(collection)));
  62.     }

  63.     /**
  64.      * Creates a sampler.
  65.      *
  66.      * @param rng Generator of uniformly distributed random numbers.
  67.      * @param collection Collection to be sampled.
  68.      * A (shallow) copy of the items will be stored in the created instance.
  69.      * @param probabilities Probability associated to each item of the
  70.      * {@code collection}.
  71.      * The probabilities must be non-negative, but zero values are allowed
  72.      * and their sum does not have to equal one (input will be normalized
  73.      * to make the probabilities sum to one).
  74.      * @throws IllegalArgumentException if {@code collection} is empty or
  75.      * a probability is negative, infinite or {@code NaN}, or if the number
  76.      * of items in the {@code collection} is not equal to the number of
  77.      * provided {@code probabilities}.
  78.      */
  79.     public DiscreteProbabilityCollectionSampler(UniformRandomProvider rng,
  80.                                                 List<T> collection,
  81.                                                 double[] probabilities) {
  82.         this(copyList(collection),
  83.              createSampler(rng, collection, probabilities));
  84.     }

  85.     /**
  86.      * @param items Collection to be sampled.
  87.      * @param sampler Sampler for the probabilities.
  88.      */
  89.     private DiscreteProbabilityCollectionSampler(List<T> items,
  90.                                                  SharedStateDiscreteSampler sampler) {
  91.         this.items = items;
  92.         this.sampler = sampler;
  93.     }

  94.     /**
  95.      * Picks one of the items from the collection passed to the constructor.
  96.      *
  97.      * @return a random sample.
  98.      */
  99.     @Override
  100.     public T sample() {
  101.         return items.get(sampler.sample());
  102.     }

  103.     /**
  104.      * {@inheritDoc}
  105.      *
  106.      * @since 1.3
  107.      */
  108.     @Override
  109.     public DiscreteProbabilityCollectionSampler<T> withUniformRandomProvider(UniformRandomProvider rng) {
  110.         return new DiscreteProbabilityCollectionSampler<>(items, sampler.withUniformRandomProvider(rng));
  111.     }

  112.     /**
  113.      * Creates the sampler of the enumerated probability distribution.
  114.      *
  115.      * @param rng Generator of uniformly distributed random numbers.
  116.      * @param probabilities Probability associated to each item.
  117.      * @return the sampler
  118.      */
  119.     private static SharedStateDiscreteSampler createSampler(UniformRandomProvider rng,
  120.                                                             double[] probabilities) {
  121.         return GuideTableDiscreteSampler.of(rng, probabilities);
  122.     }

  123.     /**
  124.      * Creates the sampler of the enumerated probability distribution.
  125.      *
  126.      * @param <T> Type of items in the collection.
  127.      * @param rng Generator of uniformly distributed random numbers.
  128.      * @param collection Collection to be sampled.
  129.      * @param probabilities Probability associated to each item.
  130.      * @return the sampler
  131.      * @throws IllegalArgumentException if the number
  132.      * of items in the {@code collection} is not equal to the number of
  133.      * provided {@code probabilities}.
  134.      */
  135.     private static <T> SharedStateDiscreteSampler createSampler(UniformRandomProvider rng,
  136.                                                                 List<T> collection,
  137.                                                                 double[] probabilities) {
  138.         if (probabilities.length != collection.size()) {
  139.             throw new IllegalArgumentException("Size mismatch: " +
  140.                                                probabilities.length + " != " +
  141.                                                collection.size());
  142.         }
  143.         return GuideTableDiscreteSampler.of(rng, probabilities);
  144.     }

  145.     // Validation methods exist to raise an exception before invocation of the
  146.     // private constructor; this mitigates Finalizer attacks
  147.     // (see SpotBugs CT_CONSTRUCTOR_THROW).

  148.     /**
  149.      * Extract the items.
  150.      *
  151.      * @param <T> Type of items in the collection.
  152.      * @param collection Collection.
  153.      * @return the items
  154.      * @throws IllegalArgumentException if {@code collection} is empty.
  155.      */
  156.     private static <T> List<T> toList(Map<T, Double> collection) {
  157.         if (collection.isEmpty()) {
  158.             throw new IllegalArgumentException(EMPTY_COLLECTION);
  159.         }
  160.         return new ArrayList<>(collection.keySet());
  161.     }

  162.     /**
  163.      * Extract the probabilities.
  164.      *
  165.      * @param <T> Type of items in the collection.
  166.      * @param collection Collection.
  167.      * @return the probabilities
  168.      */
  169.     private static <T> double[] toProbabilities(Map<T, Double> collection) {
  170.         final int size = collection.size();
  171.         final double[] probabilities = new double[size];
  172.         int count = 0;
  173.         for (final Double e : collection.values()) {
  174.             final double probability = e;
  175.             if (probability < 0 ||
  176.                 Double.isInfinite(probability) ||
  177.                 Double.isNaN(probability)) {
  178.                 throw new IllegalArgumentException("Invalid probability: " +
  179.                                                    probability);
  180.             }
  181.             probabilities[count++] = probability;
  182.         }
  183.         return probabilities;
  184.     }

  185.     /**
  186.      * Create a (shallow) copy of the collection.
  187.      *
  188.      * @param <T> Type of items in the collection.
  189.      * @param collection Collection.
  190.      * @return the copy
  191.      * @throws IllegalArgumentException if {@code collection} is empty.
  192.      */
  193.     private static <T> List<T> copyList(List<T> collection) {
  194.         if (collection.isEmpty()) {
  195.             throw new IllegalArgumentException(EMPTY_COLLECTION);
  196.         }
  197.         return new ArrayList<>(collection);
  198.     }
  199. }