1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18 package org.apache.commons.rng.sampling; 19 20 import java.util.List; 21 import java.util.Map; 22 import java.util.ArrayList; 23 24 import org.apache.commons.rng.UniformRandomProvider; 25 import org.apache.commons.rng.sampling.distribution.GuideTableDiscreteSampler; 26 import org.apache.commons.rng.sampling.distribution.SharedStateDiscreteSampler; 27 28 /** 29 * Sampling from a collection of items with user-defined 30 * <a href="http://en.wikipedia.org/wiki/Probability_distribution#Discrete_probability_distribution"> 31 * probabilities</a>. 32 * Note that if all unique items are assigned the same probability, 33 * it is much more efficient to use {@link CollectionSampler}. 34 * 35 * <p>Sampling uses {@link UniformRandomProvider#nextDouble()}.</p> 36 * 37 * @param <T> Type of items in the collection. 38 * 39 * @since 1.1 40 */ 41 public class DiscreteProbabilityCollectionSampler<T> implements SharedStateObjectSampler<T> { 42 /** The error message for an empty collection. */ 43 private static final String EMPTY_COLLECTION = "Empty collection"; 44 /** Collection to be sampled from. */ 45 private final List<T> items; 46 /** Sampler for the probabilities. */ 47 private final SharedStateDiscreteSampler sampler; 48 49 /** 50 * Creates a sampler. 51 * 52 * @param rng Generator of uniformly distributed random numbers. 53 * @param collection Collection to be sampled, with the probabilities 54 * associated to each of its items. 55 * A (shallow) copy of the items will be stored in the created instance. 56 * The probabilities must be non-negative, but zero values are allowed 57 * and their sum does not have to equal one (input will be normalized 58 * to make the probabilities sum to one). 59 * @throws IllegalArgumentException if {@code collection} is empty, a 60 * probability is negative, infinite or {@code NaN}, or the sum of all 61 * probabilities is not strictly positive. 62 */ 63 public DiscreteProbabilityCollectionSampler(UniformRandomProvider rng, 64 Map<T, Double> collection) { 65 if (collection.isEmpty()) { 66 throw new IllegalArgumentException(EMPTY_COLLECTION); 67 } 68 69 // Extract the items and probabilities 70 final int size = collection.size(); 71 items = new ArrayList<>(size); 72 final double[] probabilities = new double[size]; 73 74 int count = 0; 75 for (final Map.Entry<T, Double> e : collection.entrySet()) { 76 items.add(e.getKey()); 77 probabilities[count++] = e.getValue(); 78 } 79 80 // Delegate sampling 81 sampler = createSampler(rng, probabilities); 82 } 83 84 /** 85 * Creates a sampler. 86 * 87 * @param rng Generator of uniformly distributed random numbers. 88 * @param collection Collection to be sampled. 89 * A (shallow) copy of the items will be stored in the created instance. 90 * @param probabilities Probability associated to each item of the 91 * {@code collection}. 92 * The probabilities must be non-negative, but zero values are allowed 93 * and their sum does not have to equal one (input will be normalized 94 * to make the probabilities sum to one). 95 * @throws IllegalArgumentException if {@code collection} is empty or 96 * a probability is negative, infinite or {@code NaN}, or if the number 97 * of items in the {@code collection} is not equal to the number of 98 * provided {@code probabilities}. 99 */ 100 public DiscreteProbabilityCollectionSampler(UniformRandomProvider rng, 101 List<T> collection, 102 double[] probabilities) { 103 if (collection.isEmpty()) { 104 throw new IllegalArgumentException(EMPTY_COLLECTION); 105 } 106 final int len = probabilities.length; 107 if (len != collection.size()) { 108 throw new IllegalArgumentException("Size mismatch: " + 109 len + " != " + 110 collection.size()); 111 } 112 // Shallow copy the list 113 items = new ArrayList<>(collection); 114 // Delegate sampling 115 sampler = createSampler(rng, probabilities); 116 } 117 118 /** 119 * @param rng Generator of uniformly distributed random numbers. 120 * @param source Source to copy. 121 */ 122 private DiscreteProbabilityCollectionSampler(UniformRandomProvider rng, 123 DiscreteProbabilityCollectionSampler<T> source) { 124 this.items = source.items; 125 this.sampler = source.sampler.withUniformRandomProvider(rng); 126 } 127 128 /** 129 * Picks one of the items from the collection passed to the constructor. 130 * 131 * @return a random sample. 132 */ 133 @Override 134 public T sample() { 135 return items.get(sampler.sample()); 136 } 137 138 /** 139 * {@inheritDoc} 140 * 141 * @since 1.3 142 */ 143 @Override 144 public DiscreteProbabilityCollectionSampler<T> withUniformRandomProvider(UniformRandomProvider rng) { 145 return new DiscreteProbabilityCollectionSampler<>(rng, this); 146 } 147 148 /** 149 * Creates the sampler of the enumerated probability distribution. 150 * 151 * @param rng Generator of uniformly distributed random numbers. 152 * @param probabilities Probability associated to each item. 153 * @return the sampler 154 */ 155 private static SharedStateDiscreteSampler createSampler(UniformRandomProvider rng, 156 double[] probabilities) { 157 return GuideTableDiscreteSampler.of(rng, probabilities); 158 } 159 }