001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.statistics.distribution;
019
020import org.apache.commons.numbers.gamma.Erfc;
021import org.apache.commons.numbers.gamma.InverseErfc;
022import org.apache.commons.numbers.gamma.ErfDifference;
023import org.apache.commons.rng.UniformRandomProvider;
024import org.apache.commons.rng.sampling.distribution.GaussianSampler;
025import org.apache.commons.rng.sampling.distribution.ZigguratSampler;
026
027/**
028 * Implementation of the normal (Gaussian) distribution.
029 *
030 * <p>The probability density function of \( X \) is:
031 *
032 * <p>\[ f(x; \mu, \sigma) = \frac 1 {\sigma\sqrt{2\pi}} e^{-{\frac 1 2}\left( \frac{x-\mu}{\sigma} \right)^2 } \]
033 *
034 * <p>for \( \mu \) the mean,
035 * \( \sigma &gt; 0 \) the standard deviation, and
036 * \( x \in (-\infty, \infty) \).
037 *
038 * @see <a href="https://en.wikipedia.org/wiki/Normal_distribution">Normal distribution (Wikipedia)</a>
039 * @see <a href="https://mathworld.wolfram.com/NormalDistribution.html">Normal distribution (MathWorld)</a>
040 */
041public final class NormalDistribution extends AbstractContinuousDistribution {
042    /** 0.5 * ln(2 * pi). Computed to 25-digits precision. */
043    private static final double HALF_LOG_TWO_PI = 0.9189385332046727417803297;
044
045    /** Mean of this distribution. */
046    private final double mean;
047    /** Standard deviation of this distribution. */
048    private final double standardDeviation;
049    /** The value of {@code log(sd) + 0.5*log(2*pi)} stored for faster computation. */
050    private final double logStandardDeviationPlusHalfLog2Pi;
051    /**
052     * Standard deviation multiplied by sqrt(2).
053     * This is used to avoid a double division when computing the value passed to the
054     * error function:
055     * <pre>
056     *  ((x - u) / sd) / sqrt(2) == (x - u) / (sd * sqrt(2)).
057     *  </pre>
058     * <p>Note: Implementations may first normalise x and then divide by sqrt(2) resulting
059     * in differences due to rounding error that show increasingly large relative
060     * differences as the error function computes close to 0 in the extreme tail.
061     */
062    private final double sdSqrt2;
063    /**
064     * Standard deviation multiplied by sqrt(2 pi). Computed to high precision.
065     */
066    private final double sdSqrt2pi;
067
068    /**
069     * @param mean Mean for this distribution.
070     * @param sd Standard deviation for this distribution.
071     */
072    private NormalDistribution(double mean,
073                               double sd) {
074        this.mean = mean;
075        standardDeviation = sd;
076        logStandardDeviationPlusHalfLog2Pi = Math.log(sd) + HALF_LOG_TWO_PI;
077        // Minimise rounding error by computing sqrt(2 * sd * sd) exactly.
078        // Compute using extended precision with care to avoid over/underflow.
079        sdSqrt2 = ExtendedPrecision.sqrt2xx(sd);
080        // Compute sd * sqrt(2 * pi)
081        sdSqrt2pi = ExtendedPrecision.xsqrt2pi(sd);
082    }
083
084    /**
085     * Creates a normal distribution.
086     *
087     * @param mean Mean for this distribution.
088     * @param sd Standard deviation for this distribution.
089     * @return the distribution
090     * @throws IllegalArgumentException if {@code sd <= 0}.
091     */
092    public static NormalDistribution of(double mean,
093                                        double sd) {
094        if (sd > 0) {
095            return new NormalDistribution(mean, sd);
096        }
097        // zero, negative or nan
098        throw new DistributionException(DistributionException.NOT_STRICTLY_POSITIVE, sd);
099    }
100
101    /**
102     * Gets the standard deviation parameter of this distribution.
103     *
104     * @return the standard deviation.
105     */
106    public double getStandardDeviation() {
107        return standardDeviation;
108    }
109
110    /** {@inheritDoc} */
111    @Override
112    public double density(double x) {
113        final double z = (x - mean) / standardDeviation;
114        return ExtendedPrecision.expmhxx(z) / sdSqrt2pi;
115    }
116
117    /** {@inheritDoc} */
118    @Override
119    public double probability(double x0,
120                              double x1) {
121        if (x0 > x1) {
122            throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GT_HIGH,
123                                            x0, x1);
124        }
125        final double v0 = (x0 - mean) / sdSqrt2;
126        final double v1 = (x1 - mean) / sdSqrt2;
127        return 0.5 * ErfDifference.value(v0, v1);
128    }
129
130    /** {@inheritDoc} */
131    @Override
132    public double logDensity(double x) {
133        final double z = (x - mean) / standardDeviation;
134        return -0.5 * z * z - logStandardDeviationPlusHalfLog2Pi;
135    }
136
137    /** {@inheritDoc} */
138    @Override
139    public double cumulativeProbability(double x)  {
140        final double dev = x - mean;
141        return 0.5 * Erfc.value(-dev / sdSqrt2);
142    }
143
144    /** {@inheritDoc} */
145    @Override
146    public double survivalProbability(double x) {
147        final double dev = x - mean;
148        return 0.5 * Erfc.value(dev / sdSqrt2);
149    }
150
151    /** {@inheritDoc} */
152    @Override
153    public double inverseCumulativeProbability(double p) {
154        ArgumentUtils.checkProbability(p);
155        return mean - sdSqrt2 * InverseErfc.value(2 * p);
156    }
157
158    /** {@inheritDoc} */
159    @Override
160    public double inverseSurvivalProbability(double p) {
161        ArgumentUtils.checkProbability(p);
162        return mean + sdSqrt2 * InverseErfc.value(2 * p);
163    }
164
165    /** {@inheritDoc} */
166    @Override
167    public double getMean() {
168        return mean;
169    }
170
171    /**
172     * {@inheritDoc}
173     *
174     * <p>For standard deviation parameter \( \sigma \), the variance is \( \sigma^2 \).
175     */
176    @Override
177    public double getVariance() {
178        final double s = getStandardDeviation();
179        return s * s;
180    }
181
182    /**
183     * {@inheritDoc}
184     *
185     * <p>The lower bound of the support is always negative infinity.
186     *
187     * @return {@link Double#NEGATIVE_INFINITY negative infinity}.
188     */
189    @Override
190    public double getSupportLowerBound() {
191        return Double.NEGATIVE_INFINITY;
192    }
193
194    /**
195     * {@inheritDoc}
196     *
197     * <p>The upper bound of the support is always positive infinity.
198     *
199     * @return {@link Double#POSITIVE_INFINITY positive infinity}.
200     */
201    @Override
202    public double getSupportUpperBound() {
203        return Double.POSITIVE_INFINITY;
204    }
205
206    /** {@inheritDoc} */
207    @Override
208    public ContinuousDistribution.Sampler createSampler(final UniformRandomProvider rng) {
209        // Gaussian distribution sampler.
210        return GaussianSampler.of(ZigguratSampler.NormalizedGaussian.of(rng),
211                                  mean, standardDeviation)::sample;
212    }
213}