001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.commons.rng.sampling.distribution; 018 019import org.apache.commons.rng.UniformRandomProvider; 020import org.apache.commons.rng.sampling.SharedStateObjectSampler; 021 022/** 023 * Sampling from a <a href="https://en.wikipedia.org/wiki/Dirichlet_distribution">Dirichlet 024 * distribution</a>. 025 * 026 * <p>Sampling uses:</p> 027 * 028 * <ul> 029 * <li>{@link UniformRandomProvider#nextLong()} 030 * <li>{@link UniformRandomProvider#nextDouble()} 031 * </ul> 032 * 033 * @since 1.4 034 */ 035public abstract class DirichletSampler implements SharedStateObjectSampler<double[]> { 036 /** The minimum number of categories. */ 037 private static final int MIN_CATGEORIES = 2; 038 039 /** RNG (used for the toString() method). */ 040 private final UniformRandomProvider rng; 041 042 /** 043 * Sample from a Dirichlet distribution with different concentration parameters 044 * for each category. 045 */ 046 private static final class GeneralDirichletSampler extends DirichletSampler { 047 /** Samplers for each category. */ 048 private final SharedStateContinuousSampler[] samplers; 049 050 /** 051 * @param rng Generator of uniformly distributed random numbers. 052 * @param samplers Samplers for each category. 053 */ 054 GeneralDirichletSampler(UniformRandomProvider rng, 055 SharedStateContinuousSampler[] samplers) { 056 super(rng); 057 // Array is stored directly as it is generated within the DirichletSampler class 058 this.samplers = samplers; 059 } 060 061 @Override 062 protected int getK() { 063 return samplers.length; 064 } 065 066 @Override 067 protected double nextGamma(int i) { 068 return samplers[i].sample(); 069 } 070 071 @Override 072 public GeneralDirichletSampler withUniformRandomProvider(UniformRandomProvider rng) { 073 final SharedStateContinuousSampler[] newSamplers = new SharedStateContinuousSampler[samplers.length]; 074 for (int i = 0; i < newSamplers.length; i++) { 075 newSamplers[i] = samplers[i].withUniformRandomProvider(rng); 076 } 077 return new GeneralDirichletSampler(rng, newSamplers); 078 } 079 } 080 081 /** 082 * Sample from a symmetric Dirichlet distribution with the same concentration parameter 083 * for each category. 084 */ 085 private static final class SymmetricDirichletSampler extends DirichletSampler { 086 /** Number of categories. */ 087 private final int k; 088 /** Sampler for the categories. */ 089 private final SharedStateContinuousSampler sampler; 090 091 /** 092 * @param rng Generator of uniformly distributed random numbers. 093 * @param k Number of categories. 094 * @param sampler Sampler for the categories. 095 */ 096 SymmetricDirichletSampler(UniformRandomProvider rng, 097 int k, 098 SharedStateContinuousSampler sampler) { 099 super(rng); 100 this.k = k; 101 this.sampler = sampler; 102 } 103 104 @Override 105 protected int getK() { 106 return k; 107 } 108 109 @Override 110 protected double nextGamma(int i) { 111 return sampler.sample(); 112 } 113 114 @Override 115 public SymmetricDirichletSampler withUniformRandomProvider(UniformRandomProvider rng) { 116 return new SymmetricDirichletSampler(rng, k, sampler.withUniformRandomProvider(rng)); 117 } 118 } 119 120 /** 121 * @param rng Generator of uniformly distributed random numbers. 122 */ 123 private DirichletSampler(UniformRandomProvider rng) { 124 this.rng = rng; 125 } 126 127 /** {@inheritDoc} */ 128 @Override 129 public String toString() { 130 return "Dirichlet deviate [" + rng.toString() + "]"; 131 } 132 133 @Override 134 public double[] sample() { 135 // Create Gamma(alpha_i, 1) deviates for all alpha 136 final double[] y = new double[getK()]; 137 double norm = 0; 138 for (int i = 0; i < y.length; i++) { 139 final double yi = nextGamma(i); 140 norm += yi; 141 y[i] = yi; 142 } 143 // Normalize by dividing by the sum of the samples 144 norm = 1.0 / norm; 145 // Detect an invalid normalization, e.g. cases of all zero samples 146 if (!isNonZeroPositiveFinite(norm)) { 147 // Sample again using recursion. 148 // A stack overflow due to a broken RNG will eventually occur 149 // rather than the alternative which is an infinite loop. 150 return sample(); 151 } 152 // Normalise 153 for (int i = 0; i < y.length; i++) { 154 y[i] *= norm; 155 } 156 return y; 157 } 158 159 /** 160 * Gets the number of categories. 161 * 162 * @return k 163 */ 164 protected abstract int getK(); 165 166 /** 167 * Create a gamma sample for the given category. 168 * 169 * @param category Category. 170 * @return the sample 171 */ 172 protected abstract double nextGamma(int category); 173 174 /** {@inheritDoc} */ 175 // Redeclare the signature to return a DirichletSampler not a SharedStateObjectSampler<double[]> 176 @Override 177 public abstract DirichletSampler withUniformRandomProvider(UniformRandomProvider rng); 178 179 /** 180 * Creates a new Dirichlet distribution sampler. 181 * 182 * @param rng Generator of uniformly distributed random numbers. 183 * @param alpha Concentration parameters. 184 * @return the sampler 185 * @throws IllegalArgumentException if the number of concentration parameters 186 * is less than 2; or if any concentration parameter is not strictly positive. 187 */ 188 public static DirichletSampler of(UniformRandomProvider rng, 189 double... alpha) { 190 validateNumberOfCategories(alpha.length); 191 final SharedStateContinuousSampler[] samplers = new SharedStateContinuousSampler[alpha.length]; 192 for (int i = 0; i < samplers.length; i++) { 193 samplers[i] = createSampler(rng, alpha[i]); 194 } 195 return new GeneralDirichletSampler(rng, samplers); 196 } 197 198 /** 199 * Creates a new symmetric Dirichlet distribution sampler using the same concentration 200 * parameter for each category. 201 * 202 * @param rng Generator of uniformly distributed random numbers. 203 * @param k Number of categories. 204 * @param alpha Concentration parameter. 205 * @return the sampler 206 * @throws IllegalArgumentException if the number of categories is 207 * less than 2; or if the concentration parameter is not strictly positive. 208 */ 209 public static DirichletSampler symmetric(UniformRandomProvider rng, 210 int k, 211 double alpha) { 212 validateNumberOfCategories(k); 213 final SharedStateContinuousSampler sampler = createSampler(rng, alpha); 214 return new SymmetricDirichletSampler(rng, k, sampler); 215 } 216 217 /** 218 * Validate the number of categories. 219 * 220 * @param k Number of categories. 221 * @throws IllegalArgumentException if the number of categories is 222 * less than 2. 223 */ 224 private static void validateNumberOfCategories(int k) { 225 if (k < MIN_CATGEORIES) { 226 throw new IllegalArgumentException("Invalid number of categories: " + k); 227 } 228 } 229 230 /** 231 * Creates a gamma sampler for a category with the given concentration parameter. 232 * 233 * @param rng Generator of uniformly distributed random numbers. 234 * @param alpha Concentration parameter. 235 * @return the sampler 236 * @throws IllegalArgumentException if the concentration parameter is not strictly positive. 237 */ 238 private static SharedStateContinuousSampler createSampler(UniformRandomProvider rng, 239 double alpha) { 240 // Negation of logic will detect NaN 241 if (!isNonZeroPositiveFinite(alpha)) { 242 throw new IllegalArgumentException("Invalid concentration: " + alpha); 243 } 244 // Create a Gamma(shape=alpha, scale=1) sampler. 245 if (alpha == 1) { 246 // Special case 247 // Gamma(shape=1, scale=1) == Exponential(mean=1) 248 return ZigguratSampler.Exponential.of(rng); 249 } 250 return AhrensDieterMarsagliaTsangGammaSampler.of(rng, alpha, 1); 251 } 252 253 /** 254 * Return true if the value is non-zero, positive and finite. 255 * 256 * @param x Value. 257 * @return true if non-zero positive finite 258 */ 259 private static boolean isNonZeroPositiveFinite(double x) { 260 return x > 0 && x < Double.POSITIVE_INFINITY; 261 } 262}