1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.commons.statistics.distribution;
19
20 import java.util.function.DoubleSupplier;
21 import org.apache.commons.numbers.gamma.Erf;
22 import org.apache.commons.numbers.gamma.ErfDifference;
23 import org.apache.commons.numbers.gamma.Erfcx;
24 import org.apache.commons.rng.UniformRandomProvider;
25 import org.apache.commons.rng.sampling.distribution.ZigguratSampler;
26
27 /**
28 * Implementation of the truncated normal distribution.
29 *
30 * <p>The probability density function of \( X \) is:
31 *
32 * <p>\[ f(x;\mu,\sigma,a,b) = \frac{1}{\sigma}\,\frac{\phi(\frac{x - \mu}{\sigma})}{\Phi(\frac{b - \mu}{\sigma}) - \Phi(\frac{a - \mu}{\sigma}) } \]
33 *
34 * <p>for \( \mu \) mean of the parent normal distribution,
35 * \( \sigma \) standard deviation of the parent normal distribution,
36 * \( -\infty \le a \lt b \le \infty \) the truncation interval, and
37 * \( x \in [a, b] \), where \( \phi \) is the probability
38 * density function of the standard normal distribution and \( \Phi \)
39 * is its cumulative distribution function.
40 *
41 * @see <a href="https://en.wikipedia.org/wiki/Truncated_normal_distribution">
42 * Truncated normal distribution (Wikipedia)</a>
43 */
44 public final class TruncatedNormalDistribution extends AbstractContinuousDistribution {
45
46 /** The max allowed value for x where (x*x) will not overflow.
47 * This is a limit on computation of the moments of the truncated normal
48 * as some calculations assume x*x is finite. Value is sqrt(MAX_VALUE). */
49 private static final double MAX_X = 0x1.fffffffffffffp511;
50
51 /** The min allowed probability range of the parent normal distribution.
52 * Set to 0.0. This may be too low for accurate usage. It is a signal that
53 * the truncation is invalid. */
54 private static final double MIN_P = 0.0;
55
56 /** sqrt(2). */
57 private static final double ROOT2 = Constants.ROOT_TWO;
58 /** Normalisation constant 2 / sqrt(2 pi) = sqrt(2 / pi). */
59 private static final double ROOT_2_PI = Constants.ROOT_TWO_DIV_PI;
60 /** Normalisation constant sqrt(2 pi) / 2 = sqrt(pi / 2). */
61 private static final double ROOT_PI_2 = Constants.ROOT_PI_DIV_TWO;
62
63 /**
64 * The threshold to switch to a rejection sampler. When the truncated
65 * distribution covers more than this fraction of the CDF then rejection
66 * sampling will be more efficient than inverse CDF sampling. Performance
67 * benchmarks indicate that a normalized Gaussian sampler is up to 10 times
68 * faster than inverse transform sampling using a fast random generator. See
69 * STATISTICS-55.
70 */
71 private static final double REJECTION_THRESHOLD = 0.2;
72
73 /** Parent normal distribution. */
74 private final NormalDistribution parentNormal;
75 /** Lower bound of this distribution. */
76 private final double lower;
77 /** Upper bound of this distribution. */
78 private final double upper;
79
80 /** Stored value of {@code parentNormal.probability(lower, upper)}. This is used to
81 * normalise the probability computations. */
82 private final double cdfDelta;
83 /** log(cdfDelta). */
84 private final double logCdfDelta;
85 /** Stored value of {@code parentNormal.cumulativeProbability(lower)}. Used to map
86 * a probability into the range of the parent normal distribution. */
87 private final double cdfAlpha;
88 /** Stored value of {@code parentNormal.survivalProbability(upper)}. Used to map
89 * a probability into the range of the parent normal distribution. */
90 private final double sfBeta;
91
92 /**
93 * @param parent Parent distribution.
94 * @param z Probability of the parent distribution for {@code [lower, upper]}.
95 * @param lower Lower bound (inclusive) of the distribution, can be {@link Double#NEGATIVE_INFINITY}.
96 * @param upper Upper bound (inclusive) of the distribution, can be {@link Double#POSITIVE_INFINITY}.
97 */
98 private TruncatedNormalDistribution(NormalDistribution parent, double z, double lower, double upper) {
99 this.parentNormal = parent;
100 this.lower = lower;
101 this.upper = upper;
102
103 cdfDelta = z;
104 logCdfDelta = Math.log(cdfDelta);
105 // Used to map the inverse probability.
106 cdfAlpha = parentNormal.cumulativeProbability(lower);
107 sfBeta = parentNormal.survivalProbability(upper);
108 }
109
110 /**
111 * Creates a truncated normal distribution.
112 *
113 * <p>Note that the {@code mean} and {@code sd} is of the parent normal distribution,
114 * and not the true mean and standard deviation of the truncated normal distribution.
115 * The {@code lower} and {@code upper} bounds define the truncation of the parent
116 * normal distribution.
117 *
118 * @param mean Mean for the parent distribution.
119 * @param sd Standard deviation for the parent distribution.
120 * @param lower Lower bound (inclusive) of the distribution, can be {@link Double#NEGATIVE_INFINITY}.
121 * @param upper Upper bound (inclusive) of the distribution, can be {@link Double#POSITIVE_INFINITY}.
122 * @return the distribution
123 * @throws IllegalArgumentException if {@code sd <= 0}; if {@code lower >= upper}; or if
124 * the truncation covers no probability range in the parent distribution.
125 */
126 public static TruncatedNormalDistribution of(double mean, double sd, double lower, double upper) {
127 if (sd <= 0) {
128 throw new DistributionException(DistributionException.NOT_STRICTLY_POSITIVE, sd);
129 }
130 if (lower >= upper) {
131 throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GTE_HIGH, lower, upper);
132 }
133
134 // Use an instance for the parent normal distribution to maximise accuracy
135 // in range computations using the error function
136 final NormalDistribution parent = NormalDistribution.of(mean, sd);
137
138 // If there is no computable range then raise an exception.
139 final double z = parent.probability(lower, upper);
140 if (z <= MIN_P) {
141 // Map the bounds to a standard normal distribution for the message
142 final double a = (lower - mean) / sd;
143 final double b = (upper - mean) / sd;
144 throw new DistributionException(
145 "Excess truncation of standard normal : CDF(%s, %s) = %s", a, b, z);
146 }
147
148 // Here we have a meaningful truncation. Note that excess truncation may not be optimal.
149 // For example truncation close to zero where the PDF is constant can be approximated
150 // using a uniform distribution.
151
152 return new TruncatedNormalDistribution(parent, z, lower, upper);
153 }
154
155 /** {@inheritDoc} */
156 @Override
157 public double density(double x) {
158 if (x < lower || x > upper) {
159 return 0;
160 }
161 return parentNormal.density(x) / cdfDelta;
162 }
163
164 /** {@inheritDoc} */
165 @Override
166 public double probability(double x0, double x1) {
167 if (x0 > x1) {
168 throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GT_HIGH,
169 x0, x1);
170 }
171 return parentNormal.probability(clipToRange(x0), clipToRange(x1)) / cdfDelta;
172 }
173
174 /** {@inheritDoc} */
175 @Override
176 public double logDensity(double x) {
177 if (x < lower || x > upper) {
178 return Double.NEGATIVE_INFINITY;
179 }
180 return parentNormal.logDensity(x) - logCdfDelta;
181 }
182
183 /** {@inheritDoc} */
184 @Override
185 public double cumulativeProbability(double x) {
186 if (x <= lower) {
187 return 0;
188 } else if (x >= upper) {
189 return 1;
190 }
191 return parentNormal.probability(lower, x) / cdfDelta;
192 }
193
194 /** {@inheritDoc} */
195 @Override
196 public double survivalProbability(double x) {
197 if (x <= lower) {
198 return 1;
199 } else if (x >= upper) {
200 return 0;
201 }
202 return parentNormal.probability(x, upper) / cdfDelta;
203 }
204
205 /** {@inheritDoc} */
206 @Override
207 public double inverseCumulativeProbability(double p) {
208 ArgumentUtils.checkProbability(p);
209 // Exact bound
210 if (p == 0) {
211 return lower;
212 } else if (p == 1) {
213 return upper;
214 }
215 // Linearly map p to the range [lower, upper]
216 final double x = parentNormal.inverseCumulativeProbability(cdfAlpha + p * cdfDelta);
217 return clipToRange(x);
218 }
219
220 /** {@inheritDoc} */
221 @Override
222 public double inverseSurvivalProbability(double p) {
223 ArgumentUtils.checkProbability(p);
224 // Exact bound
225 if (p == 1) {
226 return lower;
227 } else if (p == 0) {
228 return upper;
229 }
230 // Linearly map p to the range [lower, upper]
231 final double x = parentNormal.inverseSurvivalProbability(sfBeta + p * cdfDelta);
232 return clipToRange(x);
233 }
234
235 /** {@inheritDoc} */
236 @Override
237 public Sampler createSampler(UniformRandomProvider rng) {
238 // If the truncation covers a reasonable amount of the normal distribution
239 // then a rejection sampler can be used.
240 double threshold = REJECTION_THRESHOLD;
241 // If the truncation is entirely in the upper or lower half then adjust the
242 // threshold as twice the samples can be used
243 if (lower >= 0 || upper <= 0) {
244 threshold *= 0.5;
245 }
246
247 if (cdfDelta > threshold) {
248 // Create the rejection sampler
249 final ZigguratSampler.NormalizedGaussian sampler = ZigguratSampler.NormalizedGaussian.of(rng);
250 final DoubleSupplier gen;
251 // Use mirroring if possible
252 if (lower >= 0) {
253 // Return the upper-half of the Gaussian
254 gen = () -> Math.abs(sampler.sample());
255 } else if (upper <= 0) {
256 // Return the lower-half of the Gaussian
257 gen = () -> -Math.abs(sampler.sample());
258 } else {
259 // Return the full range of the Gaussian
260 gen = sampler::sample;
261 }
262 // Map the bounds to a standard normal distribution
263 final double u = parentNormal.getMean();
264 final double s = parentNormal.getStandardDeviation();
265 final double a = (lower - u) / s;
266 final double b = (upper - u) / s;
267 // Sample in [a, b] using rejection
268 return () -> {
269 double x = gen.getAsDouble();
270 while (x < a || x > b) {
271 x = gen.getAsDouble();
272 }
273 // Avoid floating-point error when mapping back
274 return clipToRange(u + x * s);
275 };
276 }
277
278 // Default to an inverse CDF sampler
279 return super.createSampler(rng);
280 }
281
282 /**
283 * {@inheritDoc}
284 *
285 * <p>Represents the true mean of the truncated normal distribution rather
286 * than the parent normal distribution mean.
287 *
288 * <p>For \( \mu \) mean of the parent normal distribution,
289 * \( \sigma \) standard deviation of the parent normal distribution, and
290 * \( a \lt b \) the truncation interval of the parent normal distribution, the mean is:
291 *
292 * <p>\[ \mu + \frac{\phi(a)-\phi(b)}{\Phi(b) - \Phi(a)}\sigma \]
293 *
294 * <p>where \( \phi \) is the probability density function of the standard normal distribution
295 * and \( \Phi \) is its cumulative distribution function.
296 */
297 @Override
298 public double getMean() {
299 final double u = parentNormal.getMean();
300 final double s = parentNormal.getStandardDeviation();
301 final double a = (lower - u) / s;
302 final double b = (upper - u) / s;
303 return u + moment1(a, b) * s;
304 }
305
306 /**
307 * {@inheritDoc}
308 *
309 * <p>Represents the true variance of the truncated normal distribution rather
310 * than the parent normal distribution variance.
311 *
312 * <p>For \( \mu \) mean of the parent normal distribution,
313 * \( \sigma \) standard deviation of the parent normal distribution, and
314 * \( a \lt b \) the truncation interval of the parent normal distribution, the variance is:
315 *
316 * <p>\[ \sigma^2 \left[1 + \frac{a\phi(a)-b\phi(b)}{\Phi(b) - \Phi(a)} -
317 * \left( \frac{\phi(a)-\phi(b)}{\Phi(b) - \Phi(a)} \right)^2 \right] \]
318 *
319 * <p>where \( \phi \) is the probability density function of the standard normal distribution
320 * and \( \Phi \) is its cumulative distribution function.
321 */
322 @Override
323 public double getVariance() {
324 final double u = parentNormal.getMean();
325 final double s = parentNormal.getStandardDeviation();
326 final double a = (lower - u) / s;
327 final double b = (upper - u) / s;
328 return variance(a, b) * s * s;
329 }
330
331 /**
332 * {@inheritDoc}
333 *
334 * <p>The lower bound of the support is equal to the lower bound parameter
335 * of the distribution.
336 */
337 @Override
338 public double getSupportLowerBound() {
339 return lower;
340 }
341
342 /**
343 * {@inheritDoc}
344 *
345 * <p>The upper bound of the support is equal to the upper bound parameter
346 * of the distribution.
347 */
348 @Override
349 public double getSupportUpperBound() {
350 return upper;
351 }
352
353 /**
354 * Clip the value to the range [lower, upper].
355 * This is used to handle floating-point error at the support bound.
356 *
357 * @param x Value x
358 * @return x clipped to the range
359 */
360 private double clipToRange(double x) {
361 return clip(x, lower, upper);
362 }
363
364 /**
365 * Clip the value to the range [lower, upper].
366 *
367 * @param x Value x
368 * @param lower Lower bound (inclusive)
369 * @param upper Upper bound (inclusive)
370 * @return x clipped to the range
371 */
372 private static double clip(double x, double lower, double upper) {
373 if (x <= lower) {
374 return lower;
375 }
376 return x < upper ? x : upper;
377 }
378
379 // Calculation of variance and mean can suffer from cancellation.
380 //
381 // Use formulas from Jorge Fernandez-de-Cossio-Diaz adapted under the
382 // terms of the MIT "Expat" License (see NOTICE and LICENSE).
383 //
384 // These formulas use the complementary error function
385 // erfcx(z) = erfc(z) * exp(z^2)
386 // This avoids computation of exp terms for the Gaussian PDF and then
387 // dividing by the error functions erf or erfc:
388 // exp(-0.5*x*x) / erfc(x / sqrt(2)) == 1 / erfcx(x / sqrt(2))
389 // At large z the erfcx function is computable but exp(-0.5*z*z) and
390 // erfc(z) are zero. Use of these formulas allows computation of the
391 // mean and variance for the usable range of the truncated distribution
392 // (cdf(a, b) != 0). The variance is not accurate when it approaches
393 // machine epsilon (2^-52) at extremely narrow truncations and the
394 // computation -> 0.
395 //
396 // See: https://github.com/cossio/TruncatedNormal.jl
397
398 /**
399 * Compute the first moment (mean) of the truncated standard normal distribution.
400 *
401 * <p>Assumes {@code a <= b}.
402 *
403 * @param a Lower bound
404 * @param b Upper bound
405 * @return the first moment
406 */
407 static double moment1(double a, double b) {
408 // Assume a <= b
409 if (a == b) {
410 return a;
411 }
412 if (Math.abs(a) > Math.abs(b)) {
413 // Subtract from zero to avoid generating -0.0
414 return 0 - moment1(-b, -a);
415 }
416
417 // Here:
418 // |a| <= |b|
419 // a < b
420 // 0 < b
421
422 if (a <= -MAX_X) {
423 // No truncation
424 return 0;
425 }
426 if (b >= MAX_X) {
427 // One-sided truncation
428 return ROOT_2_PI / Erfcx.value(a / ROOT2);
429 }
430
431 // pdf = exp(-0.5*x*x) / sqrt(2*pi)
432 // cdf = erfc(-x/sqrt(2)) / 2
433 // Compute:
434 // -(pdf(b) - pdf(a)) / cdf(b, a)
435 // Note:
436 // exp(-0.5*b*b) - exp(-0.5*a*a)
437 // Use cancellation of powers:
438 // exp(-0.5*(b*b-a*a)) * exp(-0.5*a*a) - exp(-0.5*a*a)
439 // expm1(-0.5*(b*b-a*a)) * exp(-0.5*a*a)
440
441 // dx = -0.5*(b*b-a*a)
442 final double dx = 0.5 * (b + a) * (b - a);
443 final double m;
444 if (a <= 0) {
445 // Opposite signs
446 m = ROOT_2_PI * -Math.expm1(-dx) * Math.exp(-0.5 * a * a) / ErfDifference.value(a / ROOT2, b / ROOT2);
447 } else {
448 final double z = Math.exp(-dx) * Erfcx.value(b / ROOT2) - Erfcx.value(a / ROOT2);
449 if (z == 0) {
450 // Occurs when a and b have large magnitudes and are very close
451 return (a + b) * 0.5;
452 }
453 m = ROOT_2_PI * Math.expm1(-dx) / z;
454 }
455
456 // Clip to the range
457 return clip(m, a, b);
458 }
459
460 /**
461 * Compute the second moment of the truncated standard normal distribution.
462 *
463 * <p>Assumes {@code a <= b}.
464 *
465 * @param a Lower bound
466 * @param b Upper bound
467 * @return the first moment
468 */
469 private static double moment2(double a, double b) {
470 // Assume a < b.
471 // a == b is handled in the variance method
472 if (Math.abs(a) > Math.abs(b)) {
473 return moment2(-b, -a);
474 }
475
476 // Here:
477 // |a| <= |b|
478 // a < b
479 // 0 < b
480
481 if (a <= -MAX_X) {
482 // No truncation
483 return 1;
484 }
485 if (b >= MAX_X) {
486 // One-sided truncation.
487 // For a -> inf : moment2 -> a*a
488 // This occurs when erfcx(z) is approximated by (1/sqrt(pi)) / z and terms
489 // cancel. z > 6.71e7, a > 9.49e7
490 return 1 + ROOT_2_PI * a / Erfcx.value(a / ROOT2);
491 }
492
493 // pdf = exp(-0.5*x*x) / sqrt(2*pi)
494 // cdf = erfc(-x/sqrt(2)) / 2
495 // Compute:
496 // 1 - (b*pdf(b) - a*pdf(a)) / cdf(b, a)
497 // = (cdf(b, a) - b*pdf(b) -a*pdf(a)) / cdf(b, a)
498
499 // Note:
500 // For z -> 0:
501 // sqrt(pi / 2) * erf(z / sqrt(2)) -> z
502 // z * Math.exp(-0.5 * z * z) -> z
503 // Both computations below have cancellation as b -> 0 and the
504 // second moment is not computable as the fraction P/Q
505 // since P < ulp(Q). This always occurs when b < MIN_X
506 // if MIN_X is set at the point where
507 // exp(-0.5 * z * z) / sqrt(2 pi) == 1 / sqrt(2 pi).
508 // This is JDK dependent due to variations in Math.exp.
509 // For b < MIN_X the second moment can be approximated using
510 // a uniform distribution: (b^3 - a^3) / (3b - 3a).
511 // In practice it also occurs when b > MIN_X since any a < MIN_X
512 // is effectively zero for part of the computation. A
513 // threshold to transition to a uniform distribution
514 // approximation is a compromise. Also note it will not
515 // correct computation when (b-a) is small and is far from 0.
516 // Thus the second moment is left to be inaccurate for
517 // small ranges (b-a) and the variance -> 0 when the true
518 // variance is close to or below machine epsilon.
519
520 double m;
521
522 if (a <= 0) {
523 // Opposite signs
524 final double ea = ROOT_PI_2 * Erf.value(a / ROOT2);
525 final double eb = ROOT_PI_2 * Erf.value(b / ROOT2);
526 final double fa = ea - a * Math.exp(-0.5 * a * a);
527 final double fb = eb - b * Math.exp(-0.5 * b * b);
528 // Assume fb >= fa && eb >= ea
529 // If fb <= fa this is a tiny range around 0
530 m = (fb - fa) / (eb - ea);
531 // Clip to the range
532 m = clip(m, 0, 1);
533 } else {
534 final double dx = 0.5 * (b + a) * (b - a);
535 final double ex = Math.exp(-dx);
536 final double ea = ROOT_PI_2 * Erfcx.value(a / ROOT2);
537 final double eb = ROOT_PI_2 * Erfcx.value(b / ROOT2);
538 final double fa = ea + a;
539 final double fb = eb + b;
540 m = (fa - fb * ex) / (ea - eb * ex);
541 // Clip to the range
542 m = clip(m, a * a, b * b);
543 }
544 return m;
545 }
546
547 /**
548 * Compute the variance of the truncated standard normal distribution.
549 *
550 * <p>Assumes {@code a <= b}.
551 *
552 * @param a Lower bound
553 * @param b Upper bound
554 * @return the first moment
555 */
556 static double variance(double a, double b) {
557 if (a == b) {
558 return 0;
559 }
560
561 final double m1 = moment1(a, b);
562 double m2 = moment2(a, b);
563 // variance = m2 - m1*m1
564 // rearrange x^2 - y^2 as (x-y)(x+y)
565 m2 = Math.sqrt(m2);
566 final double variance = (m2 - m1) * (m2 + m1);
567
568 // Detect floating-point error.
569 if (variance >= 1) {
570 // Note:
571 // Extreme truncations in the tails can compute a variance above 1,
572 // for example if m2 is infinite: m2 - m1*m1 > 1
573 // Detect no truncation as the terms a and b lie far either side of zero;
574 // otherwise return 0 to indicate very small unknown variance.
575 return a < -1 && b > 1 ? 1 : 0;
576 } else if (variance <= 0) {
577 // Floating-point error can create negative variance so return 0.
578 return 0;
579 }
580
581 return variance;
582 }
583 }