TruncatedNormalDistribution.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.statistics.distribution;

  18. import java.util.function.DoubleSupplier;
  19. import org.apache.commons.numbers.gamma.Erf;
  20. import org.apache.commons.numbers.gamma.ErfDifference;
  21. import org.apache.commons.numbers.gamma.Erfcx;
  22. import org.apache.commons.rng.UniformRandomProvider;
  23. import org.apache.commons.rng.sampling.distribution.ZigguratSampler;

  24. /**
  25.  * Implementation of the truncated normal distribution.
  26.  *
  27.  * <p>The probability density function of \( X \) is:
  28.  *
  29.  * <p>\[ f(x;\mu,\sigma,a,b) = \frac{1}{\sigma}\,\frac{\phi(\frac{x - \mu}{\sigma})}{\Phi(\frac{b - \mu}{\sigma}) - \Phi(\frac{a - \mu}{\sigma}) } \]
  30.  *
  31.  * <p>for \( \mu \) mean of the parent normal distribution,
  32.  * \( \sigma \) standard deviation of the parent normal distribution,
  33.  * \( -\infty \le a \lt b \le \infty \) the truncation interval, and
  34.  * \( x \in [a, b] \), where \( \phi \) is the probability
  35.  * density function of the standard normal distribution and \( \Phi \)
  36.  * is its cumulative distribution function.
  37.  *
  38.  * @see <a href="https://en.wikipedia.org/wiki/Truncated_normal_distribution">
  39.  * Truncated normal distribution (Wikipedia)</a>
  40.  */
  41. public final class TruncatedNormalDistribution extends AbstractContinuousDistribution {

  42.     /** The max allowed value for x where (x*x) will not overflow.
  43.      * This is a limit on computation of the moments of the truncated normal
  44.      * as some calculations assume x*x is finite. Value is sqrt(MAX_VALUE). */
  45.     private static final double MAX_X = 0x1.fffffffffffffp511;

  46.     /** The min allowed probability range of the parent normal distribution.
  47.      * Set to 0.0. This may be too low for accurate usage. It is a signal that
  48.      * the truncation is invalid. */
  49.     private static final double MIN_P = 0.0;

  50.     /** sqrt(2). */
  51.     private static final double ROOT2 = Constants.ROOT_TWO;
  52.     /** Normalisation constant 2 / sqrt(2 pi) = sqrt(2 / pi). */
  53.     private static final double ROOT_2_PI = Constants.ROOT_TWO_DIV_PI;
  54.     /** Normalisation constant sqrt(2 pi) / 2 = sqrt(pi / 2). */
  55.     private static final double ROOT_PI_2 = Constants.ROOT_PI_DIV_TWO;

  56.     /**
  57.      * The threshold to switch to a rejection sampler. When the truncated
  58.      * distribution covers more than this fraction of the CDF then rejection
  59.      * sampling will be more efficient than inverse CDF sampling. Performance
  60.      * benchmarks indicate that a normalized Gaussian sampler is up to 10 times
  61.      * faster than inverse transform sampling using a fast random generator. See
  62.      * STATISTICS-55.
  63.      */
  64.     private static final double REJECTION_THRESHOLD = 0.2;

  65.     /** Parent normal distribution. */
  66.     private final NormalDistribution parentNormal;
  67.     /** Lower bound of this distribution. */
  68.     private final double lower;
  69.     /** Upper bound of this distribution. */
  70.     private final double upper;

  71.     /** Stored value of {@code parentNormal.probability(lower, upper)}. This is used to
  72.      * normalise the probability computations. */
  73.     private final double cdfDelta;
  74.     /** log(cdfDelta). */
  75.     private final double logCdfDelta;
  76.     /** Stored value of {@code parentNormal.cumulativeProbability(lower)}. Used to map
  77.      * a probability into the range of the parent normal distribution. */
  78.     private final double cdfAlpha;
  79.     /** Stored value of {@code parentNormal.survivalProbability(upper)}. Used to map
  80.      * a probability into the range of the parent normal distribution. */
  81.     private final double sfBeta;

  82.     /**
  83.      * @param parent Parent distribution.
  84.      * @param z Probability of the parent distribution for {@code [lower, upper]}.
  85.      * @param lower Lower bound (inclusive) of the distribution, can be {@link Double#NEGATIVE_INFINITY}.
  86.      * @param upper Upper bound (inclusive) of the distribution, can be {@link Double#POSITIVE_INFINITY}.
  87.      */
  88.     private TruncatedNormalDistribution(NormalDistribution parent, double z, double lower, double upper) {
  89.         this.parentNormal = parent;
  90.         this.lower = lower;
  91.         this.upper = upper;

  92.         cdfDelta = z;
  93.         logCdfDelta = Math.log(cdfDelta);
  94.         // Used to map the inverse probability.
  95.         cdfAlpha = parentNormal.cumulativeProbability(lower);
  96.         sfBeta = parentNormal.survivalProbability(upper);
  97.     }

  98.     /**
  99.      * Creates a truncated normal distribution.
  100.      *
  101.      * <p>Note that the {@code mean} and {@code sd} is of the parent normal distribution,
  102.      * and not the true mean and standard deviation of the truncated normal distribution.
  103.      * The {@code lower} and {@code upper} bounds define the truncation of the parent
  104.      * normal distribution.
  105.      *
  106.      * @param mean Mean for the parent distribution.
  107.      * @param sd Standard deviation for the parent distribution.
  108.      * @param lower Lower bound (inclusive) of the distribution, can be {@link Double#NEGATIVE_INFINITY}.
  109.      * @param upper Upper bound (inclusive) of the distribution, can be {@link Double#POSITIVE_INFINITY}.
  110.      * @return the distribution
  111.      * @throws IllegalArgumentException if {@code sd <= 0}; if {@code lower >= upper}; or if
  112.      * the truncation covers no probability range in the parent distribution.
  113.      */
  114.     public static TruncatedNormalDistribution of(double mean, double sd, double lower, double upper) {
  115.         if (sd <= 0) {
  116.             throw new DistributionException(DistributionException.NOT_STRICTLY_POSITIVE, sd);
  117.         }
  118.         if (lower >= upper) {
  119.             throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GTE_HIGH, lower, upper);
  120.         }

  121.         // Use an instance for the parent normal distribution to maximise accuracy
  122.         // in range computations using the error function
  123.         final NormalDistribution parent = NormalDistribution.of(mean, sd);

  124.         // If there is no computable range then raise an exception.
  125.         final double z = parent.probability(lower, upper);
  126.         if (z <= MIN_P) {
  127.             // Map the bounds to a standard normal distribution for the message
  128.             final double a = (lower - mean) / sd;
  129.             final double b = (upper - mean) / sd;
  130.             throw new DistributionException(
  131.                "Excess truncation of standard normal : CDF(%s, %s) = %s", a, b, z);
  132.         }

  133.         // Here we have a meaningful truncation. Note that excess truncation may not be optimal.
  134.         // For example truncation close to zero where the PDF is constant can be approximated
  135.         // using a uniform distribution.

  136.         return new TruncatedNormalDistribution(parent, z, lower, upper);
  137.     }

  138.     /** {@inheritDoc} */
  139.     @Override
  140.     public double density(double x) {
  141.         if (x < lower || x > upper) {
  142.             return 0;
  143.         }
  144.         return parentNormal.density(x) / cdfDelta;
  145.     }

  146.     /** {@inheritDoc} */
  147.     @Override
  148.     public double probability(double x0, double x1) {
  149.         if (x0 > x1) {
  150.             throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GT_HIGH,
  151.                                             x0, x1);
  152.         }
  153.         return parentNormal.probability(clipToRange(x0), clipToRange(x1)) / cdfDelta;
  154.     }

  155.     /** {@inheritDoc} */
  156.     @Override
  157.     public double logDensity(double x) {
  158.         if (x < lower || x > upper) {
  159.             return Double.NEGATIVE_INFINITY;
  160.         }
  161.         return parentNormal.logDensity(x) - logCdfDelta;
  162.     }

  163.     /** {@inheritDoc} */
  164.     @Override
  165.     public double cumulativeProbability(double x) {
  166.         if (x <= lower) {
  167.             return 0;
  168.         } else if (x >= upper) {
  169.             return 1;
  170.         }
  171.         return parentNormal.probability(lower, x) / cdfDelta;
  172.     }

  173.     /** {@inheritDoc} */
  174.     @Override
  175.     public double survivalProbability(double x) {
  176.         if (x <= lower) {
  177.             return 1;
  178.         } else if (x >= upper) {
  179.             return 0;
  180.         }
  181.         return parentNormal.probability(x, upper) / cdfDelta;
  182.     }

  183.     /** {@inheritDoc} */
  184.     @Override
  185.     public double inverseCumulativeProbability(double p) {
  186.         ArgumentUtils.checkProbability(p);
  187.         // Exact bound
  188.         if (p == 0) {
  189.             return lower;
  190.         } else if (p == 1) {
  191.             return upper;
  192.         }
  193.         // Linearly map p to the range [lower, upper]
  194.         final double x = parentNormal.inverseCumulativeProbability(cdfAlpha + p * cdfDelta);
  195.         return clipToRange(x);
  196.     }

  197.     /** {@inheritDoc} */
  198.     @Override
  199.     public double inverseSurvivalProbability(double p) {
  200.         ArgumentUtils.checkProbability(p);
  201.         // Exact bound
  202.         if (p == 1) {
  203.             return lower;
  204.         } else if (p == 0) {
  205.             return upper;
  206.         }
  207.         // Linearly map p to the range [lower, upper]
  208.         final double x = parentNormal.inverseSurvivalProbability(sfBeta + p * cdfDelta);
  209.         return clipToRange(x);
  210.     }

  211.     /** {@inheritDoc} */
  212.     @Override
  213.     public Sampler createSampler(UniformRandomProvider rng) {
  214.         // If the truncation covers a reasonable amount of the normal distribution
  215.         // then a rejection sampler can be used.
  216.         double threshold = REJECTION_THRESHOLD;
  217.         // If the truncation is entirely in the upper or lower half then adjust the
  218.         // threshold as twice the samples can be used
  219.         if (lower >= 0 || upper <= 0) {
  220.             threshold *= 0.5;
  221.         }

  222.         if (cdfDelta > threshold) {
  223.             // Create the rejection sampler
  224.             final ZigguratSampler.NormalizedGaussian sampler = ZigguratSampler.NormalizedGaussian.of(rng);
  225.             final DoubleSupplier gen;
  226.             // Use mirroring if possible
  227.             if (lower >= 0) {
  228.                 // Return the upper-half of the Gaussian
  229.                 gen = () -> Math.abs(sampler.sample());
  230.             } else if (upper <= 0) {
  231.                 // Return the lower-half of the Gaussian
  232.                 gen = () -> -Math.abs(sampler.sample());
  233.             } else {
  234.                 // Return the full range of the Gaussian
  235.                 gen = sampler::sample;
  236.             }
  237.             // Map the bounds to a standard normal distribution
  238.             final double u = parentNormal.getMean();
  239.             final double s = parentNormal.getStandardDeviation();
  240.             final double a = (lower - u) / s;
  241.             final double b = (upper - u) / s;
  242.             // Sample in [a, b] using rejection
  243.             return () -> {
  244.                 double x = gen.getAsDouble();
  245.                 while (x < a || x > b) {
  246.                     x = gen.getAsDouble();
  247.                 }
  248.                 // Avoid floating-point error when mapping back
  249.                 return clipToRange(u + x * s);
  250.             };
  251.         }

  252.         // Default to an inverse CDF sampler
  253.         return super.createSampler(rng);
  254.     }

  255.     /**
  256.      * {@inheritDoc}
  257.      *
  258.      * <p>Represents the true mean of the truncated normal distribution rather
  259.      * than the parent normal distribution mean.
  260.      *
  261.      * <p>For \( \mu \) mean of the parent normal distribution,
  262.      * \( \sigma \) standard deviation of the parent normal distribution, and
  263.      * \( a \lt b \) the truncation interval of the parent normal distribution, the mean is:
  264.      *
  265.      * <p>\[ \mu + \frac{\phi(a)-\phi(b)}{\Phi(b) - \Phi(a)}\sigma \]
  266.      *
  267.      * <p>where \( \phi \) is the probability density function of the standard normal distribution
  268.      * and \( \Phi \) is its cumulative distribution function.
  269.      */
  270.     @Override
  271.     public double getMean() {
  272.         final double u = parentNormal.getMean();
  273.         final double s = parentNormal.getStandardDeviation();
  274.         final double a = (lower - u) / s;
  275.         final double b = (upper - u) / s;
  276.         return u + moment1(a, b) * s;
  277.     }

  278.     /**
  279.      * {@inheritDoc}
  280.      *
  281.      * <p>Represents the true variance of the truncated normal distribution rather
  282.      * than the parent normal distribution variance.
  283.      *
  284.      * <p>For \( \mu \) mean of the parent normal distribution,
  285.      * \( \sigma \) standard deviation of the parent normal distribution, and
  286.      * \( a \lt b \) the truncation interval of the parent normal distribution, the variance is:
  287.      *
  288.      * <p>\[ \sigma^2 \left[1 + \frac{a\phi(a)-b\phi(b)}{\Phi(b) - \Phi(a)} -
  289.      *       \left( \frac{\phi(a)-\phi(b)}{\Phi(b) - \Phi(a)} \right)^2 \right] \]
  290.      *
  291.      * <p>where \( \phi \) is the probability density function of the standard normal distribution
  292.      * and \( \Phi \) is its cumulative distribution function.
  293.      */
  294.     @Override
  295.     public double getVariance() {
  296.         final double u = parentNormal.getMean();
  297.         final double s = parentNormal.getStandardDeviation();
  298.         final double a = (lower - u) / s;
  299.         final double b = (upper - u) / s;
  300.         return variance(a, b) * s * s;
  301.     }

  302.     /**
  303.      * {@inheritDoc}
  304.      *
  305.      * <p>The lower bound of the support is equal to the lower bound parameter
  306.      * of the distribution.
  307.      */
  308.     @Override
  309.     public double getSupportLowerBound() {
  310.         return lower;
  311.     }

  312.     /**
  313.      * {@inheritDoc}
  314.      *
  315.      * <p>The upper bound of the support is equal to the upper bound parameter
  316.      * of the distribution.
  317.      */
  318.     @Override
  319.     public double getSupportUpperBound() {
  320.         return upper;
  321.     }

  322.     /**
  323.      * Clip the value to the range [lower, upper].
  324.      * This is used to handle floating-point error at the support bound.
  325.      *
  326.      * @param x Value x
  327.      * @return x clipped to the range
  328.      */
  329.     private double clipToRange(double x) {
  330.         return clip(x, lower, upper);
  331.     }

  332.     /**
  333.      * Clip the value to the range [lower, upper].
  334.      *
  335.      * @param x Value x
  336.      * @param lower Lower bound (inclusive)
  337.      * @param upper Upper bound (inclusive)
  338.      * @return x clipped to the range
  339.      */
  340.     private static double clip(double x, double lower, double upper) {
  341.         if (x <= lower) {
  342.             return lower;
  343.         }
  344.         return x < upper ? x : upper;
  345.     }

  346.     // Calculation of variance and mean can suffer from cancellation.
  347.     //
  348.     // Use formulas from Jorge Fernandez-de-Cossio-Diaz adapted under the
  349.     // terms of the MIT "Expat" License (see NOTICE and LICENSE).
  350.     //
  351.     // These formulas use the complementary error function
  352.     //   erfcx(z) = erfc(z) * exp(z^2)
  353.     // This avoids computation of exp terms for the Gaussian PDF and then
  354.     // dividing by the error functions erf or erfc:
  355.     //   exp(-0.5*x*x) / erfc(x / sqrt(2)) == 1 / erfcx(x / sqrt(2))
  356.     // At large z the erfcx function is computable but exp(-0.5*z*z) and
  357.     // erfc(z) are zero. Use of these formulas allows computation of the
  358.     // mean and variance for the usable range of the truncated distribution
  359.     // (cdf(a, b) != 0). The variance is not accurate when it approaches
  360.     // machine epsilon (2^-52) at extremely narrow truncations and the
  361.     // computation -> 0.
  362.     //
  363.     // See: https://github.com/cossio/TruncatedNormal.jl

  364.     /**
  365.      * Compute the first moment (mean) of the truncated standard normal distribution.
  366.      *
  367.      * <p>Assumes {@code a <= b}.
  368.      *
  369.      * @param a Lower bound
  370.      * @param b Upper bound
  371.      * @return the first moment
  372.      */
  373.     static double moment1(double a, double b) {
  374.         // Assume a <= b
  375.         if (a == b) {
  376.             return a;
  377.         }
  378.         if (Math.abs(a) > Math.abs(b)) {
  379.             // Subtract from zero to avoid generating -0.0
  380.             return 0 - moment1(-b, -a);
  381.         }

  382.         // Here:
  383.         // |a| <= |b|
  384.         // a < b
  385.         // 0 < b

  386.         if (a <= -MAX_X) {
  387.             // No truncation
  388.             return 0;
  389.         }
  390.         if (b >= MAX_X) {
  391.             // One-sided truncation
  392.             return ROOT_2_PI / Erfcx.value(a / ROOT2);
  393.         }

  394.         // pdf = exp(-0.5*x*x) / sqrt(2*pi)
  395.         // cdf = erfc(-x/sqrt(2)) / 2
  396.         // Compute:
  397.         // -(pdf(b) - pdf(a)) / cdf(b, a)
  398.         // Note:
  399.         // exp(-0.5*b*b) - exp(-0.5*a*a)
  400.         // Use cancellation of powers:
  401.         // exp(-0.5*(b*b-a*a)) * exp(-0.5*a*a) - exp(-0.5*a*a)
  402.         // expm1(-0.5*(b*b-a*a)) * exp(-0.5*a*a)

  403.         // dx = -0.5*(b*b-a*a)
  404.         final double dx = 0.5 * (b + a) * (b - a);
  405.         final double m;
  406.         if (a <= 0) {
  407.             // Opposite signs
  408.             m = ROOT_2_PI * -Math.expm1(-dx) * Math.exp(-0.5 * a * a) / ErfDifference.value(a / ROOT2, b / ROOT2);
  409.         } else {
  410.             final double z = Math.exp(-dx) * Erfcx.value(b / ROOT2) - Erfcx.value(a / ROOT2);
  411.             if (z == 0) {
  412.                 // Occurs when a and b have large magnitudes and are very close
  413.                 return (a + b) * 0.5;
  414.             }
  415.             m = ROOT_2_PI * Math.expm1(-dx) / z;
  416.         }

  417.         // Clip to the range
  418.         return clip(m, a, b);
  419.     }

  420.     /**
  421.      * Compute the second moment of the truncated standard normal distribution.
  422.      *
  423.      * <p>Assumes {@code a <= b}.
  424.      *
  425.      * @param a Lower bound
  426.      * @param b Upper bound
  427.      * @return the first moment
  428.      */
  429.     private static double moment2(double a, double b) {
  430.         // Assume a < b.
  431.         // a == b is handled in the variance method
  432.         if (Math.abs(a) > Math.abs(b)) {
  433.             return moment2(-b, -a);
  434.         }

  435.         // Here:
  436.         // |a| <= |b|
  437.         // a < b
  438.         // 0 < b

  439.         if (a <= -MAX_X) {
  440.             // No truncation
  441.             return 1;
  442.         }
  443.         if (b >= MAX_X) {
  444.             // One-sided truncation.
  445.             // For a -> inf : moment2 -> a*a
  446.             // This occurs when erfcx(z) is approximated by (1/sqrt(pi)) / z and terms
  447.             // cancel. z > 6.71e7, a > 9.49e7
  448.             return 1 + ROOT_2_PI * a / Erfcx.value(a / ROOT2);
  449.         }

  450.         // pdf = exp(-0.5*x*x) / sqrt(2*pi)
  451.         // cdf = erfc(-x/sqrt(2)) / 2
  452.         // Compute:
  453.         // 1 - (b*pdf(b) - a*pdf(a)) / cdf(b, a)
  454.         // = (cdf(b, a) - b*pdf(b) -a*pdf(a)) / cdf(b, a)

  455.         // Note:
  456.         // For z -> 0:
  457.         //   sqrt(pi / 2) * erf(z / sqrt(2)) -> z
  458.         //   z * Math.exp(-0.5 * z * z) -> z
  459.         // Both computations below have cancellation as b -> 0 and the
  460.         // second moment is not computable as the fraction P/Q
  461.         // since P < ulp(Q). This always occurs when b < MIN_X
  462.         // if MIN_X is set at the point where
  463.         //   exp(-0.5 * z * z) / sqrt(2 pi) == 1 / sqrt(2 pi).
  464.         // This is JDK dependent due to variations in Math.exp.
  465.         // For b < MIN_X the second moment can be approximated using
  466.         // a uniform distribution: (b^3 - a^3) / (3b - 3a).
  467.         // In practice it also occurs when b > MIN_X since any a < MIN_X
  468.         // is effectively zero for part of the computation. A
  469.         // threshold to transition to a uniform distribution
  470.         // approximation is a compromise. Also note it will not
  471.         // correct computation when (b-a) is small and is far from 0.
  472.         // Thus the second moment is left to be inaccurate for
  473.         // small ranges (b-a) and the variance -> 0 when the true
  474.         // variance is close to or below machine epsilon.

  475.         double m;

  476.         if (a <= 0) {
  477.             // Opposite signs
  478.             final double ea = ROOT_PI_2 * Erf.value(a / ROOT2);
  479.             final double eb = ROOT_PI_2 * Erf.value(b / ROOT2);
  480.             final double fa = ea - a * Math.exp(-0.5 * a * a);
  481.             final double fb = eb - b * Math.exp(-0.5 * b * b);
  482.             // Assume fb >= fa && eb >= ea
  483.             // If fb <= fa this is a tiny range around 0
  484.             m = (fb - fa) / (eb - ea);
  485.             // Clip to the range
  486.             m = clip(m, 0, 1);
  487.         } else {
  488.             final double dx = 0.5 * (b + a) * (b - a);
  489.             final double ex = Math.exp(-dx);
  490.             final double ea = ROOT_PI_2 * Erfcx.value(a / ROOT2);
  491.             final double eb = ROOT_PI_2 * Erfcx.value(b / ROOT2);
  492.             final double fa = ea + a;
  493.             final double fb = eb + b;
  494.             m = (fa - fb * ex) / (ea - eb * ex);
  495.             // Clip to the range
  496.             m = clip(m, a * a, b * b);
  497.         }
  498.         return m;
  499.     }

  500.     /**
  501.      * Compute the variance of the truncated standard normal distribution.
  502.      *
  503.      * <p>Assumes {@code a <= b}.
  504.      *
  505.      * @param a Lower bound
  506.      * @param b Upper bound
  507.      * @return the first moment
  508.      */
  509.     static double variance(double a, double b) {
  510.         if (a == b) {
  511.             return 0;
  512.         }

  513.         final double m1 = moment1(a, b);
  514.         double m2 = moment2(a, b);
  515.         // variance = m2 - m1*m1
  516.         // rearrange x^2 - y^2 as (x-y)(x+y)
  517.         m2 = Math.sqrt(m2);
  518.         final double variance = (m2 - m1) * (m2 + m1);

  519.         // Detect floating-point error.
  520.         if (variance >= 1) {
  521.             // Note:
  522.             // Extreme truncations in the tails can compute a variance above 1,
  523.             // for example if m2 is infinite: m2 - m1*m1 > 1
  524.             // Detect no truncation as the terms a and b lie far either side of zero;
  525.             // otherwise return 0 to indicate very small unknown variance.
  526.             return a < -1 && b > 1 ? 1 : 0;
  527.         } else if (variance <= 0) {
  528.             // Floating-point error can create negative variance so return 0.
  529.             return 0;
  530.         }

  531.         return variance;
  532.     }
  533. }