1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.statistics.distribution;
18
19 import org.apache.commons.numbers.gamma.Erf;
20 import org.apache.commons.numbers.gamma.ErfDifference;
21 import org.apache.commons.numbers.gamma.Erfc;
22 import org.apache.commons.numbers.gamma.InverseErf;
23 import org.apache.commons.numbers.gamma.InverseErfc;
24 import org.apache.commons.rng.UniformRandomProvider;
25 import org.apache.commons.rng.sampling.distribution.GaussianSampler;
26 import org.apache.commons.rng.sampling.distribution.SharedStateContinuousSampler;
27 import org.apache.commons.rng.sampling.distribution.ZigguratSampler;
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51 public abstract class FoldedNormalDistribution extends AbstractContinuousDistribution {
52
53 final double sigma;
54
55
56
57
58
59
60
61
62
63
64
65 final double sigmaSqrt2;
66
67
68
69 final double sigmaSqrt2pi;
70
71
72
73
74 private static class RegularFoldedNormalDistribution extends FoldedNormalDistribution {
75
76 private final double mu;
77
78 private final double mean;
79
80 private final double variance;
81
82
83
84
85
86 RegularFoldedNormalDistribution(double mu, double sigma) {
87 super(sigma);
88 this.mu = mu;
89
90 final double a = mu / sigmaSqrt2;
91 mean = sigma * Constants.ROOT_TWO_DIV_PI * Math.exp(-a * a) + mu * Erf.value(a);
92 this.variance = mu * mu + sigma * sigma - mean * mean;
93 }
94
95 @Override
96 public double getMu() {
97 return mu;
98 }
99
100 @Override
101 public double density(double x) {
102 if (x < 0) {
103 return 0;
104 }
105 final double vm = (x - mu) / sigma;
106 final double vp = (x + mu) / sigma;
107 return (ExtendedPrecision.expmhxx(vm) + ExtendedPrecision.expmhxx(vp)) / sigmaSqrt2pi;
108 }
109
110 @Override
111 public double probability(double x0,
112 double x1) {
113 if (x0 > x1) {
114 throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GT_HIGH,
115 x0, x1);
116 }
117 if (x0 <= 0) {
118 return cumulativeProbability(x1);
119 }
120
121 final double v0m = (x0 - mu) / sigmaSqrt2;
122 final double v1m = (x1 - mu) / sigmaSqrt2;
123 final double v0p = (x0 + mu) / sigmaSqrt2;
124 final double v1p = (x1 + mu) / sigmaSqrt2;
125 return 0.5 * (ErfDifference.value(v0m, v1m) + ErfDifference.value(v0p, v1p));
126 }
127
128 @Override
129 public double cumulativeProbability(double x) {
130 if (x <= 0) {
131 return 0;
132 }
133 return 0.5 * (Erf.value((x - mu) / sigmaSqrt2) + Erf.value((x + mu) / sigmaSqrt2));
134 }
135
136 @Override
137 public double survivalProbability(double x) {
138 if (x <= 0) {
139 return 1;
140 }
141 return 0.5 * (Erfc.value((x - mu) / sigmaSqrt2) + Erfc.value((x + mu) / sigmaSqrt2));
142 }
143
144 @Override
145 public double getMean() {
146 return mean;
147 }
148
149 @Override
150 public double getVariance() {
151 return variance;
152 }
153
154 @Override
155 public Sampler createSampler(UniformRandomProvider rng) {
156
157 final SharedStateContinuousSampler s =
158 GaussianSampler.of(ZigguratSampler.NormalizedGaussian.of(rng), mu, sigma);
159 return () -> Math.abs(s.sample());
160 }
161 }
162
163
164
165
166
167
168
169 private static class HalfNormalDistribution extends FoldedNormalDistribution {
170
171 private static final double VAR = 0.36338022763241865692446494650994;
172
173 private final double logSigmaPlusHalfLog2Pi;
174
175
176
177
178 HalfNormalDistribution(double sigma) {
179 super(sigma);
180 logSigmaPlusHalfLog2Pi = Math.log(sigma) + Constants.HALF_LOG_TWO_PI;
181 }
182
183 @Override
184 public double getMu() {
185 return 0;
186 }
187
188 @Override
189 public double density(double x) {
190 if (x < 0) {
191 return 0;
192 }
193 return 2 * ExtendedPrecision.expmhxx(x / sigma) / sigmaSqrt2pi;
194 }
195
196 @Override
197 public double probability(double x0,
198 double x1) {
199 if (x0 > x1) {
200 throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GT_HIGH,
201 x0, x1);
202 }
203 if (x0 <= 0) {
204 return cumulativeProbability(x1);
205 }
206
207 return ErfDifference.value(x0 / sigmaSqrt2, x1 / sigmaSqrt2);
208 }
209
210 @Override
211 public double logDensity(double x) {
212 if (x < 0) {
213 return Double.NEGATIVE_INFINITY;
214 }
215 final double z = x / sigma;
216 return Constants.LN_TWO - 0.5 * z * z - logSigmaPlusHalfLog2Pi;
217 }
218
219 @Override
220 public double cumulativeProbability(double x) {
221 if (x <= 0) {
222 return 0;
223 }
224 return Erf.value(x / sigmaSqrt2);
225 }
226
227 @Override
228 public double survivalProbability(double x) {
229 if (x <= 0) {
230 return 1;
231 }
232 return Erfc.value(x / sigmaSqrt2);
233 }
234
235 @Override
236 public double inverseCumulativeProbability(double p) {
237 ArgumentUtils.checkProbability(p);
238
239 return 0.0 + sigmaSqrt2 * InverseErf.value(p);
240 }
241
242
243 @Override
244 public double inverseSurvivalProbability(double p) {
245 ArgumentUtils.checkProbability(p);
246 return sigmaSqrt2 * InverseErfc.value(p);
247 }
248
249 @Override
250 public double getMean() {
251 return sigma * Constants.ROOT_TWO_DIV_PI;
252 }
253
254 @Override
255 public double getVariance() {
256
257
258 return sigma * sigma * VAR;
259 }
260
261 @Override
262 public Sampler createSampler(UniformRandomProvider rng) {
263
264 final SharedStateContinuousSampler s = ZigguratSampler.NormalizedGaussian.of(rng);
265 return () -> Math.abs(s.sample() * sigma);
266 }
267 }
268
269
270
271
272 FoldedNormalDistribution(double sigma) {
273 this.sigma = sigma;
274
275
276 sigmaSqrt2 = ExtendedPrecision.sqrt2xx(sigma);
277
278 sigmaSqrt2pi = ExtendedPrecision.xsqrt2pi(sigma);
279 }
280
281
282
283
284
285
286
287
288
289
290 public static FoldedNormalDistribution of(double mu,
291 double sigma) {
292 if (sigma > 0) {
293 if (mu == 0) {
294 return new HalfNormalDistribution(sigma);
295 }
296 return new RegularFoldedNormalDistribution(mu, sigma);
297 }
298
299 throw new DistributionException(DistributionException.NOT_STRICTLY_POSITIVE, sigma);
300 }
301
302
303
304
305
306
307 public abstract double getMu();
308
309
310
311
312
313
314 public double getSigma() {
315 return sigma;
316 }
317
318
319
320
321
322
323
324
325
326
327
328
329 @Override
330 public abstract double getMean();
331
332
333
334
335
336
337
338
339
340 @Override
341 public abstract double getVariance();
342
343
344
345
346
347
348
349
350 @Override
351 public double getSupportLowerBound() {
352 return 0.0;
353 }
354
355
356
357
358
359
360
361
362 @Override
363 public double getSupportUpperBound() {
364 return Double.POSITIVE_INFINITY;
365 }
366 }