001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.commons.rng.sampling.distribution; 018 019import org.apache.commons.rng.UniformRandomProvider; 020 021/** 022 * Compute a sample from {@code n} values each with an associated probability. If all unique items 023 * are assigned the same probability it is more efficient to use the {@link DiscreteUniformSampler}. 024 * 025 * <p>The cumulative probability distribution is searched using a guide table to set an 026 * initial start point. This implementation is based on:</p> 027 * 028 * <blockquote> 029 * Devroye, Luc (1986). Non-Uniform Random Variate Generation. 030 * New York: Springer-Verlag. Chapter 3.2.4 "The method of guide tables" p. 96. 031 * </blockquote> 032 * 033 * <p>The size of the guide table can be controlled using a parameter. A larger guide table 034 * will improve performance at the cost of storage space.</p> 035 * 036 * <p>Sampling uses {@link UniformRandomProvider#nextDouble()}.</p> 037 * 038 * @see <a href="http://en.wikipedia.org/wiki/Probability_distribution#Discrete_probability_distribution"> 039 * Discrete probability distribution (Wikipedia)</a> 040 * @since 1.3 041 */ 042public final class GuideTableDiscreteSampler 043 implements SharedStateDiscreteSampler { 044 /** The default value for {@code alpha}. */ 045 private static final double DEFAULT_ALPHA = 1.0; 046 /** Underlying source of randomness. */ 047 private final UniformRandomProvider rng; 048 /** 049 * The cumulative probability table ({@code f(x)}). 050 */ 051 private final double[] cumulativeProbabilities; 052 /** 053 * The inverse cumulative probability guide table. This is a guide map between the cumulative 054 * probability (f(x)) and the value x. It is used to set the initial point for search 055 * of the cumulative probability table. 056 * 057 * <p>The index in the map is obtained using {@code p * map.length} where {@code p} is the 058 * known cumulative probability {@code f(x)} or a uniform random deviate {@code u}. The value 059 * stored at the index is value {@code x+1} when {@code p = f(x)} such that it is the 060 * exclusive upper bound on the sample value {@code x} for searching the cumulative probability 061 * table {@code f(x)}. The search of the cumulative probability is towards zero.</p> 062 */ 063 private final int[] guideTable; 064 065 /** 066 * @param rng Generator of uniformly distributed random numbers. 067 * @param cumulativeProbabilities The cumulative probability table ({@code f(x)}). 068 * @param guideTable The inverse cumulative probability guide table. 069 */ 070 private GuideTableDiscreteSampler(UniformRandomProvider rng, 071 double[] cumulativeProbabilities, 072 int[] guideTable) { 073 this.rng = rng; 074 this.cumulativeProbabilities = cumulativeProbabilities; 075 this.guideTable = guideTable; 076 } 077 078 /** {@inheritDoc} */ 079 @Override 080 public int sample() { 081 // Compute a probability 082 final double u = rng.nextDouble(); 083 084 // Initialise the search using the guide table to find an initial guess. 085 // The table provides an upper bound on the sample (x+1) for a known 086 // cumulative probability (f(x)). 087 int x = guideTable[getGuideTableIndex(u, guideTable.length)]; 088 // Search down. 089 // In the edge case where u is 1.0 then 'x' will be 1 outside the range of the 090 // cumulative probability table and this will decrement to a valid range. 091 // In the case where 'u' is mapped to the same guide table index as a lower 092 // cumulative probability f(x) (due to rounding down) then this will not decrement 093 // and return the exclusive upper bound (x+1). 094 while (x != 0 && u <= cumulativeProbabilities[x - 1]) { 095 x--; 096 } 097 return x; 098 } 099 100 /** {@inheritDoc} */ 101 @Override 102 public String toString() { 103 return "Guide table deviate [" + rng.toString() + "]"; 104 } 105 106 /** {@inheritDoc} */ 107 @Override 108 public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) { 109 return new GuideTableDiscreteSampler(rng, cumulativeProbabilities, guideTable); 110 } 111 112 /** 113 * Create a new sampler for an enumerated distribution using the given {@code probabilities}. 114 * The samples corresponding to each probability are assumed to be a natural sequence 115 * starting at zero. 116 * 117 * <p>The size of the guide table is {@code probabilities.length}.</p> 118 * 119 * @param rng Generator of uniformly distributed random numbers. 120 * @param probabilities The probabilities. 121 * @return the sampler 122 * @throws IllegalArgumentException if {@code probabilities} is null or empty, a 123 * probability is negative, infinite or {@code NaN}, or the sum of all 124 * probabilities is not strictly positive. 125 */ 126 public static SharedStateDiscreteSampler of(UniformRandomProvider rng, 127 double[] probabilities) { 128 return of(rng, probabilities, DEFAULT_ALPHA); 129 } 130 131 /** 132 * Create a new sampler for an enumerated distribution using the given {@code probabilities}. 133 * The samples corresponding to each probability are assumed to be a natural sequence 134 * starting at zero. 135 * 136 * <p>The size of the guide table is {@code alpha * probabilities.length}.</p> 137 * 138 * @param rng Generator of uniformly distributed random numbers. 139 * @param probabilities The probabilities. 140 * @param alpha The alpha factor used to set the guide table size. 141 * @return the sampler 142 * @throws IllegalArgumentException if {@code probabilities} is null or empty, a 143 * probability is negative, infinite or {@code NaN}, the sum of all 144 * probabilities is not strictly positive, or {@code alpha} is not strictly positive. 145 */ 146 public static SharedStateDiscreteSampler of(UniformRandomProvider rng, 147 double[] probabilities, 148 double alpha) { 149 validateParameters(probabilities, alpha); 150 151 final int size = probabilities.length; 152 final double[] cumulativeProbabilities = new double[size]; 153 154 double sumProb = 0; 155 int count = 0; 156 for (final double prob : probabilities) { 157 InternalUtils.validateProbability(prob); 158 159 // Compute and store cumulative probability. 160 sumProb += prob; 161 cumulativeProbabilities[count++] = sumProb; 162 } 163 164 if (Double.isInfinite(sumProb) || sumProb <= 0) { 165 throw new IllegalArgumentException("Invalid sum of probabilities: " + sumProb); 166 } 167 168 // Note: The guide table is at least length 1. Compute the size avoiding overflow 169 // in case (alpha * size) is too large. 170 final int guideTableSize = (int) Math.ceil(alpha * size); 171 final int[] guideTable = new int[Math.max(guideTableSize, guideTableSize + 1)]; 172 173 // Compute and store cumulative probability. 174 for (int x = 0; x < size; x++) { 175 final double norm = cumulativeProbabilities[x] / sumProb; 176 cumulativeProbabilities[x] = (norm < 1) ? norm : 1.0; 177 178 // Set the guide table value as an exclusive upper bound (x + 1) 179 final int index = getGuideTableIndex(cumulativeProbabilities[x], guideTable.length); 180 guideTable[index] = x + 1; 181 } 182 183 // Edge case for round-off 184 cumulativeProbabilities[size - 1] = 1.0; 185 // The final guide table entry is (maximum value of x + 1) 186 guideTable[guideTable.length - 1] = size; 187 188 // The first non-zero value in the guide table is from f(x=0). 189 // Any probabilities mapped below this must be sample x=0 so the 190 // table may initially be filled with zeros. 191 192 // Fill missing values in the guide table. 193 for (int i = 1; i < guideTable.length; i++) { 194 guideTable[i] = Math.max(guideTable[i - 1], guideTable[i]); 195 } 196 197 return new GuideTableDiscreteSampler(rng, cumulativeProbabilities, guideTable); 198 } 199 200 /** 201 * Validate the parameters. 202 * 203 * @param probabilities The probabilities. 204 * @param alpha The alpha factor used to set the guide table size. 205 * @throws IllegalArgumentException if {@code probabilities} is null or empty, or 206 * {@code alpha} is not strictly positive. 207 */ 208 private static void validateParameters(double[] probabilities, double alpha) { 209 if (probabilities == null || probabilities.length == 0) { 210 throw new IllegalArgumentException("Probabilities must not be empty."); 211 } 212 if (alpha <= 0) { 213 throw new IllegalArgumentException("Alpha must be strictly positive."); 214 } 215 } 216 217 /** 218 * Gets the guide table index for the probability. This is obtained using 219 * {@code p * (tableLength - 1)} so is inside the length of the table. 220 * 221 * @param p Cumulative probability. 222 * @param tableLength Table length. 223 * @return the guide table index. 224 */ 225 private static int getGuideTableIndex(double p, int tableLength) { 226 // Note: This is only ever called when p is in the range of the cumulative 227 // probability table. So assume 0 <= p <= 1. 228 return (int) (p * (tableLength - 1)); 229 } 230}