FoldedNormalDistribution.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.statistics.distribution;

  18. import org.apache.commons.numbers.gamma.Erf;
  19. import org.apache.commons.numbers.gamma.ErfDifference;
  20. import org.apache.commons.numbers.gamma.Erfc;
  21. import org.apache.commons.numbers.gamma.InverseErf;
  22. import org.apache.commons.numbers.gamma.InverseErfc;
  23. import org.apache.commons.rng.UniformRandomProvider;
  24. import org.apache.commons.rng.sampling.distribution.GaussianSampler;
  25. import org.apache.commons.rng.sampling.distribution.SharedStateContinuousSampler;
  26. import org.apache.commons.rng.sampling.distribution.ZigguratSampler;

  27. /**
  28.  * Implementation of the folded normal distribution.
  29.  *
  30.  * <p>Given a normally distributed random variable \( X \) with mean \( \mu \) and variance
  31.  * \( \sigma^2 \), the random variable \( Y = |X| \) has a folded normal distribution. This is
  32.  * equivalent to not recording the sign from a normally distributed random variable.
  33.  *
  34.  * <p>The probability density function of \( X \) is:
  35.  *
  36.  * <p>\[ f(x; \mu, \sigma) = \frac 1 {\sigma\sqrt{2\pi}} e^{-{\frac 1 2}\left( \frac{x-\mu}{\sigma} \right)^2 } +
  37.  *                           \frac 1 {\sigma\sqrt{2\pi}} e^{-{\frac 1 2}\left( \frac{x+\mu}{\sigma} \right)^2 }\]
  38.  *
  39.  * <p>for \( \mu \) the location,
  40.  * \( \sigma &gt; 0 \) the scale, and
  41.  * \( x \in [0, \infty) \).
  42.  *
  43.  * <p>If the location \( \mu \) is 0 this reduces to the half-normal distribution.
  44.  *
  45.  * @see <a href="https://en.wikipedia.org/wiki/Folded_normal_distribution">Folded normal distribution (Wikipedia)</a>
  46.  * @see <a href="https://en.wikipedia.org/wiki/Half-normal_distribution">Half-normal distribution (Wikipedia)</a>
  47.  * @since 1.1
  48.  */
  49. public abstract class FoldedNormalDistribution extends AbstractContinuousDistribution {
  50.     /** The scale. */
  51.     final double sigma;
  52.     /**
  53.      * The scale multiplied by sqrt(2).
  54.      * This is used to avoid a double division when computing the value passed to the
  55.      * error function:
  56.      * <pre>
  57.      *  ((x - u) / sd) / sqrt(2) == (x - u) / (sd * sqrt(2)).
  58.      *  </pre>
  59.      * <p>Note: Implementations may first normalise x and then divide by sqrt(2) resulting
  60.      * in differences due to rounding error that show increasingly large relative
  61.      * differences as the error function computes close to 0 in the extreme tail.
  62.      */
  63.     final double sigmaSqrt2;
  64.     /**
  65.      * The scale multiplied by sqrt(2 pi). Computed to high precision.
  66.      */
  67.     final double sigmaSqrt2pi;

  68.     /**
  69.      * Regular implementation of the folded normal distribution.
  70.      */
  71.     private static class RegularFoldedNormalDistribution extends FoldedNormalDistribution {
  72.         /** The location. */
  73.         private final double mu;
  74.         /** Cached value for inverse probability function. */
  75.         private final double mean;
  76.         /** Cached value for inverse probability function. */
  77.         private final double variance;

  78.         /**
  79.          * @param mu Location parameter.
  80.          * @param sigma Scale parameter.
  81.          */
  82.         RegularFoldedNormalDistribution(double mu, double sigma) {
  83.             super(sigma);
  84.             this.mu = mu;

  85.             final double a = mu / sigmaSqrt2;
  86.             mean = sigma * Constants.ROOT_TWO_DIV_PI * Math.exp(-a * a) + mu * Erf.value(a);
  87.             this.variance = mu * mu + sigma * sigma - mean * mean;
  88.         }

  89.         @Override
  90.         public double getMu() {
  91.             return mu;
  92.         }

  93.         @Override
  94.         public double density(double x) {
  95.             if (x < 0) {
  96.                 return 0;
  97.             }
  98.             final double vm = (x - mu) / sigma;
  99.             final double vp = (x + mu) / sigma;
  100.             return (ExtendedPrecision.expmhxx(vm) + ExtendedPrecision.expmhxx(vp)) / sigmaSqrt2pi;
  101.         }

  102.         @Override
  103.         public double probability(double x0,
  104.                                   double x1) {
  105.             if (x0 > x1) {
  106.                 throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GT_HIGH,
  107.                                                 x0, x1);
  108.             }
  109.             if (x0 <= 0) {
  110.                 return cumulativeProbability(x1);
  111.             }
  112.             // Assumes x1 >= x0 && x0 > 0
  113.             final double v0m = (x0 - mu) / sigmaSqrt2;
  114.             final double v1m = (x1 - mu) / sigmaSqrt2;
  115.             final double v0p = (x0 + mu) / sigmaSqrt2;
  116.             final double v1p = (x1 + mu) / sigmaSqrt2;
  117.             return 0.5 * (ErfDifference.value(v0m, v1m) + ErfDifference.value(v0p, v1p));
  118.         }

  119.         @Override
  120.         public double cumulativeProbability(double x) {
  121.             if (x <= 0) {
  122.                 return 0;
  123.             }
  124.             return 0.5 * (Erf.value((x - mu) / sigmaSqrt2) + Erf.value((x + mu) / sigmaSqrt2));
  125.         }

  126.         @Override
  127.         public double survivalProbability(double x) {
  128.             if (x <= 0) {
  129.                 return 1;
  130.             }
  131.             return 0.5 * (Erfc.value((x - mu) / sigmaSqrt2) + Erfc.value((x + mu) / sigmaSqrt2));
  132.         }

  133.         @Override
  134.         public double getMean() {
  135.             return mean;
  136.         }

  137.         @Override
  138.         public double getVariance() {
  139.             return variance;
  140.         }

  141.         @Override
  142.         public Sampler createSampler(UniformRandomProvider rng) {
  143.             // Return the absolute of a Gaussian distribution sampler.
  144.             final SharedStateContinuousSampler s =
  145.                 GaussianSampler.of(ZigguratSampler.NormalizedGaussian.of(rng), mu, sigma);
  146.             return () -> Math.abs(s.sample());
  147.         }
  148.     }

  149.     /**
  150.      * Specialisation for the half-normal distribution.
  151.      *
  152.      * <p>Elimination of the {@code mu} location parameter simplifies the probability
  153.      * functions and allows computation of the log density and inverse CDF/SF.
  154.      */
  155.     private static class HalfNormalDistribution extends FoldedNormalDistribution {
  156.         /** Variance constant (1 - 2/pi). Computed using Matlab's VPA to 30 digits. */
  157.         private static final double VAR = 0.36338022763241865692446494650994;
  158.         /** The value of {@code log(sigma) + 0.5 * log(2*PI)} stored for faster computation. */
  159.         private final double logSigmaPlusHalfLog2Pi;

  160.         /**
  161.          * @param sigma Scale parameter.
  162.          */
  163.         HalfNormalDistribution(double sigma) {
  164.             super(sigma);
  165.             logSigmaPlusHalfLog2Pi = Math.log(sigma) + Constants.HALF_LOG_TWO_PI;
  166.         }

  167.         @Override
  168.         public double getMu() {
  169.             return 0;
  170.         }

  171.         @Override
  172.         public double density(double x) {
  173.             if (x < 0) {
  174.                 return 0;
  175.             }
  176.             return 2 * ExtendedPrecision.expmhxx(x / sigma) / sigmaSqrt2pi;
  177.         }

  178.         @Override
  179.         public double probability(double x0,
  180.                                   double x1) {
  181.             if (x0 > x1) {
  182.                 throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GT_HIGH,
  183.                                                 x0, x1);
  184.             }
  185.             if (x0 <= 0) {
  186.                 return cumulativeProbability(x1);
  187.             }
  188.             // Assumes x1 >= x0 && x0 > 0
  189.             return ErfDifference.value(x0 / sigmaSqrt2, x1 / sigmaSqrt2);
  190.         }

  191.         @Override
  192.         public double logDensity(double x) {
  193.             if (x < 0) {
  194.                 return Double.NEGATIVE_INFINITY;
  195.             }
  196.             final double z = x / sigma;
  197.             return Constants.LN_TWO - 0.5 * z * z - logSigmaPlusHalfLog2Pi;
  198.         }

  199.         @Override
  200.         public double cumulativeProbability(double x) {
  201.             if (x <= 0) {
  202.                 return 0;
  203.             }
  204.             return Erf.value(x / sigmaSqrt2);
  205.         }

  206.         @Override
  207.         public double survivalProbability(double x) {
  208.             if (x <= 0) {
  209.                 return 1;
  210.             }
  211.             return Erfc.value(x / sigmaSqrt2);
  212.         }

  213.         @Override
  214.         public double inverseCumulativeProbability(double p) {
  215.             ArgumentUtils.checkProbability(p);
  216.             // Addition of 0.0 ensures 0.0 is returned for p=-0.0
  217.             return 0.0 + sigmaSqrt2 * InverseErf.value(p);
  218.         }

  219.         /** {@inheritDoc} */
  220.         @Override
  221.         public double inverseSurvivalProbability(double p) {
  222.             ArgumentUtils.checkProbability(p);
  223.             return sigmaSqrt2 * InverseErfc.value(p);
  224.         }

  225.         @Override
  226.         public double getMean() {
  227.             return sigma * Constants.ROOT_TWO_DIV_PI;
  228.         }

  229.         @Override
  230.         public double getVariance() {
  231.             // sigma^2 - mean^2
  232.             // sigma^2 - (sigma^2 * 2/pi)
  233.             return sigma * sigma * VAR;
  234.         }

  235.         @Override
  236.         public Sampler createSampler(UniformRandomProvider rng) {
  237.             // Return the absolute of a Gaussian distribution sampler.
  238.             final SharedStateContinuousSampler s = ZigguratSampler.NormalizedGaussian.of(rng);
  239.             return () -> Math.abs(s.sample() * sigma);
  240.         }
  241.     }

  242.     /**
  243.      * @param sigma Scale parameter.
  244.      */
  245.     FoldedNormalDistribution(double sigma) {
  246.         this.sigma = sigma;
  247.         // Minimise rounding error by computing sqrt(2 * sigma * sigma) exactly.
  248.         // Compute using extended precision with care to avoid over/underflow.
  249.         sigmaSqrt2 = ExtendedPrecision.sqrt2xx(sigma);
  250.         // Compute sigma * sqrt(2 * pi)
  251.         sigmaSqrt2pi = ExtendedPrecision.xsqrt2pi(sigma);
  252.     }

  253.     /**
  254.      * Creates a folded normal distribution. If the location {@code mu} is zero this is
  255.      * the half-normal distribution.
  256.      *
  257.      * @param mu Location parameter.
  258.      * @param sigma Scale parameter.
  259.      * @return the distribution
  260.      * @throws IllegalArgumentException if {@code sigma <= 0}.
  261.      */
  262.     public static FoldedNormalDistribution of(double mu,
  263.                                               double sigma) {
  264.         if (sigma > 0) {
  265.             if (mu == 0) {
  266.                 return new HalfNormalDistribution(sigma);
  267.             }
  268.             return new RegularFoldedNormalDistribution(mu, sigma);
  269.         }
  270.         // scale is zero, negative or nan
  271.         throw new DistributionException(DistributionException.NOT_STRICTLY_POSITIVE, sigma);
  272.     }

  273.     /**
  274.      * Gets the location parameter \( \mu \) of this distribution.
  275.      *
  276.      * @return the mu parameter.
  277.      */
  278.     public abstract double getMu();

  279.     /**
  280.      * Gets the scale parameter \( \sigma \) of this distribution.
  281.      *
  282.      * @return the sigma parameter.
  283.      */
  284.     public double getSigma() {
  285.         return sigma;
  286.     }

  287.     /**
  288.      * {@inheritDoc}
  289.      *
  290.      *
  291.      * <p>For location parameter \( \mu \) and scale parameter \( \sigma \), the mean is:
  292.      *
  293.      * <p>\[ \sigma \sqrt{ \frac 2 \pi } \exp \left( \frac{-\mu^2}{2\sigma^2} \right) +
  294.      *       \mu \operatorname{erf} \left( \frac \mu {\sqrt{2\sigma^2}} \right) \]
  295.      *
  296.      * <p>where \( \operatorname{erf} \) is the error function.
  297.      */
  298.     @Override
  299.     public abstract double getMean();

  300.     /**
  301.      * {@inheritDoc}
  302.      *
  303.      * <p>For location parameter \( \mu \), scale parameter \( \sigma \) and a distribution
  304.      * mean \( \mu_Y \), the variance is:
  305.      *
  306.      * <p>\[ \mu^2 + \sigma^2 - \mu_{Y}^2 \]
  307.      */
  308.     @Override
  309.     public abstract double getVariance();

  310.     /**
  311.      * {@inheritDoc}
  312.      *
  313.      * <p>The lower bound of the support is always 0.
  314.      *
  315.      * @return 0.
  316.      */
  317.     @Override
  318.     public double getSupportLowerBound() {
  319.         return 0.0;
  320.     }

  321.     /**
  322.      * {@inheritDoc}
  323.      *
  324.      * <p>The upper bound of the support is always positive infinity.
  325.      *
  326.      * @return {@linkplain Double#POSITIVE_INFINITY positive infinity}.
  327.      */
  328.     @Override
  329.     public double getSupportUpperBound() {
  330.         return Double.POSITIVE_INFINITY;
  331.     }
  332. }