1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 package org.apache.commons.rng.sampling.distribution; 18 19 import org.apache.commons.rng.UniformRandomProvider; 20 21 /** 22 * Compute a sample from {@code n} values each with an associated probability. If all unique items 23 * are assigned the same probability it is more efficient to use the {@link DiscreteUniformSampler}. 24 * 25 * <p>The cumulative probability distribution is searched using a guide table to set an 26 * initial start point. This implementation is based on:</p> 27 * 28 * <blockquote> 29 * Devroye, Luc (1986). Non-Uniform Random Variate Generation. 30 * New York: Springer-Verlag. Chapter 3.2.4 "The method of guide tables" p. 96. 31 * </blockquote> 32 * 33 * <p>The size of the guide table can be controlled using a parameter. A larger guide table 34 * will improve performance at the cost of storage space.</p> 35 * 36 * <p>Sampling uses {@link UniformRandomProvider#nextDouble()}.</p> 37 * 38 * @see <a href="http://en.wikipedia.org/wiki/Probability_distribution#Discrete_probability_distribution"> 39 * Discrete probability distribution (Wikipedia)</a> 40 * @since 1.3 41 */ 42 public final class GuideTableDiscreteSampler 43 implements SharedStateDiscreteSampler { 44 /** The default value for {@code alpha}. */ 45 private static final double DEFAULT_ALPHA = 1.0; 46 /** Underlying source of randomness. */ 47 private final UniformRandomProvider rng; 48 /** 49 * The cumulative probability table ({@code f(x)}). 50 */ 51 private final double[] cumulativeProbabilities; 52 /** 53 * The inverse cumulative probability guide table. This is a guide map between the cumulative 54 * probability (f(x)) and the value x. It is used to set the initial point for search 55 * of the cumulative probability table. 56 * 57 * <p>The index in the map is obtained using {@code p * map.length} where {@code p} is the 58 * known cumulative probability {@code f(x)} or a uniform random deviate {@code u}. The value 59 * stored at the index is value {@code x+1} when {@code p = f(x)} such that it is the 60 * exclusive upper bound on the sample value {@code x} for searching the cumulative probability 61 * table {@code f(x)}. The search of the cumulative probability is towards zero.</p> 62 */ 63 private final int[] guideTable; 64 65 /** 66 * @param rng Generator of uniformly distributed random numbers. 67 * @param cumulativeProbabilities The cumulative probability table ({@code f(x)}). 68 * @param guideTable The inverse cumulative probability guide table. 69 */ 70 private GuideTableDiscreteSampler(UniformRandomProvider rng, 71 double[] cumulativeProbabilities, 72 int[] guideTable) { 73 this.rng = rng; 74 this.cumulativeProbabilities = cumulativeProbabilities; 75 this.guideTable = guideTable; 76 } 77 78 /** {@inheritDoc} */ 79 @Override 80 public int sample() { 81 // Compute a probability 82 final double u = rng.nextDouble(); 83 84 // Initialise the search using the guide table to find an initial guess. 85 // The table provides an upper bound on the sample (x+1) for a known 86 // cumulative probability (f(x)). 87 int x = guideTable[getGuideTableIndex(u, guideTable.length)]; 88 // Search down. 89 // In the edge case where u is 1.0 then 'x' will be 1 outside the range of the 90 // cumulative probability table and this will decrement to a valid range. 91 // In the case where 'u' is mapped to the same guide table index as a lower 92 // cumulative probability f(x) (due to rounding down) then this will not decrement 93 // and return the exclusive upper bound (x+1). 94 while (x != 0 && u <= cumulativeProbabilities[x - 1]) { 95 x--; 96 } 97 return x; 98 } 99 100 /** {@inheritDoc} */ 101 @Override 102 public String toString() { 103 return "Guide table deviate [" + rng.toString() + "]"; 104 } 105 106 /** {@inheritDoc} */ 107 @Override 108 public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) { 109 return new GuideTableDiscreteSampler(rng, cumulativeProbabilities, guideTable); 110 } 111 112 /** 113 * Create a new sampler for an enumerated distribution using the given {@code probabilities}. 114 * The samples corresponding to each probability are assumed to be a natural sequence 115 * starting at zero. 116 * 117 * <p>The size of the guide table is {@code probabilities.length}.</p> 118 * 119 * @param rng Generator of uniformly distributed random numbers. 120 * @param probabilities The probabilities. 121 * @return the sampler 122 * @throws IllegalArgumentException if {@code probabilities} is null or empty, a 123 * probability is negative, infinite or {@code NaN}, or the sum of all 124 * probabilities is not strictly positive. 125 */ 126 public static SharedStateDiscreteSampler of(UniformRandomProvider rng, 127 double[] probabilities) { 128 return of(rng, probabilities, DEFAULT_ALPHA); 129 } 130 131 /** 132 * Create a new sampler for an enumerated distribution using the given {@code probabilities}. 133 * The samples corresponding to each probability are assumed to be a natural sequence 134 * starting at zero. 135 * 136 * <p>The size of the guide table is {@code alpha * probabilities.length}.</p> 137 * 138 * @param rng Generator of uniformly distributed random numbers. 139 * @param probabilities The probabilities. 140 * @param alpha The alpha factor used to set the guide table size. 141 * @return the sampler 142 * @throws IllegalArgumentException if {@code probabilities} is null or empty, a 143 * probability is negative, infinite or {@code NaN}, the sum of all 144 * probabilities is not strictly positive, or {@code alpha} is not strictly positive. 145 */ 146 public static SharedStateDiscreteSampler of(UniformRandomProvider rng, 147 double[] probabilities, 148 double alpha) { 149 validateParameters(probabilities, alpha); 150 151 final int size = probabilities.length; 152 final double[] cumulativeProbabilities = new double[size]; 153 154 double sumProb = 0; 155 int count = 0; 156 for (final double prob : probabilities) { 157 InternalUtils.validateProbability(prob); 158 159 // Compute and store cumulative probability. 160 sumProb += prob; 161 cumulativeProbabilities[count++] = sumProb; 162 } 163 164 if (Double.isInfinite(sumProb) || sumProb <= 0) { 165 throw new IllegalArgumentException("Invalid sum of probabilities: " + sumProb); 166 } 167 168 // Note: The guide table is at least length 1. Compute the size avoiding overflow 169 // in case (alpha * size) is too large. 170 final int guideTableSize = (int) Math.ceil(alpha * size); 171 final int[] guideTable = new int[Math.max(guideTableSize, guideTableSize + 1)]; 172 173 // Compute and store cumulative probability. 174 for (int x = 0; x < size; x++) { 175 final double norm = cumulativeProbabilities[x] / sumProb; 176 cumulativeProbabilities[x] = (norm < 1) ? norm : 1.0; 177 178 // Set the guide table value as an exclusive upper bound (x + 1) 179 final int index = getGuideTableIndex(cumulativeProbabilities[x], guideTable.length); 180 guideTable[index] = x + 1; 181 } 182 183 // Edge case for round-off 184 cumulativeProbabilities[size - 1] = 1.0; 185 // The final guide table entry is (maximum value of x + 1) 186 guideTable[guideTable.length - 1] = size; 187 188 // The first non-zero value in the guide table is from f(x=0). 189 // Any probabilities mapped below this must be sample x=0 so the 190 // table may initially be filled with zeros. 191 192 // Fill missing values in the guide table. 193 for (int i = 1; i < guideTable.length; i++) { 194 guideTable[i] = Math.max(guideTable[i - 1], guideTable[i]); 195 } 196 197 return new GuideTableDiscreteSampler(rng, cumulativeProbabilities, guideTable); 198 } 199 200 /** 201 * Validate the parameters. 202 * 203 * @param probabilities The probabilities. 204 * @param alpha The alpha factor used to set the guide table size. 205 * @throws IllegalArgumentException if {@code probabilities} is null or empty, or 206 * {@code alpha} is not strictly positive. 207 */ 208 private static void validateParameters(double[] probabilities, double alpha) { 209 if (probabilities == null || probabilities.length == 0) { 210 throw new IllegalArgumentException("Probabilities must not be empty."); 211 } 212 if (alpha <= 0) { 213 throw new IllegalArgumentException("Alpha must be strictly positive."); 214 } 215 } 216 217 /** 218 * Gets the guide table index for the probability. This is obtained using 219 * {@code p * (tableLength - 1)} so is inside the length of the table. 220 * 221 * @param p Cumulative probability. 222 * @param tableLength Table length. 223 * @return the guide table index. 224 */ 225 private static int getGuideTableIndex(double p, int tableLength) { 226 // Note: This is only ever called when p is in the range of the cumulative 227 // probability table. So assume 0 <= p <= 1. 228 return (int) (p * (tableLength - 1)); 229 } 230 }