1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.statistics.distribution;
19
20 import org.apache.commons.numbers.gamma.Erf;
21 import org.apache.commons.numbers.gamma.Erfcx;
22 import org.junit.jupiter.api.Assertions;
23 import org.junit.jupiter.api.Test;
24 import org.junit.jupiter.params.ParameterizedTest;
25 import org.junit.jupiter.params.provider.CsvSource;
26
27
28
29
30
31
32 class TruncatedNormalDistributionTest extends BaseContinuousDistributionTest {
33 @Override
34 ContinuousDistribution makeDistribution(Object... parameters) {
35 final double mean = (Double) parameters[0];
36 final double sd = (Double) parameters[1];
37 final double upper = (Double) parameters[2];
38 final double lower = (Double) parameters[3];
39 return TruncatedNormalDistribution.of(mean, sd, upper, lower);
40 }
41
42 @Override
43 Object[][] makeInvalidParameters() {
44 return new Object[][] {
45 {0.0, 0.0, -1.0, 1.0},
46 {0.0, -0.1, -1.0, 1.0},
47 {0.0, 1.0, 1.0, -1.0},
48
49 {0.0, 1.0, 100.0, 101.0},
50 };
51 }
52
53 @Override
54 String[] getParameterNames() {
55
56
57 return new String[] {null, null, "SupportLowerBound", "SupportUpperBound"};
58 }
59
60 @Override
61 protected double getRelativeTolerance() {
62 return 1e-14;
63 }
64
65
66
67
68
69
70
71
72
73
74
75
76
77 @ParameterizedTest
78 @CsvSource({
79 "0.0, 1.0, -4, 6",
80 "1.0, 2.0, -4, 6",
81 "3.45, 6.78, -8, 10",
82 })
83 void testMomentsEffectivelyNoTruncation(double mean, double sd, double lower, double upper) {
84 double inf = Double.POSITIVE_INFINITY;
85 double max = Double.MAX_VALUE;
86 TruncatedNormalDistribution dist1;
87 TruncatedNormalDistribution dist2;
88
89 dist1 = TruncatedNormalDistribution.of(mean, sd, -inf, upper);
90 dist2 = TruncatedNormalDistribution.of(mean, sd, -max, upper);
91 Assertions.assertEquals(dist1.getMean(), dist2.getMean(), "Mean");
92 Assertions.assertEquals(dist1.getVariance(), dist2.getVariance(), "Variance");
93
94 dist1 = TruncatedNormalDistribution.of(mean, sd, lower, inf);
95 dist2 = TruncatedNormalDistribution.of(mean, sd, lower, max);
96 Assertions.assertEquals(dist1.getMean(), dist2.getMean(), "Mean");
97 Assertions.assertEquals(dist1.getVariance(), dist2.getVariance(), "Variance");
98
99 dist1 = TruncatedNormalDistribution.of(mean, sd, -inf, inf);
100 dist2 = TruncatedNormalDistribution.of(mean, sd, -max, max);
101 Assertions.assertEquals(dist1.getMean(), dist2.getMean(), "Mean");
102 Assertions.assertEquals(dist1.getVariance(), dist2.getVariance(), "Variance");
103 }
104
105
106
107
108
109
110
111
112 @Test
113 void testMean() {
114 assertMean(Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 0, 0);
115 assertMean(Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY, 0);
116 assertMean(Double.NEGATIVE_INFINITY, Double.NEGATIVE_INFINITY, Double.NEGATIVE_INFINITY, 0);
117 assertMean(0, Double.POSITIVE_INFINITY, Math.sqrt(2 / Math.PI), 1e-15);
118 assertMean(Double.NEGATIVE_INFINITY, 0, -Math.sqrt(2 / Math.PI), 1e-15);
119
120 for (int x = -10; x <= 10; x++) {
121 final double expected = Math.sqrt(2 / Math.PI) / Erfcx.value(x / Math.sqrt(2));
122 assertMean(x, Double.POSITIVE_INFINITY, expected, 1e-15);
123 }
124
125 for (int i = -100; i <= 100; i++) {
126 final double x = Math.exp(i);
127 assertMean(-x, x, 0, 0);
128 final double expected = -Math.sqrt(2 / Math.PI) * Math.expm1(-x * x / 2) / Erf.value(x / Math.sqrt(2));
129 assertMean(0, x, expected, 1e-15);
130 }
131
132 assertMean(1e-44, 1e-43, 5.4999999999999999999999999999999999999999e-44, 1e-15);
133
134 assertMean(100, 115, 100.00999800099926070518490239457545847490332879043, 1e-15);
135 assertMean(-1e6, -999000, -999000.00000100100100099899498898098, 1e-15);
136 assertMean(+1e6, Double.POSITIVE_INFINITY, +1.00000000000099999999999800000e6, 1e-15);
137 assertMean(Double.NEGATIVE_INFINITY, -1e6, -1.00000000000099999999999800000e6, 1e-15);
138
139 assertMean(-1e200, 1e200, 0, 1e-15);
140 assertMean(0, +1e200, +0.797884560802865355879892119869, 1e-15);
141 assertMean(-1e200, 0, -0.797884560802865355879892119869, 1e-15);
142
143 assertMean(50, 70, -2, 3, 50.171943499898757645751683644632860837133138152489, 1e-15);
144 assertMean(-100.0, 0.0, 0.0, 2.0986317998643735, -1.6744659119217125058885983754999713622460154892645, 1e-15);
145 assertMean(0.0, 0.9, 0.0, 0.07132755843183151, 0.056911157632522598806524588414964004271754161737065, 1e-15);
146 assertMean(-100.0, 100.0, 0.0, 17.185261847875548, 0, 1e-15);
147 assertMean(-100.0, 0.5, 0.0, 0.47383322897860064, -0.1267981330521791493635176736743283314399, 1e-15);
148 assertMean(-100.0, 100.0, 0.0, 17.185261847875548, 0, 1e-15);
149
150 for (int i = -10; i <= 10; i++) {
151 final double a = Math.exp(i);
152 for (int j = -10; j <= 10; j++) {
153 final double b = Math.exp(j);
154 if (a <= b) {
155 final double mean = TruncatedNormalDistribution.moment1(a, b);
156 Assertions.assertTrue(a <= mean && mean <= b);
157 }
158 }
159 }
160
161
162 assertMean(0, 1000, 1000000, 1, 999.99999899899899900100501101901899090472046236710608108591983, 6e-14);
163 }
164
165
166
167
168
169
170
171
172 @Test
173 void testVariance() {
174 assertVariance(Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 1, 0);
175 assertVariance(Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY, 0, 0);
176 assertVariance(Double.NEGATIVE_INFINITY, Double.NEGATIVE_INFINITY, 0, 0);
177 assertVariance(0, Double.POSITIVE_INFINITY, 1 - 2 / Math.PI, 1e-15);
178 assertVariance(Double.NEGATIVE_INFINITY, 0, 1 - 2 / Math.PI, 1e-15);
179
180 for (int x = -10; x <= 10; x++) {
181 final double expected = 1 + Math.sqrt(2 / Math.PI) * x / Erfcx.value(x / Math.sqrt(2)) -
182 (2 / Math.PI) / Math.pow(Erfcx.value(x / Math.sqrt(2)), 2);
183 assertVariance(x, Double.POSITIVE_INFINITY, expected, 1e-11);
184 }
185
186 assertVariance(50, 70, 0.0003990431868038995479099272265360593305365, 1e-9);
187
188 assertVariance(50, 70, -2, 3, 0.029373438107168350377591231295634273607812172191712, 1e-11);
189 assertVariance(-100.0, 0.0, 0.0, 2.0986317998643735, 1.6004193412141677189841357987638847137391508803335, 1e-15);
190 assertVariance(0.0, 0.9, 0.0, 0.07132755843183151, 0.0018487407287725028827020557707636415445504260892486, 1e-15);
191 assertVariance(-100.0, 100.0, 0.0, 17.185261847875548, 295.333163899557735486302841237124507431445, 1e-15);
192 assertVariance(-100.0, 0.5, 0.0, 0.47383322897860064, 0.145041095812679283837328561547251019229612, 1e-15);
193 assertVariance(-100.0, 100.0, 0.0, 17.185261847875548, 295.333163899557735486302841237124507431445, 1e-15);
194 assertVariance(-10000, 10000, 0, 1, 1, 1e-15);
195
196
197 Assertions.assertTrue(TruncatedNormalDistribution.variance(999000, 1e6) >= 0);
198 Assertions.assertTrue(TruncatedNormalDistribution.variance(-1000000, 1000 - 1000000) >= 0);
199
200
201
202
203
204
205 }
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240 @ParameterizedTest
241 @CsvSource({
242
243 "1.23, 1.23, 1.23, 0, 0, 0",
244 "1.23, 4.56, 1.7122093853640246, 0.1739856461219162, 1e-15, 5e-15",
245
246
247 "-55, 60, 0, 1, 0, 0",
248
249
250 "-100, 101, 1.3443134677817230433408433600205167e-2172, 1, 1e-15, 1e-15",
251 "-40, 101, 1.46327025083830317873709720033828097e-348, 1, 1e-15, 1e-15",
252 "-30, 101, 1.47364613487854751904949326604507453e-196, 1, 1e-15, 1e-15",
253 "-20, 101, 5.52094836215976318958273568278700042e-88, 1, 1e-15, 1e-15",
254 "-10, 101, 7.69459862670641934633909221175249367e-23, 0.999999999999999999999230540137329438, 1e-15, 1e-15",
255 "-5, 101, 1.48671994090490571244174411946057083e-06, 0.999992566398085139288753504945569711, 1e-15, 1e-15",
256 "-1, 101, 0.287599970939178361228670127385217202, 0.629686285776605400861244494862843017, 1e-15, 1e-15",
257 "0, 101, 0.797884560802865355879892119868763748, 0.363380227632418656924464946509942526, 1e-15, 1e-15",
258 "1, 101, 1.52513527616098120908909053639057876, 0.199097665570348791553367979096726767, 1e-15, 1e-14",
259 "5, 101, 5.18650396712584211561650896200523673, 0.032696434617112225345315807700917674, 1e-15, 1e-13",
260 "10, 101, 10.0980932339625119628436416537120371, 0.00944537782565626116413681765035684208, 1e-15, 1e-11",
261 "20, 101, 20.0497530685278505422140233087209891, 0.00246326161505216359968528619980015911, 1e-15, 1e-11",
262 "30, 101, 30.033259667433677037071124100012257, 0.00110377151189009100113674138540728116, 1e-15, 1e-10",
263 "40, 101, 40.0249688472072637232448709953697417, 0.000622668378591388773498879400697584317, 1e-15, 2e-9",
264 "100, 101, 100.009998000999260705184902394575471, 9.99400499482634503612772420030347819e-05, 1e-15, 2e-8",
265
266
267 "-5, Infinity, 1.4867199409049057124417441194605712e-06, 0.999992566398085139288753504945569711, 1e-14, 1e-14",
268 "-3, Infinity, 0.00443783904212566379330210431090259846, 0.98666678845825919379095350748267984, 1e-15, 1e-15",
269 "-1, Infinity, 0.287599970939178361228670127385217154, 0.629686285776605400861244494862843306, 1e-15, 1e-15",
270 "0, Infinity, 0.797884560802865355879892119868763748, 0.363380227632418656924464946509942526, 1e-15, 1e-15",
271 "1, Infinity, 1.52513527616098120908909053639057876, 0.199097665570348791553367979096726767, 1e-15, 1e-15",
272 "3, Infinity, 3.28309865493043650692809222681220005, 0.0705591867852681168624020577420568271, 1e-15, 2e-14",
273 "20, Infinity, 20.0497530685278505422140233087209891, 0.00246326161505216359968528619980015911, 1e-15, 1e-11",
274 "100, Infinity, 100.009998000999260705184902394575471, 9.99400499482634503612772420030347819e-05, 1e-15, 4e-8",
275
276 "1e4, Infinity, 10000.0000999999980000000999999925986, 9.99999940000005002391967510312099493e-09, 1e-15, 0.8",
277 "1e6, Infinity, 1000000.00000099999999999800000000016, 9.99999999770471649802883928921316157e-13, 1e-15, 1.0",
278
279
280
281 "1e100, Infinity, 1.00000000000000001590289110975991788e+100, 0, 1e-15, -1",
282
283
284
285 "1e290, 1e300, 1.00000000000000006172783352786715689e+290, 0, 1e-15, -1",
286
287
288 "1, 1.1000000000000001, 1.04912545221799091312759556239135752, 0.000832596851563726615564931035799390151, 1e-15, 2e-12",
289 "5, 5.0999999999999996, 5.04581083165668427678725919870992629, 0.000822546087919772895415146023240560636, 1e-15, 2e-11",
290 "35, 35.100000000000001, 35.025438801080858717764612789648226, 0.000494605845872597846399929727938197022, 1e-15, 2e-9",
291
292
293
294
295
296
297
298 "1, 1.0000000000000002, 1.00000000000000011091535982917837267, 0, 1e-15, -1",
299
300 "4, 4.0000000000000009, 4.00000000000000044406536771487238653, 0, 1e-15, -1",
301
302 "10, 10.000000000000002, 10.0000000000000008883225369216741152, 0, 1e-15, -1",
303
304
305
306
307
308
309
310
311
312
313 "-7.299454196351098e-8, 7.299454196351098e-8, 0, 1.77606771882092042827020676955306864e-15, 1e-15, -1e-15",
314 "-7.299454196351098e-8, 3.649727098175549e-8, -1.82486354908777262111748030604612676e-08, 9.99038091836768051420202283759953002e-16, 1e-15, -1e-15",
315 "-7.299454196351098e-8, 1.8248635490877744e-8, -2.7372953236316597672674778496667655e-08, 6.93776452664422342699175710737901419e-16, 1e-15, -2e-15",
316 "-7.299454196351098e-8, 0, -3.64972709817554726791073610445021429e-08, 4.44016929705230343732389204118195096e-16, 1e-15, -2e-15",
317 "-7.299454196351098e-8, -1.8248635490877744e-8, -4.56215887271943497112157190855901547e-08, 2.49759522957641055973442997155578316e-16, 3e-10, -5e-9",
318 "-7.299454196351098e-8, -3.649727098175549e-8, -5.47459064726332272497430210977513379e-08, 1.11004232421306844799494326433537718e-16, 3e-10, -2e-8",
319 "-3.649727098175549e-8, 3.649727098175549e-8, 0, 4.44016929705230343602092590994317462e-16, 1e-15, -1e-15",
320 "-3.649727098175549e-8, 1.8248635490877744e-8, -9.12431774543886994224314381693928319e-09, 2.49759522959192087703300220816741702e-16, 1e-15, -1e-15",
321 "-3.649727098175549e-8, 0, -1.82486354908777424165810069993271136e-08, 1.11004232426307600672725101668733634e-16, 1e-15, -2e-15",
322 "-3.649727098175549e-8, -1.8248635490877744e-8, -2.73729532363166159037567578213044937e-08, 2.77510581119222125912321725734620803e-17, 3e-10, -2e-8",
323 "-1.8248635490877744e-8, 1.8248635490877744e-8, 0, 1.11004232426307600757649604002272128e-16, 1e-15, -1e-15",
324 "-1.8248635490877744e-8, 9.124317745438872e-9, -4.5621588727194358257035396943085424e-09, 6.24398807397980267185125584627689296e-17, 1e-15, -1e-15",
325 "-1.8248635490877744e-8, 0, -9.12431774543887196791891930929818729e-09, 2.77510581065769011631256630479419225e-17, 1e-15, -1e-15",
326 "-9.124317745438872e-9, 9.124317745438872e-9, 0, 2.77510581065769013145632047586655539e-17, 1e-15, -1e-15",
327 "-9.124317745438872e-9, 4.562158872719436e-9, -2.28107943635971801967451582038414367e-09, 1.56099701849495071020338587207547036e-17, 1e-15, -1e-15",
328 "-9.124317745438872e-9, 0, -4.5621588727194360789130116308534264e-09, 6.93776452664422554954584023114952882e-18, 1e-15, -1e-15",
329
330
331
332 "0, 2.220446049250313e-16, 1.11022302462515654042363166809081572e-16, 4.14074938043255708407035257655783112e-33, 1e-15, -1e-2",
333 })
334 void testAdditionalMoments(double lower, double upper,
335 double mean, double variance,
336 double meanRelativeError, double varianceRelativeError) {
337 assertMean(lower, upper, mean, meanRelativeError);
338 if (varianceRelativeError < 0) {
339
340
341
342
343 final double var = TruncatedNormalDistribution.variance(lower, upper);
344 Assertions.assertTrue(var >= 0, () -> "Variance is not positive: " + var);
345 Assertions.assertEquals(var, TruncatedNormalDistribution.variance(-upper, -lower));
346 TestUtils.assertEquals(variance, var,
347 createAbsOrRelTolerance(1.5 * 0x1.0p-52, -varianceRelativeError),
348 () -> String.format("variance(%s, %s)", lower, upper));
349 } else {
350 assertVariance(lower, upper, variance, varianceRelativeError);
351 }
352 }
353
354
355
356
357 private static void assertMean(double lower, double upper, double expected, double eps) {
358 final double mean = TruncatedNormalDistribution.moment1(lower, upper);
359 Assertions.assertEquals(0 - mean, TruncatedNormalDistribution.moment1(-upper, -lower));
360 TestUtils.assertEquals(expected, mean, DoubleTolerances.relative(eps),
361 () -> String.format("mean(%s, %s)", lower, upper));
362 }
363
364
365
366
367
368
369 private static void assertMean(double lower, double upper, double u, double s, double expected, double eps) {
370 final double a = (lower - u) / s;
371 final double b = (upper - u) / s;
372 final double mean = u + TruncatedNormalDistribution.moment1(a, b) * s;
373 TestUtils.assertEquals(expected, mean, DoubleTolerances.relative(eps),
374 () -> String.format("mean(%s, %s, %s, %s)", lower, upper, u, s));
375 }
376
377
378
379
380 private static void assertVariance(double lower, double upper, double expected, double eps) {
381 final double var = TruncatedNormalDistribution.variance(lower, upper);
382 Assertions.assertEquals(var, TruncatedNormalDistribution.variance(-upper, -lower));
383 TestUtils.assertEquals(expected, var, DoubleTolerances.relative(eps),
384 () -> String.format("variance(%s, %s)", lower, upper));
385 }
386
387
388
389
390
391
392 private static void assertVariance(double lower, double upper, double u, double s, double expected, double eps) {
393 final double a = (lower - u) / s;
394 final double b = (upper - u) / s;
395 final double var = TruncatedNormalDistribution.variance(a, b) * s * s;
396 TestUtils.assertEquals(expected, var, DoubleTolerances.relative(eps),
397 () -> String.format("variance(%s, %s, %s, %s)", lower, upper, u, s));
398 }
399 }