001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.commons.rng.sampling; 019 020import java.util.List; 021import java.util.Map; 022import java.util.HashMap; 023import java.util.ArrayList; 024import java.util.Arrays; 025 026import org.apache.commons.rng.UniformRandomProvider; 027 028/** 029 * Sampling from a collection of items with user-defined 030 * <a href="http://en.wikipedia.org/wiki/Probability_distribution#Discrete_probability_distribution"> 031 * probabilities</a>. 032 * Note that if all unique items are assigned the same probability, 033 * it is much more efficient to use {@link CollectionSampler}. 034 * 035 * @param <T> Type of items in the collection. 036 * 037 * @since 1.1 038 */ 039public class DiscreteProbabilityCollectionSampler<T> { 040 /** Collection to be sampled from. */ 041 private final List<T> items; 042 /** RNG. */ 043 private final UniformRandomProvider rng; 044 /** Cumulative probabilities. */ 045 private final double[] cumulativeProbabilities; 046 047 /** 048 * Creates a sampler. 049 * 050 * @param rng Generator of uniformly distributed random numbers. 051 * @param collection Collection to be sampled, with the probabilities 052 * associated to each of its items. 053 * A (shallow) copy of the items will be stored in the created instance. 054 * The probabilities must be non-negative, but zero values are allowed 055 * and their sum does not have to equal one (input will be normalized 056 * to make the probabilities sum to one). 057 * @throws IllegalArgumentException if {@code collection} is empty, a 058 * probability is negative, infinite or {@code NaN}, or the sum of all 059 * probabilities is not strictly positive. 060 */ 061 public DiscreteProbabilityCollectionSampler(UniformRandomProvider rng, 062 Map<T, Double> collection) { 063 if (collection.isEmpty()) { 064 throw new IllegalArgumentException("Empty collection"); 065 } 066 067 this.rng = rng; 068 final int size = collection.size(); 069 items = new ArrayList<T>(size); 070 cumulativeProbabilities = new double[size]; 071 072 double sumProb = 0; 073 int count = 0; 074 for (Map.Entry<T, Double> e : collection.entrySet()) { 075 items.add(e.getKey()); 076 077 final double prob = e.getValue(); 078 if (prob < 0 || 079 Double.isInfinite(prob) || 080 Double.isNaN(prob)) { 081 throw new IllegalArgumentException("Invalid probability: " + 082 prob); 083 } 084 085 // Temporarily store probability. 086 cumulativeProbabilities[count++] = prob; 087 sumProb += prob; 088 } 089 090 if (!(sumProb > 0)) { 091 throw new IllegalArgumentException("Invalid sum of probabilities"); 092 } 093 094 // Compute and store cumulative probability. 095 for (int i = 0; i < size; i++) { 096 cumulativeProbabilities[i] /= sumProb; 097 if (i > 0) { 098 cumulativeProbabilities[i] += cumulativeProbabilities[i - 1]; 099 } 100 } 101 } 102 103 /** 104 * Creates a sampler. 105 * 106 * @param rng Generator of uniformly distributed random numbers. 107 * @param collection Collection to be sampled. 108 * A (shallow) copy of the items will be stored in the created instance. 109 * @param probabilities Probability associated to each item of the 110 * {@code collection}. 111 * The probabilities must be non-negative, but zero values are allowed 112 * and their sum does not have to equal one (input will be normalized 113 * to make the probabilities sum to one). 114 * @throws IllegalArgumentException if {@code collection} is empty or 115 * a probability is negative, infinite or {@code NaN}, or if the number 116 * of items in the {@code collection} is not equal to the number of 117 * provided {@code probabilities}. 118 */ 119 public DiscreteProbabilityCollectionSampler(UniformRandomProvider rng, 120 List<T> collection, 121 double[] probabilities) { 122 this(rng, consolidate(collection, probabilities)); 123 } 124 125 /** 126 * Picks one of the items from the collection passed to the constructor. 127 * 128 * @return a random sample. 129 */ 130 public T sample() { 131 final double rand = rng.nextDouble(); 132 133 int index = Arrays.binarySearch(cumulativeProbabilities, rand); 134 if (index < 0) { 135 index = -index - 1; 136 } 137 138 if (index >= 0 && 139 index < cumulativeProbabilities.length && 140 rand < cumulativeProbabilities[index]) { 141 return items.get(index); 142 } 143 144 // This should never happen, but it ensures we will return a correct 145 // object in case there is some floating point inequality problem 146 // wrt the cumulative probabilities. 147 return items.get(items.size() - 1); 148 } 149 150 /** 151 * @param collection Collection to be sampled. 152 * @param probabilities Probability associated to each item of the 153 * {@code collection}. 154 * @return a consolidated map (where probabilities of equal items 155 * have been summed). 156 * @throws IllegalArgumentException if the number of items in the 157 * {@code collection} is not equal to the number of provided 158 * {@code probabilities}. 159 * @param <T> Type of items in the collection. 160 */ 161 private static <T> Map<T, Double> consolidate(List<T> collection, 162 double[] probabilities) { 163 final int len = probabilities.length; 164 if (len != collection.size()) { 165 throw new IllegalArgumentException("Size mismatch: " + 166 len + " != " + 167 collection.size()); 168 } 169 170 final Map<T, Double> map = new HashMap<T, Double>(); 171 for (int i = 0; i < len; i++) { 172 final T item = collection.get(i); 173 final Double prob = probabilities[i]; 174 175 Double currentProb = map.get(item); 176 if (currentProb == null) { 177 currentProb = 0d; 178 } 179 180 map.put(item, currentProb + prob); 181 } 182 183 return map; 184 } 185}