1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.commons.rng.sampling;
19
20 import java.util.List;
21 import java.util.Map;
22 import java.util.ArrayList;
23 import org.apache.commons.rng.UniformRandomProvider;
24 import org.apache.commons.rng.sampling.distribution.GuideTableDiscreteSampler;
25 import org.apache.commons.rng.sampling.distribution.SharedStateDiscreteSampler;
26
27 /**
28 * Sampling from a collection of items with user-defined
29 * <a href="http://en.wikipedia.org/wiki/Probability_distribution#Discrete_probability_distribution">
30 * probabilities</a>.
31 * Note that if all unique items are assigned the same probability,
32 * it is much more efficient to use {@link CollectionSampler}.
33 *
34 * <p>Sampling uses {@link UniformRandomProvider#nextDouble()}.</p>
35 *
36 * @param <T> Type of items in the collection.
37 *
38 * @since 1.1
39 */
40 public class DiscreteProbabilityCollectionSampler<T> implements SharedStateObjectSampler<T> {
41 /** The error message for an empty collection. */
42 private static final String EMPTY_COLLECTION = "Empty collection";
43 /** Collection to be sampled from. */
44 private final List<T> items;
45 /** Sampler for the probabilities. */
46 private final SharedStateDiscreteSampler sampler;
47
48 /**
49 * Creates a sampler.
50 *
51 * @param rng Generator of uniformly distributed random numbers.
52 * @param collection Collection to be sampled, with the probabilities
53 * associated to each of its items.
54 * A (shallow) copy of the items will be stored in the created instance.
55 * The probabilities must be non-negative, but zero values are allowed
56 * and their sum does not have to equal one (input will be normalized
57 * to make the probabilities sum to one).
58 * @throws IllegalArgumentException if {@code collection} is empty, a
59 * probability is negative, infinite or {@code NaN}, or the sum of all
60 * probabilities is not strictly positive.
61 */
62 public DiscreteProbabilityCollectionSampler(UniformRandomProvider rng,
63 Map<T, Double> collection) {
64 this(toList(collection),
65 createSampler(rng, toProbabilities(collection)));
66 }
67
68 /**
69 * Creates a sampler.
70 *
71 * @param rng Generator of uniformly distributed random numbers.
72 * @param collection Collection to be sampled.
73 * A (shallow) copy of the items will be stored in the created instance.
74 * @param probabilities Probability associated to each item of the
75 * {@code collection}.
76 * The probabilities must be non-negative, but zero values are allowed
77 * and their sum does not have to equal one (input will be normalized
78 * to make the probabilities sum to one).
79 * @throws IllegalArgumentException if {@code collection} is empty or
80 * a probability is negative, infinite or {@code NaN}, or if the number
81 * of items in the {@code collection} is not equal to the number of
82 * provided {@code probabilities}.
83 */
84 public DiscreteProbabilityCollectionSampler(UniformRandomProvider rng,
85 List<T> collection,
86 double[] probabilities) {
87 this(copyList(collection),
88 createSampler(rng, collection, probabilities));
89 }
90
91 /**
92 * @param items Collection to be sampled.
93 * @param sampler Sampler for the probabilities.
94 */
95 private DiscreteProbabilityCollectionSampler(List<T> items,
96 SharedStateDiscreteSampler sampler) {
97 this.items = items;
98 this.sampler = sampler;
99 }
100
101 /**
102 * Picks one of the items from the collection passed to the constructor.
103 *
104 * @return a random sample.
105 */
106 @Override
107 public T sample() {
108 return items.get(sampler.sample());
109 }
110
111 /**
112 * {@inheritDoc}
113 *
114 * @since 1.3
115 */
116 @Override
117 public DiscreteProbabilityCollectionSampler<T> withUniformRandomProvider(UniformRandomProvider rng) {
118 return new DiscreteProbabilityCollectionSampler<>(items, sampler.withUniformRandomProvider(rng));
119 }
120
121 /**
122 * Creates the sampler of the enumerated probability distribution.
123 *
124 * @param rng Generator of uniformly distributed random numbers.
125 * @param probabilities Probability associated to each item.
126 * @return the sampler
127 */
128 private static SharedStateDiscreteSampler createSampler(UniformRandomProvider rng,
129 double[] probabilities) {
130 return GuideTableDiscreteSampler.of(rng, probabilities);
131 }
132
133 /**
134 * Creates the sampler of the enumerated probability distribution.
135 *
136 * @param <T> Type of items in the collection.
137 * @param rng Generator of uniformly distributed random numbers.
138 * @param collection Collection to be sampled.
139 * @param probabilities Probability associated to each item.
140 * @return the sampler
141 * @throws IllegalArgumentException if the number
142 * of items in the {@code collection} is not equal to the number of
143 * provided {@code probabilities}.
144 */
145 private static <T> SharedStateDiscreteSampler createSampler(UniformRandomProvider rng,
146 List<T> collection,
147 double[] probabilities) {
148 if (probabilities.length != collection.size()) {
149 throw new IllegalArgumentException("Size mismatch: " +
150 probabilities.length + " != " +
151 collection.size());
152 }
153 return GuideTableDiscreteSampler.of(rng, probabilities);
154 }
155
156 // Validation methods exist to raise an exception before invocation of the
157 // private constructor; this mitigates Finalizer attacks
158 // (see SpotBugs CT_CONSTRUCTOR_THROW).
159
160 /**
161 * Extract the items.
162 *
163 * @param <T> Type of items in the collection.
164 * @param collection Collection.
165 * @return the items
166 * @throws IllegalArgumentException if {@code collection} is empty.
167 */
168 private static <T> List<T> toList(Map<T, Double> collection) {
169 if (collection.isEmpty()) {
170 throw new IllegalArgumentException(EMPTY_COLLECTION);
171 }
172 return new ArrayList<>(collection.keySet());
173 }
174
175 /**
176 * Extract the probabilities.
177 *
178 * @param <T> Type of items in the collection.
179 * @param collection Collection.
180 * @return the probabilities
181 */
182 private static <T> double[] toProbabilities(Map<T, Double> collection) {
183 final int size = collection.size();
184 final double[] probabilities = new double[size];
185 int count = 0;
186 for (final Double e : collection.values()) {
187 final double probability = e;
188 if (probability < 0 ||
189 Double.isInfinite(probability) ||
190 Double.isNaN(probability)) {
191 throw new IllegalArgumentException("Invalid probability: " +
192 probability);
193 }
194 probabilities[count++] = probability;
195 }
196 return probabilities;
197 }
198
199 /**
200 * Create a (shallow) copy of the collection.
201 *
202 * @param <T> Type of items in the collection.
203 * @param collection Collection.
204 * @return the copy
205 * @throws IllegalArgumentException if {@code collection} is empty.
206 */
207 private static <T> List<T> copyList(List<T> collection) {
208 if (collection.isEmpty()) {
209 throw new IllegalArgumentException(EMPTY_COLLECTION);
210 }
211 return new ArrayList<>(collection);
212 }
213 }