1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.rng.sampling.distribution;
18
19 import org.apache.commons.rng.UniformRandomProvider;
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 public class ChengBetaSampler
39 extends SamplerBase
40 implements SharedStateContinuousSampler {
41
42 private static final double LN_4 = Math.log(4.0);
43
44
45 private final SharedStateContinuousSampler delegate;
46
47
48
49
50 private abstract static class BaseChengBetaSampler
51 implements SharedStateContinuousSampler {
52
53
54
55
56
57
58 protected final boolean aIsAlphaShape;
59
60
61
62
63 protected final double a;
64
65
66
67
68 protected final double b;
69
70 protected final UniformRandomProvider rng;
71
72
73
74
75 protected final double alpha;
76
77 protected final double logAlpha;
78
79
80
81
82
83
84
85 BaseChengBetaSampler(UniformRandomProvider rng, boolean aIsAlphaShape, double a, double b) {
86 this.rng = rng;
87 this.aIsAlphaShape = aIsAlphaShape;
88 this.a = a;
89 this.b = b;
90 alpha = a + b;
91 logAlpha = Math.log(alpha);
92 }
93
94
95
96
97
98 private BaseChengBetaSampler(UniformRandomProvider rng,
99 BaseChengBetaSampler source) {
100 this.rng = rng;
101 aIsAlphaShape = source.aIsAlphaShape;
102 a = source.a;
103 b = source.b;
104
105 alpha = source.alpha;
106 logAlpha = source.logAlpha;
107 }
108
109
110 @Override
111 public String toString() {
112 return "Cheng Beta deviate [" + rng.toString() + "]";
113 }
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128 protected double computeX(double w) {
129
130 final double tmp = Math.min(w, Double.MAX_VALUE);
131 return aIsAlphaShape ? tmp / (b + tmp) : b / (b + tmp);
132 }
133 }
134
135
136
137
138
139 private static class ChengBBBetaSampler extends BaseChengBetaSampler {
140
141 private static final double LN_5_P1 = 1 + Math.log(5.0);
142
143
144 private final double beta;
145
146 private final double gamma;
147
148
149
150
151
152
153
154 ChengBBBetaSampler(UniformRandomProvider rng, boolean aIsAlphaShape, double a, double b) {
155 super(rng, aIsAlphaShape, a, b);
156 beta = Math.sqrt((alpha - 2) / (2 * a * b - alpha));
157 gamma = a + 1 / beta;
158 }
159
160
161
162
163
164 private ChengBBBetaSampler(UniformRandomProvider rng,
165 ChengBBBetaSampler source) {
166 super(rng, source);
167
168 beta = source.beta;
169 gamma = source.gamma;
170 }
171
172 @Override
173 public double sample() {
174 double r;
175 double w;
176 double t;
177 do {
178
179 final double u1 = rng.nextDouble();
180 final double u2 = rng.nextDouble();
181 final double v = beta * (Math.log(u1) - Math.log1p(-u1));
182 w = a * Math.exp(v);
183 final double z = u1 * u1 * u2;
184 r = gamma * v - LN_4;
185 final double s = a + r - w;
186
187 if (s + LN_5_P1 >= 5 * z) {
188 break;
189 }
190
191
192 t = Math.log(z);
193 if (s >= t) {
194 break;
195 }
196
197 } while (r + alpha * (logAlpha - Math.log(b + w)) < t);
198
199
200 return computeX(w);
201 }
202
203 @Override
204 public SharedStateContinuousSampler withUniformRandomProvider(UniformRandomProvider rng) {
205 return new ChengBBBetaSampler(rng, this);
206 }
207 }
208
209
210
211
212
213 private static class ChengBCBetaSampler extends BaseChengBetaSampler {
214
215 private static final double ONE_HALF = 1d / 2;
216
217 private static final double ONE_QUARTER = 1d / 4;
218
219
220 private final double beta;
221
222 private final double delta;
223
224 private final double k1;
225
226 private final double k2;
227
228
229
230
231
232
233
234 ChengBCBetaSampler(UniformRandomProvider rng, boolean aIsAlphaShape, double a, double b) {
235 super(rng, aIsAlphaShape, a, b);
236
237 beta = 1 / b;
238 delta = 1 + a - b;
239
240
241
242 k1 = delta * (1.0 / 72.0 + 3.0 / 72.0 * b) / (a * beta - 7.0 / 9.0);
243 k2 = 0.25 + (0.5 + 0.25 / delta) * b;
244 }
245
246
247
248
249
250 private ChengBCBetaSampler(UniformRandomProvider rng,
251 ChengBCBetaSampler source) {
252 super(rng, source);
253 beta = source.beta;
254 delta = source.delta;
255 k1 = source.k1;
256 k2 = source.k2;
257 }
258
259 @Override
260 public double sample() {
261 double w;
262 while (true) {
263
264 final double u1 = rng.nextDouble();
265 final double u2 = rng.nextDouble();
266
267 final double y = u1 * u2;
268 final double z = u1 * y;
269 if (u1 < ONE_HALF) {
270
271 if (ONE_QUARTER * u2 + z - y >= k1) {
272 continue;
273 }
274 } else {
275
276 if (z <= ONE_QUARTER) {
277 final double v = beta * (Math.log(u1) - Math.log1p(-u1));
278 w = a * Math.exp(v);
279 break;
280 }
281
282
283 if (z >= k2) {
284 continue;
285 }
286 }
287
288
289 final double v = beta * (Math.log(u1) - Math.log1p(-u1));
290 w = a * Math.exp(v);
291 if (alpha * (logAlpha - Math.log(b + w) + v) - LN_4 >= Math.log(z)) {
292 break;
293 }
294 }
295
296
297 return computeX(w);
298 }
299
300 @Override
301 public SharedStateContinuousSampler withUniformRandomProvider(UniformRandomProvider rng) {
302 return new ChengBCBetaSampler(rng, this);
303 }
304 }
305
306
307
308
309
310
311
312
313
314 public ChengBetaSampler(UniformRandomProvider rng,
315 double alpha,
316 double beta) {
317 super(null);
318 delegate = of(rng, alpha, beta);
319 }
320
321
322 @Override
323 public double sample() {
324 return delegate.sample();
325 }
326
327
328 @Override
329 public String toString() {
330 return delegate.toString();
331 }
332
333
334
335
336
337
338 @Override
339 public SharedStateContinuousSampler withUniformRandomProvider(UniformRandomProvider rng) {
340 return delegate.withUniformRandomProvider(rng);
341 }
342
343
344
345
346
347
348
349
350
351
352
353 public static SharedStateContinuousSampler of(UniformRandomProvider rng,
354 double alpha,
355 double beta) {
356 if (alpha <= 0) {
357 throw new IllegalArgumentException("alpha is not strictly positive: " + alpha);
358 }
359 if (beta <= 0) {
360 throw new IllegalArgumentException("beta is not strictly positive: " + beta);
361 }
362
363
364 final double a = Math.min(alpha, beta);
365 final double b = Math.max(alpha, beta);
366 final boolean aIsAlphaShape = a == alpha;
367
368 return a > 1 ?
369
370 new ChengBBBetaSampler(rng, aIsAlphaShape, a, b) :
371
372
373
374 new ChengBCBetaSampler(rng, !aIsAlphaShape, b, a);
375 }
376 }