1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.math4.legacy.distribution;
18
19 import java.util.ArrayList;
20 import java.util.LinkedHashMap;
21 import java.util.List;
22 import java.util.Map;
23 import java.util.Map.Entry;
24
25 import org.apache.commons.statistics.distribution.DiscreteDistribution;
26 import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
27 import org.apache.commons.math4.legacy.exception.MathArithmeticException;
28 import org.apache.commons.math4.legacy.exception.NotANumberException;
29 import org.apache.commons.math4.legacy.exception.NotFiniteNumberException;
30 import org.apache.commons.math4.legacy.exception.NotPositiveException;
31 import org.apache.commons.rng.UniformRandomProvider;
32 import org.apache.commons.math4.legacy.core.Pair;
33
34
35
36
37
38
39
40
41
42
43
44 public class EnumeratedIntegerDistribution extends AbstractIntegerDistribution {
45
46
47
48
49 protected final EnumeratedDistribution<Integer> innerDistribution;
50
51
52
53
54
55
56
57
58
59
60
61
62
63 public EnumeratedIntegerDistribution(final int[] singletons,
64 final double[] probabilities)
65 throws DimensionMismatchException,
66 NotPositiveException,
67 MathArithmeticException,
68 NotFiniteNumberException,
69 NotANumberException {
70 innerDistribution = new EnumeratedDistribution<>(createDistribution(singletons,
71 probabilities));
72 }
73
74
75
76
77
78
79
80 public EnumeratedIntegerDistribution(final int[] data) {
81 final Map<Integer, Integer> dataMap = new LinkedHashMap<>();
82 for (int value : data) {
83 dataMap.merge(value, 1, Integer::sum);
84 }
85 final int massPoints = dataMap.size();
86 final double denom = data.length;
87 final int[] values = new int[massPoints];
88 final double[] probabilities = new double[massPoints];
89 int index = 0;
90 for (Entry<Integer, Integer> entry : dataMap.entrySet()) {
91 values[index] = entry.getKey();
92 probabilities[index] = entry.getValue().intValue() / denom;
93 index++;
94 }
95 innerDistribution = new EnumeratedDistribution<>(createDistribution(values, probabilities));
96 }
97
98
99
100
101
102
103
104
105 private static List<Pair<Integer, Double>> createDistribution(int[] singletons, double[] probabilities) {
106 if (singletons.length != probabilities.length) {
107 throw new DimensionMismatchException(probabilities.length, singletons.length);
108 }
109
110 final List<Pair<Integer, Double>> samples = new ArrayList<>(singletons.length);
111
112 for (int i = 0; i < singletons.length; i++) {
113 samples.add(new Pair<>(singletons[i], probabilities[i]));
114 }
115 return samples;
116 }
117
118
119
120
121 @Override
122 public double probability(final int x) {
123 return innerDistribution.probability(x);
124 }
125
126
127
128
129 @Override
130 public double cumulativeProbability(final int x) {
131 double probability = 0;
132
133 for (final Pair<Integer, Double> sample : innerDistribution.getPmf()) {
134 if (sample.getKey() <= x) {
135 probability += sample.getValue();
136 }
137 }
138
139 return probability;
140 }
141
142
143
144
145
146
147 @Override
148 public double getMean() {
149 double mean = 0;
150
151 for (final Pair<Integer, Double> sample : innerDistribution.getPmf()) {
152 mean += sample.getValue() * sample.getKey();
153 }
154
155 return mean;
156 }
157
158
159
160
161
162
163 @Override
164 public double getVariance() {
165 double mean = 0;
166 double meanOfSquares = 0;
167
168 for (final Pair<Integer, Double> sample : innerDistribution.getPmf()) {
169 mean += sample.getValue() * sample.getKey();
170 meanOfSquares += sample.getValue() * sample.getKey() * sample.getKey();
171 }
172
173 return meanOfSquares - mean * mean;
174 }
175
176
177
178
179
180
181
182
183 @Override
184 public int getSupportLowerBound() {
185 int min = Integer.MAX_VALUE;
186 for (final Pair<Integer, Double> sample : innerDistribution.getPmf()) {
187 if (sample.getKey() < min && sample.getValue() > 0) {
188 min = sample.getKey();
189 }
190 }
191
192 return min;
193 }
194
195
196
197
198
199
200
201
202 @Override
203 public int getSupportUpperBound() {
204 int max = Integer.MIN_VALUE;
205 for (final Pair<Integer, Double> sample : innerDistribution.getPmf()) {
206 if (sample.getKey() > max && sample.getValue() > 0) {
207 max = sample.getKey();
208 }
209 }
210
211 return max;
212 }
213
214
215
216
217
218
219 @Override
220 public DiscreteDistribution.Sampler createSampler(final UniformRandomProvider rng) {
221 return innerDistribution.createSampler(rng)::sample;
222 }
223 }