1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.rng.sampling.distribution;
18
19 import org.apache.commons.rng.UniformRandomProvider;
20 import org.apache.commons.rng.sampling.SharedStateObjectSampler;
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35 public abstract class DirichletSampler implements SharedStateObjectSampler<double[]> {
36
37 private static final int MIN_CATGEORIES = 2;
38
39
40 private final UniformRandomProvider rng;
41
42
43
44
45
46 private static final class GeneralDirichletSampler extends DirichletSampler {
47
48 private final SharedStateContinuousSampler[] samplers;
49
50
51
52
53
54 GeneralDirichletSampler(UniformRandomProvider rng,
55 SharedStateContinuousSampler[] samplers) {
56 super(rng);
57
58 this.samplers = samplers;
59 }
60
61 @Override
62 protected int getK() {
63 return samplers.length;
64 }
65
66 @Override
67 protected double nextGamma(int i) {
68 return samplers[i].sample();
69 }
70
71 @Override
72 public GeneralDirichletSampler withUniformRandomProvider(UniformRandomProvider rng) {
73 final SharedStateContinuousSampler[] newSamplers = new SharedStateContinuousSampler[samplers.length];
74 for (int i = 0; i < newSamplers.length; i++) {
75 newSamplers[i] = samplers[i].withUniformRandomProvider(rng);
76 }
77 return new GeneralDirichletSampler(rng, newSamplers);
78 }
79 }
80
81
82
83
84
85 private static final class SymmetricDirichletSampler extends DirichletSampler {
86
87 private final int k;
88
89 private final SharedStateContinuousSampler sampler;
90
91
92
93
94
95
96 SymmetricDirichletSampler(UniformRandomProvider rng,
97 int k,
98 SharedStateContinuousSampler sampler) {
99 super(rng);
100 this.k = k;
101 this.sampler = sampler;
102 }
103
104 @Override
105 protected int getK() {
106 return k;
107 }
108
109 @Override
110 protected double nextGamma(int i) {
111 return sampler.sample();
112 }
113
114 @Override
115 public SymmetricDirichletSampler withUniformRandomProvider(UniformRandomProvider rng) {
116 return new SymmetricDirichletSampler(rng, k, sampler.withUniformRandomProvider(rng));
117 }
118 }
119
120
121
122
123 private DirichletSampler(UniformRandomProvider rng) {
124 this.rng = rng;
125 }
126
127
128 @Override
129 public String toString() {
130 return "Dirichlet deviate [" + rng.toString() + "]";
131 }
132
133 @Override
134 public double[] sample() {
135
136 final double[] y = new double[getK()];
137 double norm = 0;
138 for (int i = 0; i < y.length; i++) {
139 final double yi = nextGamma(i);
140 norm += yi;
141 y[i] = yi;
142 }
143
144 norm = 1.0 / norm;
145
146 if (!isNonZeroPositiveFinite(norm)) {
147
148
149
150 return sample();
151 }
152
153 for (int i = 0; i < y.length; i++) {
154 y[i] *= norm;
155 }
156 return y;
157 }
158
159
160
161
162
163
164 protected abstract int getK();
165
166
167
168
169
170
171
172 protected abstract double nextGamma(int category);
173
174
175
176 @Override
177 public abstract DirichletSampler withUniformRandomProvider(UniformRandomProvider rng);
178
179
180
181
182
183
184
185
186
187
188 public static DirichletSampler of(UniformRandomProvider rng,
189 double... alpha) {
190 validateNumberOfCategories(alpha.length);
191 final SharedStateContinuousSampler[] samplers = new SharedStateContinuousSampler[alpha.length];
192 for (int i = 0; i < samplers.length; i++) {
193 samplers[i] = createSampler(rng, alpha[i]);
194 }
195 return new GeneralDirichletSampler(rng, samplers);
196 }
197
198
199
200
201
202
203
204
205
206
207
208
209 public static DirichletSampler symmetric(UniformRandomProvider rng,
210 int k,
211 double alpha) {
212 validateNumberOfCategories(k);
213 final SharedStateContinuousSampler sampler = createSampler(rng, alpha);
214 return new SymmetricDirichletSampler(rng, k, sampler);
215 }
216
217
218
219
220
221
222
223
224 private static void validateNumberOfCategories(int k) {
225 if (k < MIN_CATGEORIES) {
226 throw new IllegalArgumentException("Invalid number of categories: " + k);
227 }
228 }
229
230
231
232
233
234
235
236
237
238 private static SharedStateContinuousSampler createSampler(UniformRandomProvider rng,
239 double alpha) {
240
241 if (!isNonZeroPositiveFinite(alpha)) {
242 throw new IllegalArgumentException("Invalid concentration: " + alpha);
243 }
244
245 if (alpha == 1) {
246
247
248 return ZigguratSampler.Exponential.of(rng);
249 }
250 return AhrensDieterMarsagliaTsangGammaSampler.of(rng, alpha, 1);
251 }
252
253
254
255
256
257
258
259 private static boolean isNonZeroPositiveFinite(double x) {
260 return x > 0 && x < Double.POSITIVE_INFINITY;
261 }
262 }