001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.rng.sampling;
019
020import java.util.List;
021import java.util.Map;
022import java.util.HashMap;
023import java.util.ArrayList;
024import java.util.Arrays;
025
026import org.apache.commons.rng.UniformRandomProvider;
027
028/**
029 * Sampling from a collection of items with user-defined
030 * <a href="http://en.wikipedia.org/wiki/Probability_distribution#Discrete_probability_distribution">
031 * probabilities</a>.
032 * Note that if all unique items are assigned the same probability,
033 * it is much more efficient to use {@link CollectionSampler}.
034 *
035 * @param <T> Type of items in the collection.
036 *
037 * @since 1.1
038 */
039public class DiscreteProbabilityCollectionSampler<T> {
040    /** Collection to be sampled from. */
041    private final List<T> items;
042    /** RNG. */
043    private final UniformRandomProvider rng;
044    /** Cumulative probabilities. */
045    private final double[] cumulativeProbabilities;
046
047    /**
048     * Creates a sampler.
049     *
050     * @param rng Generator of uniformly distributed random numbers.
051     * @param collection Collection to be sampled, with the probabilities
052     * associated to each of its items.
053     * A (shallow) copy of the items will be stored in the created instance.
054     * The probabilities must be non-negative, but zero values are allowed
055     * and their sum does not have to equal one (input will be normalized
056     * to make the probabilities sum to one).
057     * @throws IllegalArgumentException if {@code collection} is empty, a
058     * probability is negative, infinite or {@code NaN}, or the sum of all
059     * probabilities is not strictly positive.
060     */
061    public DiscreteProbabilityCollectionSampler(UniformRandomProvider rng,
062                                                Map<T, Double> collection) {
063        if (collection.isEmpty()) {
064            throw new IllegalArgumentException("Empty collection");
065        }
066
067        this.rng = rng;
068        final int size = collection.size();
069        items = new ArrayList<T>(size);
070        cumulativeProbabilities = new double[size];
071
072        double sumProb = 0;
073        int count = 0;
074        for (Map.Entry<T, Double> e : collection.entrySet()) {
075            items.add(e.getKey());
076
077            final double prob = e.getValue();
078            if (prob < 0 ||
079                Double.isInfinite(prob) ||
080                Double.isNaN(prob)) {
081                throw new IllegalArgumentException("Invalid probability: " +
082                                                   prob);
083            }
084
085            // Temporarily store probability.
086            cumulativeProbabilities[count++] = prob;
087            sumProb += prob;
088        }
089
090        if (!(sumProb > 0)) {
091            throw new IllegalArgumentException("Invalid sum of probabilities");
092        }
093
094        // Compute and store cumulative probability.
095        for (int i = 0; i < size; i++) {
096            cumulativeProbabilities[i] /= sumProb;
097            if (i > 0) {
098                cumulativeProbabilities[i] += cumulativeProbabilities[i - 1];
099            }
100        }
101    }
102
103    /**
104     * Creates a sampler.
105     *
106     * @param rng Generator of uniformly distributed random numbers.
107     * @param collection Collection to be sampled.
108     * A (shallow) copy of the items will be stored in the created instance.
109     * @param probabilities Probability associated to each item of the
110     * {@code collection}.
111     * The probabilities must be non-negative, but zero values are allowed
112     * and their sum does not have to equal one (input will be normalized
113     * to make the probabilities sum to one).
114     * @throws IllegalArgumentException if {@code collection} is empty or
115     * a probability is negative, infinite or {@code NaN}, or if the number
116     * of items in the {@code collection} is not equal to the number of
117     * provided {@code probabilities}.
118     */
119    public DiscreteProbabilityCollectionSampler(UniformRandomProvider rng,
120                                                List<T> collection,
121                                                double[] probabilities) {
122        this(rng, consolidate(collection, probabilities));
123    }
124
125    /**
126     * Picks one of the items from the collection passed to the constructor.
127     *
128     * @return a random sample.
129     */
130    public T sample() {
131        final double rand = rng.nextDouble();
132
133        int index = Arrays.binarySearch(cumulativeProbabilities, rand);
134        if (index < 0) {
135            index = -index - 1;
136        }
137
138        if (index >= 0 &&
139            index < cumulativeProbabilities.length &&
140            rand < cumulativeProbabilities[index]) {
141            return items.get(index);
142        }
143
144        // This should never happen, but it ensures we will return a correct
145        // object in case there is some floating point inequality problem
146        // wrt the cumulative probabilities.
147        return items.get(items.size() - 1);
148    }
149
150    /**
151     * @param collection Collection to be sampled.
152     * @param probabilities Probability associated to each item of the
153     * {@code collection}.
154     * @return a consolidated map (where probabilities of equal items
155     * have been summed).
156     * @throws IllegalArgumentException if the number of items in the
157     * {@code collection} is not equal to the number of provided
158     * {@code probabilities}.
159     * @param <T> Type of items in the collection.
160     */
161    private static <T> Map<T, Double> consolidate(List<T> collection,
162                                                  double[] probabilities) {
163        final int len = probabilities.length;
164        if (len != collection.size()) {
165            throw new IllegalArgumentException("Size mismatch: " +
166                                               len + " != " +
167                                               collection.size());
168        }
169
170        final Map<T, Double> map = new HashMap<T, Double>();
171        for (int i = 0; i < len; i++) {
172            final T item = collection.get(i);
173            final Double prob = probabilities[i];
174
175            Double currentProb = map.get(item);
176            if (currentProb == null) {
177                currentProb = 0d;
178            }
179
180            map.put(item, currentProb + prob);
181        }
182
183        return map;
184    }
185}