View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.statistics.distribution;
18  
19  import org.apache.commons.numbers.gamma.Erf;
20  import org.apache.commons.numbers.gamma.ErfDifference;
21  import org.apache.commons.numbers.gamma.Erfc;
22  import org.apache.commons.numbers.gamma.InverseErf;
23  import org.apache.commons.numbers.gamma.InverseErfc;
24  import org.apache.commons.rng.UniformRandomProvider;
25  import org.apache.commons.rng.sampling.distribution.GaussianSampler;
26  import org.apache.commons.rng.sampling.distribution.SharedStateContinuousSampler;
27  import org.apache.commons.rng.sampling.distribution.ZigguratSampler;
28  
29  /**
30   * Implementation of the folded normal distribution.
31   *
32   * <p>Given a normally distributed random variable \( X \) with mean \( \mu \) and variance
33   * \( \sigma^2 \), the random variable \( Y = |X| \) has a folded normal distribution. This is
34   * equivalent to not recording the sign from a normally distributed random variable.
35   *
36   * <p>The probability density function of \( X \) is:
37   *
38   * <p>\[ f(x; \mu, \sigma) = \frac 1 {\sigma\sqrt{2\pi}} e^{-{\frac 1 2}\left( \frac{x-\mu}{\sigma} \right)^2 } +
39   *                           \frac 1 {\sigma\sqrt{2\pi}} e^{-{\frac 1 2}\left( \frac{x+\mu}{\sigma} \right)^2 }\]
40   *
41   * <p>for \( \mu \) the location,
42   * \( \sigma &gt; 0 \) the scale, and
43   * \( x \in [0, \infty) \).
44   *
45   * <p>If the location \( \mu \) is 0 this reduces to the half-normal distribution.
46   *
47   * @see <a href="https://en.wikipedia.org/wiki/Folded_normal_distribution">Folded normal distribution (Wikipedia)</a>
48   * @see <a href="https://en.wikipedia.org/wiki/Half-normal_distribution">Half-normal distribution (Wikipedia)</a>
49   * @since 1.1
50   */
51  public abstract class FoldedNormalDistribution extends AbstractContinuousDistribution {
52      /** The scale. */
53      final double sigma;
54      /**
55       * The scale multiplied by sqrt(2).
56       * This is used to avoid a double division when computing the value passed to the
57       * error function:
58       * <pre>
59       *  ((x - u) / sd) / sqrt(2) == (x - u) / (sd * sqrt(2)).
60       *  </pre>
61       * <p>Note: Implementations may first normalise x and then divide by sqrt(2) resulting
62       * in differences due to rounding error that show increasingly large relative
63       * differences as the error function computes close to 0 in the extreme tail.
64       */
65      final double sigmaSqrt2;
66      /**
67       * The scale multiplied by sqrt(2 pi). Computed to high precision.
68       */
69      final double sigmaSqrt2pi;
70  
71      /**
72       * Regular implementation of the folded normal distribution.
73       */
74      private static class RegularFoldedNormalDistribution extends FoldedNormalDistribution {
75          /** The location. */
76          private final double mu;
77          /** Cached value for inverse probability function. */
78          private final double mean;
79          /** Cached value for inverse probability function. */
80          private final double variance;
81  
82          /**
83           * @param mu Location parameter.
84           * @param sigma Scale parameter.
85           */
86          RegularFoldedNormalDistribution(double mu, double sigma) {
87              super(sigma);
88              this.mu = mu;
89  
90              final double a = mu / sigmaSqrt2;
91              mean = sigma * Constants.ROOT_TWO_DIV_PI * Math.exp(-a * a) + mu * Erf.value(a);
92              this.variance = mu * mu + sigma * sigma - mean * mean;
93          }
94  
95          @Override
96          public double getMu() {
97              return mu;
98          }
99  
100         @Override
101         public double density(double x) {
102             if (x < 0) {
103                 return 0;
104             }
105             final double vm = (x - mu) / sigma;
106             final double vp = (x + mu) / sigma;
107             return (ExtendedPrecision.expmhxx(vm) + ExtendedPrecision.expmhxx(vp)) / sigmaSqrt2pi;
108         }
109 
110         @Override
111         public double probability(double x0,
112                                   double x1) {
113             if (x0 > x1) {
114                 throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GT_HIGH,
115                                                 x0, x1);
116             }
117             if (x0 <= 0) {
118                 return cumulativeProbability(x1);
119             }
120             // Assumes x1 >= x0 && x0 > 0
121             final double v0m = (x0 - mu) / sigmaSqrt2;
122             final double v1m = (x1 - mu) / sigmaSqrt2;
123             final double v0p = (x0 + mu) / sigmaSqrt2;
124             final double v1p = (x1 + mu) / sigmaSqrt2;
125             return 0.5 * (ErfDifference.value(v0m, v1m) + ErfDifference.value(v0p, v1p));
126         }
127 
128         @Override
129         public double cumulativeProbability(double x) {
130             if (x <= 0) {
131                 return 0;
132             }
133             return 0.5 * (Erf.value((x - mu) / sigmaSqrt2) + Erf.value((x + mu) / sigmaSqrt2));
134         }
135 
136         @Override
137         public double survivalProbability(double x) {
138             if (x <= 0) {
139                 return 1;
140             }
141             return 0.5 * (Erfc.value((x - mu) / sigmaSqrt2) + Erfc.value((x + mu) / sigmaSqrt2));
142         }
143 
144         @Override
145         public double getMean() {
146             return mean;
147         }
148 
149         @Override
150         public double getVariance() {
151             return variance;
152         }
153 
154         @Override
155         public Sampler createSampler(UniformRandomProvider rng) {
156             // Return the absolute of a Gaussian distribution sampler.
157             final SharedStateContinuousSampler s =
158                 GaussianSampler.of(ZigguratSampler.NormalizedGaussian.of(rng), mu, sigma);
159             return () -> Math.abs(s.sample());
160         }
161     }
162 
163     /**
164      * Specialisation for the half-normal distribution.
165      *
166      * <p>Elimination of the {@code mu} location parameter simplifies the probability
167      * functions and allows computation of the log density and inverse CDF/SF.
168      */
169     private static class HalfNormalDistribution extends FoldedNormalDistribution {
170         /** Variance constant (1 - 2/pi). Computed using Matlab's VPA to 30 digits. */
171         private static final double VAR = 0.36338022763241865692446494650994;
172         /** The value of {@code log(sigma) + 0.5 * log(2*PI)} stored for faster computation. */
173         private final double logSigmaPlusHalfLog2Pi;
174 
175         /**
176          * @param sigma Scale parameter.
177          */
178         HalfNormalDistribution(double sigma) {
179             super(sigma);
180             logSigmaPlusHalfLog2Pi = Math.log(sigma) + Constants.HALF_LOG_TWO_PI;
181         }
182 
183         @Override
184         public double getMu() {
185             return 0;
186         }
187 
188         @Override
189         public double density(double x) {
190             if (x < 0) {
191                 return 0;
192             }
193             return 2 * ExtendedPrecision.expmhxx(x / sigma) / sigmaSqrt2pi;
194         }
195 
196         @Override
197         public double probability(double x0,
198                                   double x1) {
199             if (x0 > x1) {
200                 throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GT_HIGH,
201                                                 x0, x1);
202             }
203             if (x0 <= 0) {
204                 return cumulativeProbability(x1);
205             }
206             // Assumes x1 >= x0 && x0 > 0
207             return ErfDifference.value(x0 / sigmaSqrt2, x1 / sigmaSqrt2);
208         }
209 
210         @Override
211         public double logDensity(double x) {
212             if (x < 0) {
213                 return Double.NEGATIVE_INFINITY;
214             }
215             final double z = x / sigma;
216             return Constants.LN_TWO - 0.5 * z * z - logSigmaPlusHalfLog2Pi;
217         }
218 
219         @Override
220         public double cumulativeProbability(double x) {
221             if (x <= 0) {
222                 return 0;
223             }
224             return Erf.value(x / sigmaSqrt2);
225         }
226 
227         @Override
228         public double survivalProbability(double x) {
229             if (x <= 0) {
230                 return 1;
231             }
232             return Erfc.value(x / sigmaSqrt2);
233         }
234 
235         @Override
236         public double inverseCumulativeProbability(double p) {
237             ArgumentUtils.checkProbability(p);
238             // Addition of 0.0 ensures 0.0 is returned for p=-0.0
239             return 0.0 + sigmaSqrt2 * InverseErf.value(p);
240         }
241 
242         /** {@inheritDoc} */
243         @Override
244         public double inverseSurvivalProbability(double p) {
245             ArgumentUtils.checkProbability(p);
246             return sigmaSqrt2 * InverseErfc.value(p);
247         }
248 
249         @Override
250         public double getMean() {
251             return sigma * Constants.ROOT_TWO_DIV_PI;
252         }
253 
254         @Override
255         public double getVariance() {
256             // sigma^2 - mean^2
257             // sigma^2 - (sigma^2 * 2/pi)
258             return sigma * sigma * VAR;
259         }
260 
261         @Override
262         public Sampler createSampler(UniformRandomProvider rng) {
263             // Return the absolute of a Gaussian distribution sampler.
264             final SharedStateContinuousSampler s = ZigguratSampler.NormalizedGaussian.of(rng);
265             return () -> Math.abs(s.sample() * sigma);
266         }
267     }
268 
269     /**
270      * @param sigma Scale parameter.
271      */
272     FoldedNormalDistribution(double sigma) {
273         this.sigma = sigma;
274         // Minimise rounding error by computing sqrt(2 * sigma * sigma) exactly.
275         // Compute using extended precision with care to avoid over/underflow.
276         sigmaSqrt2 = ExtendedPrecision.sqrt2xx(sigma);
277         // Compute sigma * sqrt(2 * pi)
278         sigmaSqrt2pi = ExtendedPrecision.xsqrt2pi(sigma);
279     }
280 
281     /**
282      * Creates a folded normal distribution. If the location {@code mu} is zero this is
283      * the half-normal distribution.
284      *
285      * @param mu Location parameter.
286      * @param sigma Scale parameter.
287      * @return the distribution
288      * @throws IllegalArgumentException if {@code sigma <= 0}.
289      */
290     public static FoldedNormalDistribution of(double mu,
291                                               double sigma) {
292         if (sigma > 0) {
293             if (mu == 0) {
294                 return new HalfNormalDistribution(sigma);
295             }
296             return new RegularFoldedNormalDistribution(mu, sigma);
297         }
298         // scale is zero, negative or nan
299         throw new DistributionException(DistributionException.NOT_STRICTLY_POSITIVE, sigma);
300     }
301 
302     /**
303      * Gets the location parameter \( \mu \) of this distribution.
304      *
305      * @return the mu parameter.
306      */
307     public abstract double getMu();
308 
309     /**
310      * Gets the scale parameter \( \sigma \) of this distribution.
311      *
312      * @return the sigma parameter.
313      */
314     public double getSigma() {
315         return sigma;
316     }
317 
318     /**
319      * {@inheritDoc}
320      *
321      *
322      * <p>For location parameter \( \mu \) and scale parameter \( \sigma \), the mean is:
323      *
324      * <p>\[ \sigma \sqrt{ \frac 2 \pi } \exp \left( \frac{-\mu^2}{2\sigma^2} \right) +
325      *       \mu \operatorname{erf} \left( \frac \mu {\sqrt{2\sigma^2}} \right) \]
326      *
327      * <p>where \( \operatorname{erf} \) is the error function.
328      */
329     @Override
330     public abstract double getMean();
331 
332     /**
333      * {@inheritDoc}
334      *
335      * <p>For location parameter \( \mu \), scale parameter \( \sigma \) and a distribution
336      * mean \( \mu_Y \), the variance is:
337      *
338      * <p>\[ \mu^2 + \sigma^2 - \mu_{Y}^2 \]
339      */
340     @Override
341     public abstract double getVariance();
342 
343     /**
344      * {@inheritDoc}
345      *
346      * <p>The lower bound of the support is always 0.
347      *
348      * @return 0.
349      */
350     @Override
351     public double getSupportLowerBound() {
352         return 0.0;
353     }
354 
355     /**
356      * {@inheritDoc}
357      *
358      * <p>The upper bound of the support is always positive infinity.
359      *
360      * @return {@linkplain Double#POSITIVE_INFINITY positive infinity}.
361      */
362     @Override
363     public double getSupportUpperBound() {
364         return Double.POSITIVE_INFINITY;
365     }
366 }