001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.rng.sampling;
019
020import java.util.List;
021import java.util.Map;
022import java.util.ArrayList;
023
024import org.apache.commons.rng.UniformRandomProvider;
025import org.apache.commons.rng.sampling.distribution.GuideTableDiscreteSampler;
026import org.apache.commons.rng.sampling.distribution.SharedStateDiscreteSampler;
027
028/**
029 * Sampling from a collection of items with user-defined
030 * <a href="http://en.wikipedia.org/wiki/Probability_distribution#Discrete_probability_distribution">
031 * probabilities</a>.
032 * Note that if all unique items are assigned the same probability,
033 * it is much more efficient to use {@link CollectionSampler}.
034 *
035 * <p>Sampling uses {@link UniformRandomProvider#nextDouble()}.</p>
036 *
037 * @param <T> Type of items in the collection.
038 *
039 * @since 1.1
040 */
041public class DiscreteProbabilityCollectionSampler<T> implements SharedStateObjectSampler<T> {
042    /** The error message for an empty collection. */
043    private static final String EMPTY_COLLECTION = "Empty collection";
044    /** Collection to be sampled from. */
045    private final List<T> items;
046    /** Sampler for the probabilities. */
047    private final SharedStateDiscreteSampler sampler;
048
049    /**
050     * Creates a sampler.
051     *
052     * @param rng Generator of uniformly distributed random numbers.
053     * @param collection Collection to be sampled, with the probabilities
054     * associated to each of its items.
055     * A (shallow) copy of the items will be stored in the created instance.
056     * The probabilities must be non-negative, but zero values are allowed
057     * and their sum does not have to equal one (input will be normalized
058     * to make the probabilities sum to one).
059     * @throws IllegalArgumentException if {@code collection} is empty, a
060     * probability is negative, infinite or {@code NaN}, or the sum of all
061     * probabilities is not strictly positive.
062     */
063    public DiscreteProbabilityCollectionSampler(UniformRandomProvider rng,
064                                                Map<T, Double> collection) {
065        if (collection.isEmpty()) {
066            throw new IllegalArgumentException(EMPTY_COLLECTION);
067        }
068
069        // Extract the items and probabilities
070        final int size = collection.size();
071        items = new ArrayList<>(size);
072        final double[] probabilities = new double[size];
073
074        int count = 0;
075        for (final Map.Entry<T, Double> e : collection.entrySet()) {
076            items.add(e.getKey());
077            probabilities[count++] = e.getValue();
078        }
079
080        // Delegate sampling
081        sampler = createSampler(rng, probabilities);
082    }
083
084    /**
085     * Creates a sampler.
086     *
087     * @param rng Generator of uniformly distributed random numbers.
088     * @param collection Collection to be sampled.
089     * A (shallow) copy of the items will be stored in the created instance.
090     * @param probabilities Probability associated to each item of the
091     * {@code collection}.
092     * The probabilities must be non-negative, but zero values are allowed
093     * and their sum does not have to equal one (input will be normalized
094     * to make the probabilities sum to one).
095     * @throws IllegalArgumentException if {@code collection} is empty or
096     * a probability is negative, infinite or {@code NaN}, or if the number
097     * of items in the {@code collection} is not equal to the number of
098     * provided {@code probabilities}.
099     */
100    public DiscreteProbabilityCollectionSampler(UniformRandomProvider rng,
101                                                List<T> collection,
102                                                double[] probabilities) {
103        if (collection.isEmpty()) {
104            throw new IllegalArgumentException(EMPTY_COLLECTION);
105        }
106        final int len = probabilities.length;
107        if (len != collection.size()) {
108            throw new IllegalArgumentException("Size mismatch: " +
109                                               len + " != " +
110                                               collection.size());
111        }
112        // Shallow copy the list
113        items = new ArrayList<>(collection);
114        // Delegate sampling
115        sampler = createSampler(rng, probabilities);
116    }
117
118    /**
119     * @param rng Generator of uniformly distributed random numbers.
120     * @param source Source to copy.
121     */
122    private DiscreteProbabilityCollectionSampler(UniformRandomProvider rng,
123                                                 DiscreteProbabilityCollectionSampler<T> source) {
124        this.items = source.items;
125        this.sampler = source.sampler.withUniformRandomProvider(rng);
126    }
127
128    /**
129     * Picks one of the items from the collection passed to the constructor.
130     *
131     * @return a random sample.
132     */
133    @Override
134    public T sample() {
135        return items.get(sampler.sample());
136    }
137
138    /**
139     * {@inheritDoc}
140     *
141     * @since 1.3
142     */
143    @Override
144    public DiscreteProbabilityCollectionSampler<T> withUniformRandomProvider(UniformRandomProvider rng) {
145        return new DiscreteProbabilityCollectionSampler<>(rng, this);
146    }
147
148    /**
149     * Creates the sampler of the enumerated probability distribution.
150     *
151     * @param rng Generator of uniformly distributed random numbers.
152     * @param probabilities Probability associated to each item.
153     * @return the sampler
154     */
155    private static SharedStateDiscreteSampler createSampler(UniformRandomProvider rng,
156                                                            double[] probabilities) {
157        return GuideTableDiscreteSampler.of(rng, probabilities);
158    }
159}