001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.commons.rng.sampling.distribution; 018 019import org.apache.commons.rng.UniformRandomProvider; 020import org.apache.commons.rng.sampling.SharedStateObjectSampler; 021 022/** 023 * Sampling from a <a href="https://en.wikipedia.org/wiki/Dirichlet_distribution">Dirichlet 024 * distribution</a>. 025 * 026 * <p>Sampling uses:</p> 027 * 028 * <ul> 029 * <li>{@link UniformRandomProvider#nextLong()} 030 * <li>{@link UniformRandomProvider#nextDouble()} 031 * </ul> 032 * 033 * @since 1.4 034 */ 035public abstract class DirichletSampler implements SharedStateObjectSampler<double[]> { 036 /** The minimum number of categories. */ 037 private static final int MIN_CATGEORIES = 2; 038 039 /** RNG (used for the toString() method). */ 040 private final UniformRandomProvider rng; 041 042 /** 043 * Sample from a Dirichlet distribution with different concentration parameters 044 * for each category. 045 */ 046 private static final class GeneralDirichletSampler extends DirichletSampler { 047 /** Samplers for each category. */ 048 private final SharedStateContinuousSampler[] samplers; 049 050 /** 051 * @param rng Generator of uniformly distributed random numbers. 052 * @param samplers Samplers for each category. 053 */ 054 GeneralDirichletSampler(UniformRandomProvider rng, 055 SharedStateContinuousSampler[] samplers) { 056 super(rng); 057 // Array is stored directly as it is generated within the DirichletSampler class 058 this.samplers = samplers; 059 } 060 061 @Override 062 protected int getK() { 063 return samplers.length; 064 } 065 066 @Override 067 protected double nextGamma(int i) { 068 return samplers[i].sample(); 069 } 070 071 @Override 072 public GeneralDirichletSampler withUniformRandomProvider(UniformRandomProvider rng) { 073 final SharedStateContinuousSampler[] newSamplers = new SharedStateContinuousSampler[samplers.length]; 074 for (int i = 0; i < newSamplers.length; i++) { 075 newSamplers[i] = samplers[i].withUniformRandomProvider(rng); 076 } 077 return new GeneralDirichletSampler(rng, newSamplers); 078 } 079 } 080 081 /** 082 * Sample from a symmetric Dirichlet distribution with the same concentration parameter 083 * for each category. 084 */ 085 private static final class SymmetricDirichletSampler extends DirichletSampler { 086 /** Number of categories. */ 087 private final int k; 088 /** Sampler for the categories. */ 089 private final SharedStateContinuousSampler sampler; 090 091 /** 092 * @param rng Generator of uniformly distributed random numbers. 093 * @param k Number of categories. 094 * @param sampler Sampler for the categories. 095 */ 096 SymmetricDirichletSampler(UniformRandomProvider rng, 097 int k, 098 SharedStateContinuousSampler sampler) { 099 super(rng); 100 this.k = k; 101 this.sampler = sampler; 102 } 103 104 @Override 105 protected int getK() { 106 return k; 107 } 108 109 @Override 110 protected double nextGamma(int i) { 111 return sampler.sample(); 112 } 113 114 @Override 115 public SymmetricDirichletSampler withUniformRandomProvider(UniformRandomProvider rng) { 116 return new SymmetricDirichletSampler(rng, k, sampler.withUniformRandomProvider(rng)); 117 } 118 } 119 120 /** 121 * @param rng Generator of uniformly distributed random numbers. 122 */ 123 DirichletSampler(UniformRandomProvider rng) { 124 this.rng = rng; 125 } 126 127 /** {@inheritDoc} */ 128 @Override 129 public String toString() { 130 return "Dirichlet deviate [" + rng.toString() + "]"; 131 } 132 133 /** {@inheritDoc} */ 134 @Override 135 public double[] sample() { 136 // Create Gamma(alpha_i, 1) deviates for all alpha 137 final double[] y = new double[getK()]; 138 double norm = 0; 139 for (int i = 0; i < y.length; i++) { 140 final double yi = nextGamma(i); 141 norm += yi; 142 y[i] = yi; 143 } 144 // Normalize by dividing by the sum of the samples 145 norm = 1.0 / norm; 146 // Detect an invalid normalization, e.g. cases of all zero samples 147 if (!isNonZeroPositiveFinite(norm)) { 148 // Sample again using recursion. 149 // A stack overflow due to a broken RNG will eventually occur 150 // rather than the alternative which is an infinite loop. 151 return sample(); 152 } 153 // Normalise 154 for (int i = 0; i < y.length; i++) { 155 y[i] *= norm; 156 } 157 return y; 158 } 159 160 /** 161 * Gets the number of categories. 162 * 163 * @return k 164 */ 165 protected abstract int getK(); 166 167 /** 168 * Create a gamma sample for the given category. 169 * 170 * @param category Category. 171 * @return the sample 172 */ 173 protected abstract double nextGamma(int category); 174 175 /** {@inheritDoc} */ 176 // Redeclare the signature to return a DirichletSampler not a SharedStateObjectSampler<double[]> 177 @Override 178 public abstract DirichletSampler withUniformRandomProvider(UniformRandomProvider rng); 179 180 /** 181 * Creates a new Dirichlet distribution sampler. 182 * 183 * @param rng Generator of uniformly distributed random numbers. 184 * @param alpha Concentration parameters. 185 * @return the sampler 186 * @throws IllegalArgumentException if the number of concentration parameters 187 * is less than 2; or if any concentration parameter is not strictly positive. 188 */ 189 public static DirichletSampler of(UniformRandomProvider rng, 190 double... alpha) { 191 validateNumberOfCategories(alpha.length); 192 final SharedStateContinuousSampler[] samplers = new SharedStateContinuousSampler[alpha.length]; 193 for (int i = 0; i < samplers.length; i++) { 194 samplers[i] = createSampler(rng, alpha[i]); 195 } 196 return new GeneralDirichletSampler(rng, samplers); 197 } 198 199 /** 200 * Creates a new symmetric Dirichlet distribution sampler using the same concentration 201 * parameter for each category. 202 * 203 * @param rng Generator of uniformly distributed random numbers. 204 * @param k Number of categories. 205 * @param alpha Concentration parameter. 206 * @return the sampler 207 * @throws IllegalArgumentException if the number of categories is 208 * less than 2; or if the concentration parameter is not strictly positive. 209 */ 210 public static DirichletSampler symmetric(UniformRandomProvider rng, 211 int k, 212 double alpha) { 213 validateNumberOfCategories(k); 214 final SharedStateContinuousSampler sampler = createSampler(rng, alpha); 215 return new SymmetricDirichletSampler(rng, k, sampler); 216 } 217 218 /** 219 * Validate the number of categories. 220 * 221 * @param k Number of categories. 222 * @throws IllegalArgumentException if the number of categories is 223 * less than 2. 224 */ 225 private static void validateNumberOfCategories(int k) { 226 if (k < MIN_CATGEORIES) { 227 throw new IllegalArgumentException("Invalid number of categories: " + k); 228 } 229 } 230 231 /** 232 * Creates a gamma sampler for a category with the given concentration parameter. 233 * 234 * @param rng Generator of uniformly distributed random numbers. 235 * @param alpha Concentration parameter. 236 * @return the sampler 237 * @throws IllegalArgumentException if the concentration parameter is not strictly positive. 238 */ 239 private static SharedStateContinuousSampler createSampler(UniformRandomProvider rng, 240 double alpha) { 241 InternalUtils.requireStrictlyPositiveFinite(alpha, "alpha concentration"); 242 // Create a Gamma(shape=alpha, scale=1) sampler. 243 if (alpha == 1) { 244 // Special case 245 // Gamma(shape=1, scale=1) == Exponential(mean=1) 246 return ZigguratSampler.Exponential.of(rng); 247 } 248 return AhrensDieterMarsagliaTsangGammaSampler.of(rng, alpha, 1); 249 } 250 251 /** 252 * Return true if the value is non-zero, positive and finite. 253 * 254 * @param x Value. 255 * @return true if non-zero positive finite 256 */ 257 private static boolean isNonZeroPositiveFinite(double x) { 258 return x > 0 && x < Double.POSITIVE_INFINITY; 259 } 260}