1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17 package org.apache.commons.rng.sampling.distribution;
18
19 import org.apache.commons.rng.UniformRandomProvider;
20 import org.apache.commons.rng.sampling.SharedStateObjectSampler;
21
22 /**
23 * Sampling from a <a href="https://en.wikipedia.org/wiki/Dirichlet_distribution">Dirichlet
24 * distribution</a>.
25 *
26 * <p>Sampling uses:</p>
27 *
28 * <ul>
29 * <li>{@link UniformRandomProvider#nextLong()}
30 * <li>{@link UniformRandomProvider#nextDouble()}
31 * </ul>
32 *
33 * @since 1.4
34 */
35 public abstract class DirichletSampler implements SharedStateObjectSampler<double[]> {
36 /** The minimum number of categories. */
37 private static final int MIN_CATGEORIES = 2;
38
39 /** RNG (used for the toString() method). */
40 private final UniformRandomProvider rng;
41
42 /**
43 * Sample from a Dirichlet distribution with different concentration parameters
44 * for each category.
45 */
46 private static final class GeneralDirichletSampler extends DirichletSampler {
47 /** Samplers for each category. */
48 private final SharedStateContinuousSampler[] samplers;
49
50 /**
51 * @param rng Generator of uniformly distributed random numbers.
52 * @param samplers Samplers for each category.
53 */
54 GeneralDirichletSampler(UniformRandomProvider rng,
55 SharedStateContinuousSampler[] samplers) {
56 super(rng);
57 // Array is stored directly as it is generated within the DirichletSampler class
58 this.samplers = samplers;
59 }
60
61 @Override
62 protected int getK() {
63 return samplers.length;
64 }
65
66 @Override
67 protected double nextGamma(int i) {
68 return samplers[i].sample();
69 }
70
71 @Override
72 public GeneralDirichletSampler withUniformRandomProvider(UniformRandomProvider rng) {
73 final SharedStateContinuousSampler[] newSamplers = new SharedStateContinuousSampler[samplers.length];
74 for (int i = 0; i < newSamplers.length; i++) {
75 newSamplers[i] = samplers[i].withUniformRandomProvider(rng);
76 }
77 return new GeneralDirichletSampler(rng, newSamplers);
78 }
79 }
80
81 /**
82 * Sample from a symmetric Dirichlet distribution with the same concentration parameter
83 * for each category.
84 */
85 private static final class SymmetricDirichletSampler extends DirichletSampler {
86 /** Number of categories. */
87 private final int k;
88 /** Sampler for the categories. */
89 private final SharedStateContinuousSampler sampler;
90
91 /**
92 * @param rng Generator of uniformly distributed random numbers.
93 * @param k Number of categories.
94 * @param sampler Sampler for the categories.
95 */
96 SymmetricDirichletSampler(UniformRandomProvider rng,
97 int k,
98 SharedStateContinuousSampler sampler) {
99 super(rng);
100 this.k = k;
101 this.sampler = sampler;
102 }
103
104 @Override
105 protected int getK() {
106 return k;
107 }
108
109 @Override
110 protected double nextGamma(int i) {
111 return sampler.sample();
112 }
113
114 @Override
115 public SymmetricDirichletSampler withUniformRandomProvider(UniformRandomProvider rng) {
116 return new SymmetricDirichletSampler(rng, k, sampler.withUniformRandomProvider(rng));
117 }
118 }
119
120 /**
121 * @param rng Generator of uniformly distributed random numbers.
122 */
123 DirichletSampler(UniformRandomProvider rng) {
124 this.rng = rng;
125 }
126
127 /** {@inheritDoc} */
128 @Override
129 public String toString() {
130 return "Dirichlet deviate [" + rng.toString() + "]";
131 }
132
133 /** {@inheritDoc} */
134 @Override
135 public double[] sample() {
136 // Create Gamma(alpha_i, 1) deviates for all alpha
137 final double[] y = new double[getK()];
138 double norm = 0;
139 for (int i = 0; i < y.length; i++) {
140 final double yi = nextGamma(i);
141 norm += yi;
142 y[i] = yi;
143 }
144 // Normalize by dividing by the sum of the samples
145 norm = 1.0 / norm;
146 // Detect an invalid normalization, e.g. cases of all zero samples
147 if (!isNonZeroPositiveFinite(norm)) {
148 // Sample again using recursion.
149 // A stack overflow due to a broken RNG will eventually occur
150 // rather than the alternative which is an infinite loop.
151 return sample();
152 }
153 // Normalise
154 for (int i = 0; i < y.length; i++) {
155 y[i] *= norm;
156 }
157 return y;
158 }
159
160 /**
161 * Gets the number of categories.
162 *
163 * @return k
164 */
165 protected abstract int getK();
166
167 /**
168 * Create a gamma sample for the given category.
169 *
170 * @param category Category.
171 * @return the sample
172 */
173 protected abstract double nextGamma(int category);
174
175 /** {@inheritDoc} */
176 // Redeclare the signature to return a DirichletSampler not a SharedStateObjectSampler<double[]>
177 @Override
178 public abstract DirichletSampler withUniformRandomProvider(UniformRandomProvider rng);
179
180 /**
181 * Creates a new Dirichlet distribution sampler.
182 *
183 * @param rng Generator of uniformly distributed random numbers.
184 * @param alpha Concentration parameters.
185 * @return the sampler
186 * @throws IllegalArgumentException if the number of concentration parameters
187 * is less than 2; or if any concentration parameter is not strictly positive.
188 */
189 public static DirichletSampler of(UniformRandomProvider rng,
190 double... alpha) {
191 validateNumberOfCategories(alpha.length);
192 final SharedStateContinuousSampler[] samplers = new SharedStateContinuousSampler[alpha.length];
193 for (int i = 0; i < samplers.length; i++) {
194 samplers[i] = createSampler(rng, alpha[i]);
195 }
196 return new GeneralDirichletSampler(rng, samplers);
197 }
198
199 /**
200 * Creates a new symmetric Dirichlet distribution sampler using the same concentration
201 * parameter for each category.
202 *
203 * @param rng Generator of uniformly distributed random numbers.
204 * @param k Number of categories.
205 * @param alpha Concentration parameter.
206 * @return the sampler
207 * @throws IllegalArgumentException if the number of categories is
208 * less than 2; or if the concentration parameter is not strictly positive.
209 */
210 public static DirichletSampler symmetric(UniformRandomProvider rng,
211 int k,
212 double alpha) {
213 validateNumberOfCategories(k);
214 final SharedStateContinuousSampler sampler = createSampler(rng, alpha);
215 return new SymmetricDirichletSampler(rng, k, sampler);
216 }
217
218 /**
219 * Validate the number of categories.
220 *
221 * @param k Number of categories.
222 * @throws IllegalArgumentException if the number of categories is
223 * less than 2.
224 */
225 private static void validateNumberOfCategories(int k) {
226 if (k < MIN_CATGEORIES) {
227 throw new IllegalArgumentException("Invalid number of categories: " + k);
228 }
229 }
230
231 /**
232 * Creates a gamma sampler for a category with the given concentration parameter.
233 *
234 * @param rng Generator of uniformly distributed random numbers.
235 * @param alpha Concentration parameter.
236 * @return the sampler
237 * @throws IllegalArgumentException if the concentration parameter is not strictly positive.
238 */
239 private static SharedStateContinuousSampler createSampler(UniformRandomProvider rng,
240 double alpha) {
241 InternalUtils.requireStrictlyPositiveFinite(alpha, "alpha concentration");
242 // Create a Gamma(shape=alpha, scale=1) sampler.
243 if (alpha == 1) {
244 // Special case
245 // Gamma(shape=1, scale=1) == Exponential(mean=1)
246 return ZigguratSampler.Exponential.of(rng);
247 }
248 return AhrensDieterMarsagliaTsangGammaSampler.of(rng, alpha, 1);
249 }
250
251 /**
252 * Return true if the value is non-zero, positive and finite.
253 *
254 * @param x Value.
255 * @return true if non-zero positive finite
256 */
257 private static boolean isNonZeroPositiveFinite(double x) {
258 return x > 0 && x < Double.POSITIVE_INFINITY;
259 }
260 }