NormalDistribution.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.statistics.distribution;

  18. import org.apache.commons.numbers.gamma.ErfDifference;
  19. import org.apache.commons.numbers.gamma.Erfc;
  20. import org.apache.commons.numbers.gamma.InverseErfc;
  21. import org.apache.commons.rng.UniformRandomProvider;
  22. import org.apache.commons.rng.sampling.distribution.GaussianSampler;
  23. import org.apache.commons.rng.sampling.distribution.ZigguratSampler;

  24. /**
  25.  * Implementation of the normal (Gaussian) distribution.
  26.  *
  27.  * <p>The probability density function of \( X \) is:
  28.  *
  29.  * <p>\[ f(x; \mu, \sigma) = \frac 1 {\sigma\sqrt{2\pi}} e^{-{\frac 1 2}\left( \frac{x-\mu}{\sigma} \right)^2 } \]
  30.  *
  31.  * <p>for \( \mu \) the mean,
  32.  * \( \sigma &gt; 0 \) the standard deviation, and
  33.  * \( x \in (-\infty, \infty) \).
  34.  *
  35.  * @see <a href="https://en.wikipedia.org/wiki/Normal_distribution">Normal distribution (Wikipedia)</a>
  36.  * @see <a href="https://mathworld.wolfram.com/NormalDistribution.html">Normal distribution (MathWorld)</a>
  37.  */
  38. public final class NormalDistribution extends AbstractContinuousDistribution {
  39.     /** Mean of this distribution. */
  40.     private final double mean;
  41.     /** Standard deviation of this distribution. */
  42.     private final double standardDeviation;
  43.     /** The value of {@code log(sd) + 0.5*log(2*pi)} stored for faster computation. */
  44.     private final double logStandardDeviationPlusHalfLog2Pi;
  45.     /**
  46.      * Standard deviation multiplied by sqrt(2).
  47.      * This is used to avoid a double division when computing the value passed to the
  48.      * error function:
  49.      * <pre>
  50.      *  ((x - u) / sd) / sqrt(2) == (x - u) / (sd * sqrt(2)).
  51.      *  </pre>
  52.      * <p>Note: Implementations may first normalise x and then divide by sqrt(2) resulting
  53.      * in differences due to rounding error that show increasingly large relative
  54.      * differences as the error function computes close to 0 in the extreme tail.
  55.      */
  56.     private final double sdSqrt2;
  57.     /**
  58.      * Standard deviation multiplied by sqrt(2 pi). Computed to high precision.
  59.      */
  60.     private final double sdSqrt2pi;

  61.     /**
  62.      * @param mean Mean for this distribution.
  63.      * @param sd Standard deviation for this distribution.
  64.      */
  65.     private NormalDistribution(double mean,
  66.                                double sd) {
  67.         this.mean = mean;
  68.         standardDeviation = sd;
  69.         logStandardDeviationPlusHalfLog2Pi = Math.log(sd) + Constants.HALF_LOG_TWO_PI;
  70.         // Minimise rounding error by computing sqrt(2 * sd * sd) exactly.
  71.         // Compute using extended precision with care to avoid over/underflow.
  72.         sdSqrt2 = ExtendedPrecision.sqrt2xx(sd);
  73.         // Compute sd * sqrt(2 * pi)
  74.         sdSqrt2pi = ExtendedPrecision.xsqrt2pi(sd);
  75.     }

  76.     /**
  77.      * Creates a normal distribution.
  78.      *
  79.      * @param mean Mean for this distribution.
  80.      * @param sd Standard deviation for this distribution.
  81.      * @return the distribution
  82.      * @throws IllegalArgumentException if {@code sd <= 0}.
  83.      */
  84.     public static NormalDistribution of(double mean,
  85.                                         double sd) {
  86.         if (sd > 0) {
  87.             return new NormalDistribution(mean, sd);
  88.         }
  89.         // zero, negative or nan
  90.         throw new DistributionException(DistributionException.NOT_STRICTLY_POSITIVE, sd);
  91.     }

  92.     /**
  93.      * Gets the standard deviation parameter of this distribution.
  94.      *
  95.      * @return the standard deviation.
  96.      */
  97.     public double getStandardDeviation() {
  98.         return standardDeviation;
  99.     }

  100.     /** {@inheritDoc} */
  101.     @Override
  102.     public double density(double x) {
  103.         final double z = (x - mean) / standardDeviation;
  104.         return ExtendedPrecision.expmhxx(z) / sdSqrt2pi;
  105.     }

  106.     /** {@inheritDoc} */
  107.     @Override
  108.     public double probability(double x0,
  109.                               double x1) {
  110.         if (x0 > x1) {
  111.             throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GT_HIGH,
  112.                                             x0, x1);
  113.         }
  114.         final double v0 = (x0 - mean) / sdSqrt2;
  115.         final double v1 = (x1 - mean) / sdSqrt2;
  116.         return 0.5 * ErfDifference.value(v0, v1);
  117.     }

  118.     /** {@inheritDoc} */
  119.     @Override
  120.     public double logDensity(double x) {
  121.         final double z = (x - mean) / standardDeviation;
  122.         return -0.5 * z * z - logStandardDeviationPlusHalfLog2Pi;
  123.     }

  124.     /** {@inheritDoc} */
  125.     @Override
  126.     public double cumulativeProbability(double x)  {
  127.         final double dev = x - mean;
  128.         return 0.5 * Erfc.value(-dev / sdSqrt2);
  129.     }

  130.     /** {@inheritDoc} */
  131.     @Override
  132.     public double survivalProbability(double x) {
  133.         final double dev = x - mean;
  134.         return 0.5 * Erfc.value(dev / sdSqrt2);
  135.     }

  136.     /** {@inheritDoc} */
  137.     @Override
  138.     public double inverseCumulativeProbability(double p) {
  139.         ArgumentUtils.checkProbability(p);
  140.         return mean - sdSqrt2 * InverseErfc.value(2 * p);
  141.     }

  142.     /** {@inheritDoc} */
  143.     @Override
  144.     public double inverseSurvivalProbability(double p) {
  145.         ArgumentUtils.checkProbability(p);
  146.         return mean + sdSqrt2 * InverseErfc.value(2 * p);
  147.     }

  148.     /** {@inheritDoc} */
  149.     @Override
  150.     public double getMean() {
  151.         return mean;
  152.     }

  153.     /**
  154.      * {@inheritDoc}
  155.      *
  156.      * <p>For standard deviation parameter \( \sigma \), the variance is \( \sigma^2 \).
  157.      */
  158.     @Override
  159.     public double getVariance() {
  160.         final double s = getStandardDeviation();
  161.         return s * s;
  162.     }

  163.     /**
  164.      * {@inheritDoc}
  165.      *
  166.      * <p>The lower bound of the support is always negative infinity.
  167.      *
  168.      * @return {@linkplain Double#NEGATIVE_INFINITY negative infinity}.
  169.      */
  170.     @Override
  171.     public double getSupportLowerBound() {
  172.         return Double.NEGATIVE_INFINITY;
  173.     }

  174.     /**
  175.      * {@inheritDoc}
  176.      *
  177.      * <p>The upper bound of the support is always positive infinity.
  178.      *
  179.      * @return {@linkplain Double#POSITIVE_INFINITY positive infinity}.
  180.      */
  181.     @Override
  182.     public double getSupportUpperBound() {
  183.         return Double.POSITIVE_INFINITY;
  184.     }

  185.     /** {@inheritDoc} */
  186.     @Override
  187.     public ContinuousDistribution.Sampler createSampler(final UniformRandomProvider rng) {
  188.         // Gaussian distribution sampler.
  189.         return GaussianSampler.of(ZigguratSampler.NormalizedGaussian.of(rng),
  190.                                   mean, standardDeviation)::sample;
  191.     }
  192. }