1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.statistics.distribution;
19
20 import org.apache.commons.numbers.gamma.ErfDifference;
21 import org.apache.commons.numbers.gamma.Erfc;
22 import org.apache.commons.numbers.gamma.InverseErfc;
23 import org.apache.commons.rng.UniformRandomProvider;
24 import org.apache.commons.rng.sampling.distribution.GaussianSampler;
25 import org.apache.commons.rng.sampling.distribution.ZigguratSampler;
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41 public final class NormalDistribution extends AbstractContinuousDistribution {
42
43 private final double mean;
44
45 private final double standardDeviation;
46
47 private final double logStandardDeviationPlusHalfLog2Pi;
48
49
50
51
52
53
54
55
56
57
58
59 private final double sdSqrt2;
60
61
62
63 private final double sdSqrt2pi;
64
65
66
67
68
69 private NormalDistribution(double mean,
70 double sd) {
71 this.mean = mean;
72 standardDeviation = sd;
73 logStandardDeviationPlusHalfLog2Pi = Math.log(sd) + Constants.HALF_LOG_TWO_PI;
74
75
76 sdSqrt2 = ExtendedPrecision.sqrt2xx(sd);
77
78 sdSqrt2pi = ExtendedPrecision.xsqrt2pi(sd);
79 }
80
81
82
83
84
85
86
87
88
89 public static NormalDistribution of(double mean,
90 double sd) {
91 if (sd > 0) {
92 return new NormalDistribution(mean, sd);
93 }
94
95 throw new DistributionException(DistributionException.NOT_STRICTLY_POSITIVE, sd);
96 }
97
98
99
100
101
102
103 public double getStandardDeviation() {
104 return standardDeviation;
105 }
106
107
108 @Override
109 public double density(double x) {
110 final double z = (x - mean) / standardDeviation;
111 return ExtendedPrecision.expmhxx(z) / sdSqrt2pi;
112 }
113
114
115 @Override
116 public double probability(double x0,
117 double x1) {
118 if (x0 > x1) {
119 throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GT_HIGH,
120 x0, x1);
121 }
122 final double v0 = (x0 - mean) / sdSqrt2;
123 final double v1 = (x1 - mean) / sdSqrt2;
124 return 0.5 * ErfDifference.value(v0, v1);
125 }
126
127
128 @Override
129 public double logDensity(double x) {
130 final double z = (x - mean) / standardDeviation;
131 return -0.5 * z * z - logStandardDeviationPlusHalfLog2Pi;
132 }
133
134
135 @Override
136 public double cumulativeProbability(double x) {
137 final double dev = x - mean;
138 return 0.5 * Erfc.value(-dev / sdSqrt2);
139 }
140
141
142 @Override
143 public double survivalProbability(double x) {
144 final double dev = x - mean;
145 return 0.5 * Erfc.value(dev / sdSqrt2);
146 }
147
148
149 @Override
150 public double inverseCumulativeProbability(double p) {
151 ArgumentUtils.checkProbability(p);
152 return mean - sdSqrt2 * InverseErfc.value(2 * p);
153 }
154
155
156 @Override
157 public double inverseSurvivalProbability(double p) {
158 ArgumentUtils.checkProbability(p);
159 return mean + sdSqrt2 * InverseErfc.value(2 * p);
160 }
161
162
163 @Override
164 public double getMean() {
165 return mean;
166 }
167
168
169
170
171
172
173 @Override
174 public double getVariance() {
175 final double s = getStandardDeviation();
176 return s * s;
177 }
178
179
180
181
182
183
184
185
186 @Override
187 public double getSupportLowerBound() {
188 return Double.NEGATIVE_INFINITY;
189 }
190
191
192
193
194
195
196
197
198 @Override
199 public double getSupportUpperBound() {
200 return Double.POSITIVE_INFINITY;
201 }
202
203
204 @Override
205 public ContinuousDistribution.Sampler createSampler(final UniformRandomProvider rng) {
206
207 return GaussianSampler.of(ZigguratSampler.NormalizedGaussian.of(rng),
208 mean, standardDeviation)::sample;
209 }
210 }