View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.statistics.distribution;
19  
20  import org.apache.commons.numbers.gamma.ErfDifference;
21  import org.apache.commons.numbers.gamma.Erfc;
22  import org.apache.commons.numbers.gamma.InverseErfc;
23  import org.apache.commons.rng.UniformRandomProvider;
24  import org.apache.commons.rng.sampling.distribution.GaussianSampler;
25  import org.apache.commons.rng.sampling.distribution.ZigguratSampler;
26  
27  /**
28   * Implementation of the normal (Gaussian) distribution.
29   *
30   * <p>The probability density function of \( X \) is:
31   *
32   * <p>\[ f(x; \mu, \sigma) = \frac 1 {\sigma\sqrt{2\pi}} e^{-{\frac 1 2}\left( \frac{x-\mu}{\sigma} \right)^2 } \]
33   *
34   * <p>for \( \mu \) the mean,
35   * \( \sigma &gt; 0 \) the standard deviation, and
36   * \( x \in (-\infty, \infty) \).
37   *
38   * @see <a href="https://en.wikipedia.org/wiki/Normal_distribution">Normal distribution (Wikipedia)</a>
39   * @see <a href="https://mathworld.wolfram.com/NormalDistribution.html">Normal distribution (MathWorld)</a>
40   */
41  public final class NormalDistribution extends AbstractContinuousDistribution {
42      /** Mean of this distribution. */
43      private final double mean;
44      /** Standard deviation of this distribution. */
45      private final double standardDeviation;
46      /** The value of {@code log(sd) + 0.5*log(2*pi)} stored for faster computation. */
47      private final double logStandardDeviationPlusHalfLog2Pi;
48      /**
49       * Standard deviation multiplied by sqrt(2).
50       * This is used to avoid a double division when computing the value passed to the
51       * error function:
52       * <pre>
53       *  ((x - u) / sd) / sqrt(2) == (x - u) / (sd * sqrt(2)).
54       *  </pre>
55       * <p>Note: Implementations may first normalise x and then divide by sqrt(2) resulting
56       * in differences due to rounding error that show increasingly large relative
57       * differences as the error function computes close to 0 in the extreme tail.
58       */
59      private final double sdSqrt2;
60      /**
61       * Standard deviation multiplied by sqrt(2 pi). Computed to high precision.
62       */
63      private final double sdSqrt2pi;
64  
65      /**
66       * @param mean Mean for this distribution.
67       * @param sd Standard deviation for this distribution.
68       */
69      private NormalDistribution(double mean,
70                                 double sd) {
71          this.mean = mean;
72          standardDeviation = sd;
73          logStandardDeviationPlusHalfLog2Pi = Math.log(sd) + Constants.HALF_LOG_TWO_PI;
74          // Minimise rounding error by computing sqrt(2 * sd * sd) exactly.
75          // Compute using extended precision with care to avoid over/underflow.
76          sdSqrt2 = ExtendedPrecision.sqrt2xx(sd);
77          // Compute sd * sqrt(2 * pi)
78          sdSqrt2pi = ExtendedPrecision.xsqrt2pi(sd);
79      }
80  
81      /**
82       * Creates a normal distribution.
83       *
84       * @param mean Mean for this distribution.
85       * @param sd Standard deviation for this distribution.
86       * @return the distribution
87       * @throws IllegalArgumentException if {@code sd <= 0}.
88       */
89      public static NormalDistribution of(double mean,
90                                          double sd) {
91          if (sd > 0) {
92              return new NormalDistribution(mean, sd);
93          }
94          // zero, negative or nan
95          throw new DistributionException(DistributionException.NOT_STRICTLY_POSITIVE, sd);
96      }
97  
98      /**
99       * Gets the standard deviation parameter of this distribution.
100      *
101      * @return the standard deviation.
102      */
103     public double getStandardDeviation() {
104         return standardDeviation;
105     }
106 
107     /** {@inheritDoc} */
108     @Override
109     public double density(double x) {
110         final double z = (x - mean) / standardDeviation;
111         return ExtendedPrecision.expmhxx(z) / sdSqrt2pi;
112     }
113 
114     /** {@inheritDoc} */
115     @Override
116     public double probability(double x0,
117                               double x1) {
118         if (x0 > x1) {
119             throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GT_HIGH,
120                                             x0, x1);
121         }
122         final double v0 = (x0 - mean) / sdSqrt2;
123         final double v1 = (x1 - mean) / sdSqrt2;
124         return 0.5 * ErfDifference.value(v0, v1);
125     }
126 
127     /** {@inheritDoc} */
128     @Override
129     public double logDensity(double x) {
130         final double z = (x - mean) / standardDeviation;
131         return -0.5 * z * z - logStandardDeviationPlusHalfLog2Pi;
132     }
133 
134     /** {@inheritDoc} */
135     @Override
136     public double cumulativeProbability(double x)  {
137         final double dev = x - mean;
138         return 0.5 * Erfc.value(-dev / sdSqrt2);
139     }
140 
141     /** {@inheritDoc} */
142     @Override
143     public double survivalProbability(double x) {
144         final double dev = x - mean;
145         return 0.5 * Erfc.value(dev / sdSqrt2);
146     }
147 
148     /** {@inheritDoc} */
149     @Override
150     public double inverseCumulativeProbability(double p) {
151         ArgumentUtils.checkProbability(p);
152         return mean - sdSqrt2 * InverseErfc.value(2 * p);
153     }
154 
155     /** {@inheritDoc} */
156     @Override
157     public double inverseSurvivalProbability(double p) {
158         ArgumentUtils.checkProbability(p);
159         return mean + sdSqrt2 * InverseErfc.value(2 * p);
160     }
161 
162     /** {@inheritDoc} */
163     @Override
164     public double getMean() {
165         return mean;
166     }
167 
168     /**
169      * {@inheritDoc}
170      *
171      * <p>For standard deviation parameter \( \sigma \), the variance is \( \sigma^2 \).
172      */
173     @Override
174     public double getVariance() {
175         final double s = getStandardDeviation();
176         return s * s;
177     }
178 
179     /**
180      * {@inheritDoc}
181      *
182      * <p>The lower bound of the support is always negative infinity.
183      *
184      * @return {@linkplain Double#NEGATIVE_INFINITY negative infinity}.
185      */
186     @Override
187     public double getSupportLowerBound() {
188         return Double.NEGATIVE_INFINITY;
189     }
190 
191     /**
192      * {@inheritDoc}
193      *
194      * <p>The upper bound of the support is always positive infinity.
195      *
196      * @return {@linkplain Double#POSITIVE_INFINITY positive infinity}.
197      */
198     @Override
199     public double getSupportUpperBound() {
200         return Double.POSITIVE_INFINITY;
201     }
202 
203     /** {@inheritDoc} */
204     @Override
205     public ContinuousDistribution.Sampler createSampler(final UniformRandomProvider rng) {
206         // Gaussian distribution sampler.
207         return GaussianSampler.of(ZigguratSampler.NormalizedGaussian.of(rng),
208                                   mean, standardDeviation)::sample;
209     }
210 }