1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.statistics.examples.distribution;
18
19 import java.io.BufferedReader;
20 import java.io.File;
21 import java.io.IOException;
22 import java.io.PrintWriter;
23 import java.io.UncheckedIOException;
24 import java.nio.file.Files;
25 import java.util.Arrays;
26 import java.util.List;
27 import org.apache.commons.numbers.core.Precision;
28 import org.apache.commons.statistics.distribution.ContinuousDistribution;
29 import org.apache.commons.statistics.distribution.DiscreteDistribution;
30
31
32
33
34 final class DistributionUtils {
35
36 private static final String UNKNOWN_FUNCTION = "Unknown function: ";
37
38 private static final double MAX_RELATIVE_ERROR = 1e-6;
39
40 private static final double DELTA_P = 1e-6;
41
42
43 private DistributionUtils() {}
44
45
46
47
48 interface ContinuousFunction {
49
50
51
52
53
54
55
56 double apply(ContinuousDistribution dist, double x);
57 }
58
59
60
61
62 interface DiscreteFunction {
63
64
65
66
67
68
69
70 double apply(DiscreteDistribution dist, int x);
71 }
72
73
74
75
76 interface InverseDiscreteFunction {
77
78
79
80
81
82
83
84 int apply(DiscreteDistribution dist, double x);
85 }
86
87
88
89
90
91
92
93 static void evaluate(List<Distribution<ContinuousDistribution>> dist,
94 ContinuousDistributionOptions distributionOptions) {
95 try (PrintWriter out = createOutput(distributionOptions)) {
96 final ContinuousFunction fun = createFunction(distributionOptions);
97 final double[] points = createPoints(distributionOptions);
98
99 final String delim = createDelimiter(distributionOptions);
100 createHeader("x", dist, out, delim);
101
102
103 final String format = distributionOptions.format;
104 final String xformat = distributionOptions.xformat;
105 for (final double x : points) {
106 out.format(xformat, x);
107 dist.forEach(d -> {
108 out.print(delim);
109 out.format(format, fun.apply(d.getDistribution(), x));
110 });
111 out.println();
112 }
113 }
114 }
115
116
117
118
119
120
121
122 static void evaluate(List<Distribution<ContinuousDistribution>> dist,
123 InverseContinuousDistributionOptions distributionOptions) {
124 try (PrintWriter out = createOutput(distributionOptions)) {
125 final ContinuousFunction fun = createFunction(distributionOptions);
126 final double[] points = createPoints(distributionOptions);
127
128 final String delim = createDelimiter(distributionOptions);
129 createHeader("p", dist, out, delim);
130
131
132 final String format = distributionOptions.format;
133 final String xformat = distributionOptions.pformat;
134 for (final double p : points) {
135 out.format(xformat, p);
136 dist.forEach(d -> {
137 out.print(delim);
138 out.format(format, fun.apply(d.getDistribution(), p));
139 });
140 out.println();
141 }
142 }
143 }
144
145
146
147
148
149
150
151 static void evaluate(List<Distribution<DiscreteDistribution>> dist,
152 DiscreteDistributionOptions distributionOptions) {
153 try (PrintWriter out = createOutput(distributionOptions)) {
154 final DiscreteFunction fun = createFunction(distributionOptions);
155 final int[] points = createPoints(distributionOptions);
156
157 final String delim = createDelimiter(distributionOptions);
158 createHeader("x", dist, out, delim);
159
160
161 final String format = distributionOptions.format;
162 for (final int x : points) {
163 out.print(x);
164 dist.forEach(d -> {
165 out.print(delim);
166 out.format(format, fun.apply(d.getDistribution(), x));
167 });
168 out.println();
169 }
170 }
171 }
172
173
174
175
176
177
178
179 static void evaluate(List<Distribution<DiscreteDistribution>> dist,
180 InverseDiscreteDistributionOptions distributionOptions) {
181 try (PrintWriter out = createOutput(distributionOptions)) {
182 final InverseDiscreteFunction fun = createFunction(distributionOptions);
183 final double[] points = createPoints(distributionOptions);
184
185 final String delim = createDelimiter(distributionOptions);
186 createHeader("p", dist, out, delim);
187
188
189 final String format = distributionOptions.pformat;
190 for (final double p : points) {
191 out.format(format, p);
192 dist.forEach(d -> {
193 out.print(delim);
194 out.print(fun.apply(d.getDistribution(), p));
195 });
196 out.println();
197 }
198 }
199 }
200
201
202
203
204
205
206
207 static void check(List<Distribution<ContinuousDistribution>> dist,
208 ContinuousDistributionOptions distributionOptions) {
209 try (PrintWriter out = createOutput(distributionOptions)) {
210 final double[] points = createPoints(distributionOptions);
211
212 dist.forEach(d -> {
213 final ContinuousDistribution dd = d.getDistribution();
214 final String title = dd.getClass().getSimpleName() + " " + d.getParameters();
215
216
217 final double lower = dd.getSupportLowerBound();
218 final double upper = dd.getSupportUpperBound();
219 if (!(lower == dd.inverseCumulativeProbability(0))) {
220 out.printf("%s lower icdf(0.0) : %s != %s", title, lower, dd.inverseCumulativeProbability(0));
221 }
222 if (!(upper == dd.inverseCumulativeProbability(1))) {
223 out.printf("%s upper icdf(1.0) : %s != %s", title, upper, dd.inverseCumulativeProbability(1));
224 }
225 if (!(lower == dd.inverseSurvivalProbability(1))) {
226 out.printf("%s lower isf(1.0) : %s != %s", title, lower, dd.inverseSurvivalProbability(1));
227 }
228 if (!(upper == dd.inverseSurvivalProbability(0))) {
229 out.printf("%s upper isf(0.0) : %s != %s", title, upper, dd.inverseSurvivalProbability(0));
230 }
231
232 for (final double x : points) {
233 final double p1 = dd.cumulativeProbability(x);
234 final double p2 = dd.survivalProbability(x);
235 final double s = p1 + p2;
236 if (!(Math.abs(1.0 - s) < 1e-10)) {
237 out.printf("%s x=%s : cdf + survival != 1.0 : %s + %s%n", title, x, p1, p2);
238 }
239
240 if (!closeToInteger(p1)) {
241 final double xx = dd.inverseCumulativeProbability(p1);
242 if (!Precision.equalsWithRelativeTolerance(x, xx, MAX_RELATIVE_ERROR) &&
243
244 !Precision.equalsWithRelativeTolerance(p1, dd.cumulativeProbability(xx),
245 MAX_RELATIVE_ERROR)) {
246 out.printf("%s x=%s : icdf(%s) : %s (cdf=%s)%n", title, x, p1, xx,
247 dd.cumulativeProbability(xx));
248 }
249 }
250
251 if (!closeToInteger(p2)) {
252 final double xx = dd.inverseSurvivalProbability(p2);
253 if (!Precision.equalsWithRelativeTolerance(x, xx, MAX_RELATIVE_ERROR) &&
254
255 !Precision.equalsWithRelativeTolerance(p2, dd.survivalProbability(xx),
256 MAX_RELATIVE_ERROR)) {
257 out.printf("%s x=%s : isf(%s) : %s (sf=%s)%n", title, x, p2, xx,
258 dd.survivalProbability(xx));
259 }
260 }
261 }
262
263 for (final double x : points) {
264 final double p1 = dd.density(x);
265 final double lp = dd.logDensity(x);
266 final double p2 = Math.exp(lp);
267 if (!Precision.equalsWithRelativeTolerance(p1, p2, MAX_RELATIVE_ERROR)) {
268 out.printf("%s x=%s : pdf != exp(logpdf) : %s != %s%n", title, x, p1, p2);
269 }
270 }
271 });
272 }
273 }
274
275
276
277
278
279
280
281 static void check(List<Distribution<DiscreteDistribution>> dist,
282 DiscreteDistributionOptions distributionOptions) {
283 try (PrintWriter out = createOutput(distributionOptions)) {
284 final int[] points = createPoints(distributionOptions);
285
286 dist.forEach(d -> {
287 final DiscreteDistribution dd = d.getDistribution();
288 final String title = dd.getClass().getSimpleName() + " " + d.getParameters();
289
290
291 final int lower = dd.getSupportLowerBound();
292 final int upper = dd.getSupportUpperBound();
293 if (!(lower == dd.inverseCumulativeProbability(0))) {
294 out.printf("%s lower != icdf(0.0) : %d != %d", title, lower, dd.inverseCumulativeProbability(0));
295 }
296 if (!(upper == dd.inverseCumulativeProbability(1))) {
297 out.printf("%s upper != icdf(1.0) : %d != %d", title, upper, dd.inverseCumulativeProbability(1));
298 }
299 if (!(lower == dd.inverseSurvivalProbability(1))) {
300 out.printf("%s lower isf(1.0) : %d != %d", title, lower, dd.inverseSurvivalProbability(1));
301 }
302 if (!(upper == dd.inverseSurvivalProbability(0))) {
303 out.printf("%s upper isf(0.0) : %d != %d", title, upper, dd.inverseSurvivalProbability(0));
304 }
305
306 for (final int x : points) {
307 final double p1 = dd.cumulativeProbability(x);
308 final double p2 = dd.survivalProbability(x);
309 final double s = p1 + p2;
310 if (!(Math.abs(1.0 - s) < 1e-10)) {
311 out.printf("%s x=%d : cdf + survival != 1.0 : %s + %s%n", title, x, p1, p2);
312 }
313
314 if (!closeToInteger(p1)) {
315 final int xx = dd.inverseCumulativeProbability(p1);
316 if (x != xx) {
317 out.printf("%s x=%d : icdf(%s) : %d (cdf=%s)%n", title, x, p1, xx,
318 dd.cumulativeProbability(xx));
319 }
320 }
321
322 if (!closeToInteger(p2)) {
323 final int xx = dd.inverseSurvivalProbability(p2);
324 if (x != xx) {
325 out.printf("%s x=%d : isf(%s) : %d (sf=%s)%n", title, x, p2, xx,
326 dd.survivalProbability(xx));
327 }
328 }
329 }
330
331 for (final int x : points) {
332 final double p1 = dd.probability(x);
333 final double lp = dd.logProbability(x);
334 final double p2 = Math.exp(lp);
335 if (!Precision.equalsWithRelativeTolerance(p1, p2, MAX_RELATIVE_ERROR)) {
336 out.printf("%s x=%d : pmf != exp(logpmf) : %s != %s%n", title, x, p1, p2);
337 }
338 }
339 });
340 }
341 }
342
343
344
345
346
347
348
349 private static PrintWriter createOutput(DistributionOptions distributionOptions) {
350 if (distributionOptions.outputFile != null) {
351 try {
352 return new PrintWriter(Files.newBufferedWriter(distributionOptions.outputFile.toPath()));
353 } catch (IOException ex) {
354 throw new UncheckedIOException("Failed to create output: " + distributionOptions.outputFile, ex);
355 }
356 }
357 return new PrintWriter(System.out) {
358 @Override
359 public void close() {
360
361 flush();
362 }
363 };
364 }
365
366
367
368
369
370
371
372 private static String createDelimiter(DistributionOptions distributionOptions) {
373 final String delim = distributionOptions.delim;
374
375 return delim.replace("\\t", "\t");
376 }
377
378
379
380
381
382
383
384
385
386
387 private static <T> void createHeader(String xname, List<Distribution<T>> dist, final PrintWriter out,
388 final String delim) {
389
390 out.print(xname);
391 dist.forEach(d -> {
392 out.print(delim);
393 out.print(d.getParameters());
394 });
395 out.println();
396 }
397
398
399
400
401
402
403
404 private static ContinuousFunction createFunction(ContinuousDistributionOptions distributionOptions) {
405 ContinuousFunction f;
406 switch (distributionOptions.distributionFunction) {
407 case PDF:
408 f = ContinuousDistribution::density;
409 break;
410 case LPDF:
411 f = ContinuousDistribution::logDensity;
412 break;
413 case CDF:
414 f = ContinuousDistribution::cumulativeProbability;
415 break;
416 case SF:
417 f = ContinuousDistribution::survivalProbability;
418 break;
419 default:
420 throw new IllegalArgumentException(UNKNOWN_FUNCTION + distributionOptions.distributionFunction);
421 }
422 if (!distributionOptions.suppressException) {
423 return f;
424 }
425 return new ContinuousFunction() {
426 @Override
427 public double apply(ContinuousDistribution dist, double x) {
428 try {
429 return f.apply(dist, x);
430 } catch (IllegalArgumentException ex) {
431
432 return Double.NaN;
433 }
434 }
435 };
436 }
437
438
439
440
441
442
443
444 private static DiscreteFunction createFunction(DiscreteDistributionOptions distributionOptions) {
445 DiscreteFunction f;
446 switch (distributionOptions.distributionFunction) {
447 case PMF:
448 f = DiscreteDistribution::probability;
449 break;
450 case LPMF:
451 f = DiscreteDistribution::logProbability;
452 break;
453 case CDF:
454 f = DiscreteDistribution::cumulativeProbability;
455 break;
456 case SF:
457 f = DiscreteDistribution::survivalProbability;
458 break;
459 default:
460 throw new IllegalArgumentException(UNKNOWN_FUNCTION + distributionOptions.distributionFunction);
461 }
462 if (!distributionOptions.suppressException) {
463 return f;
464 }
465 return new DiscreteFunction() {
466 @Override
467 public double apply(DiscreteDistribution dist, int x) {
468 try {
469 return f.apply(dist, x);
470 } catch (IllegalArgumentException ex) {
471
472 return Double.NaN;
473 }
474 }
475 };
476 }
477
478
479
480
481
482
483
484 private static ContinuousFunction createFunction(InverseContinuousDistributionOptions distributionOptions) {
485 ContinuousFunction f;
486 switch (distributionOptions.distributionFunction) {
487 case ICDF:
488 f = ContinuousDistribution::inverseCumulativeProbability;
489 break;
490 case ISF:
491 f = ContinuousDistribution::inverseSurvivalProbability;
492 break;
493 default:
494 throw new IllegalArgumentException(UNKNOWN_FUNCTION + distributionOptions.distributionFunction);
495 }
496 if (!distributionOptions.suppressException) {
497 return f;
498 }
499 return new ContinuousFunction() {
500 @Override
501 public double apply(ContinuousDistribution dist, double x) {
502 try {
503 return f.apply(dist, x);
504 } catch (IllegalArgumentException ex) {
505
506 return Double.NaN;
507 }
508 }
509 };
510 }
511
512
513
514
515
516
517
518 private static InverseDiscreteFunction createFunction(InverseDiscreteDistributionOptions distributionOptions) {
519 InverseDiscreteFunction f;
520 switch (distributionOptions.distributionFunction) {
521 case ICDF:
522 f = DiscreteDistribution::inverseCumulativeProbability;
523 break;
524 case ISF:
525 f = DiscreteDistribution::inverseSurvivalProbability;
526 break;
527 default:
528 throw new IllegalArgumentException(UNKNOWN_FUNCTION + distributionOptions.distributionFunction);
529 }
530 if (!distributionOptions.suppressException) {
531 return f;
532 }
533 return new InverseDiscreteFunction() {
534 @Override
535 public int apply(DiscreteDistribution dist, double x) {
536 try {
537 return f.apply(dist, x);
538 } catch (IllegalArgumentException ex) {
539
540 return Integer.MIN_VALUE;
541 }
542 }
543 };
544 }
545
546
547
548
549
550
551
552 private static double[] createPoints(ContinuousDistributionOptions distributionOptions) {
553 if (distributionOptions.x != null) {
554 return distributionOptions.x;
555 }
556 if (distributionOptions.inputFile != null) {
557 return readDoublePoints(distributionOptions.inputFile);
558 }
559 return enumerate(distributionOptions.min, distributionOptions.max,
560 distributionOptions.steps);
561 }
562
563
564
565
566
567
568
569 private static int[] createPoints(DiscreteDistributionOptions distributionOptions) {
570 if (distributionOptions.x != null) {
571 return distributionOptions.x;
572 }
573 if (distributionOptions.inputFile != null) {
574 return readIntPoints(distributionOptions.inputFile);
575 }
576 return series(distributionOptions.min, distributionOptions.max,
577 distributionOptions.increment);
578 }
579
580
581
582
583
584
585
586 private static double[] createPoints(InverseDiscreteDistributionOptions distributionOptions) {
587 if (distributionOptions.x != null) {
588 return distributionOptions.x;
589 }
590 if (distributionOptions.inputFile != null) {
591 return readDoublePoints(distributionOptions.inputFile);
592 }
593 return enumerate(distributionOptions.min, distributionOptions.max,
594 distributionOptions.steps);
595 }
596
597
598
599
600
601
602
603 private static double[] readDoublePoints(File inputFile) {
604 double[] points = new double[10];
605 int size = 0;
606 try (BufferedReader in = Files.newBufferedReader(inputFile.toPath())) {
607 for (String line = in.readLine(); line != null; line = in.readLine()) {
608 final double x = Double.parseDouble(line);
609 if (points.length == size) {
610 points = Arrays.copyOf(points, size * 2);
611 }
612 points[size++] = x;
613 }
614 } catch (final IOException e) {
615 throw new UncheckedIOException(e);
616 } catch (final NumberFormatException e) {
617 throw new RuntimeException("Input file should contain a real number on each line", e);
618 }
619 return Arrays.copyOf(points, size);
620 }
621
622
623
624
625
626
627
628 private static int[] readIntPoints(File inputFile) {
629 int[] points = new int[10];
630 int size = 0;
631 try (BufferedReader in = Files.newBufferedReader(inputFile.toPath())) {
632 for (String line = in.readLine(); line != null; line = in.readLine()) {
633 final int x = Integer.parseInt(line);
634 if (points.length == size) {
635 points = Arrays.copyOf(points, size * 2);
636 }
637 points[size++] = x;
638 }
639 } catch (final IOException e) {
640 throw new UncheckedIOException(e);
641 } catch (final NumberFormatException e) {
642 throw new RuntimeException("Input file should contain an integer on each line", e);
643 }
644 return Arrays.copyOf(points, size);
645 }
646
647
648
649
650
651
652
653
654
655 private static double[] enumerate(double min, double max, int steps) {
656 if (!Double.isFinite(min)) {
657 throw new IllegalArgumentException("Invalid minimum: " + min);
658 }
659 if (!Double.isFinite(max)) {
660 throw new IllegalArgumentException("Invalid maximum: " + max);
661 }
662 if (min == max) {
663 return new double[] {min};
664 }
665 final double[] x = new double[steps + 1];
666 final double dx = (max - min) / steps;
667 for (int i = 0; i < steps; i++) {
668 x[i] = min + i * dx;
669 }
670 x[steps] = max;
671 return x;
672 }
673
674
675
676
677
678
679
680
681
682 private static int[] series(int min, int max, int increment) {
683 if (min == max) {
684 return new int[] {min};
685 }
686 final int steps = (int) Math.ceil((double) (max - min) / increment);
687 final int[] x = new int[steps + 1];
688 for (int i = 0; i < steps; i++) {
689 x[i] = min + i * increment;
690 }
691 x[steps] = max;
692 return x;
693 }
694
695
696
697
698
699
700
701
702
703
704
705 static int validateLengths(int... lengths) {
706 int max = 0;
707 for (final int l : lengths) {
708 max = max < l ? l : max;
709 }
710
711 for (final int l : lengths) {
712 if (l != 1 && l != max) {
713 throw new IllegalArgumentException(
714 "Invalid parameter array length: " + l +
715 ". Lengths must by either 1 or the maximum (" + max + ").");
716 }
717 }
718 return max;
719 }
720
721
722
723
724
725
726
727
728
729
730
731 static double[] expandToLength(double[] array, int n) {
732 if (array.length != n) {
733 array = Arrays.copyOf(array, n);
734 Arrays.fill(array, array[0]);
735 }
736 return array;
737 }
738
739
740
741
742
743
744
745
746
747
748
749 static int[] expandToLength(int[] array, int n) {
750 if (array.length != n) {
751 array = Arrays.copyOf(array, n);
752 Arrays.fill(array, array[0]);
753 }
754 return array;
755 }
756
757
758
759
760
761
762
763 private static boolean closeToInteger(double p) {
764 return Math.abs(Math.rint(p) - p) < DELTA_P;
765 }
766 }