1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.rng.sampling;
19
20 import java.util.Arrays;
21 import java.util.Collections;
22 import java.util.HashMap;
23 import java.util.List;
24 import java.util.Map;
25 import java.util.TreeMap;
26 import org.apache.commons.rng.UniformRandomProvider;
27 import org.junit.jupiter.api.Assertions;
28 import org.junit.jupiter.api.Test;
29 import org.junit.jupiter.params.ParameterizedTest;
30 import org.junit.jupiter.params.provider.ValueSource;
31
32
33
34
35 class DiscreteProbabilityCollectionSamplerTest {
36
37 private final UniformRandomProvider rng = RandomAssert.createRNG();
38
39 @Test
40 void testPrecondition1() {
41
42 final List<Double> collection = Arrays.asList(1d, 2d);
43 final double[] probabilities = {0};
44 Assertions.assertThrows(IllegalArgumentException.class,
45 () -> new DiscreteProbabilityCollectionSampler<>(rng,
46 collection,
47 probabilities));
48 }
49
50 @Test
51 void testPrecondition2() {
52
53 final List<Double> collection = Arrays.asList(1d, 2d);
54 final double[] probabilities = {0, -1};
55 Assertions.assertThrows(IllegalArgumentException.class,
56 () -> new DiscreteProbabilityCollectionSampler<>(rng,
57 collection,
58 probabilities));
59 }
60
61 @Test
62 void testPrecondition3() {
63
64 final List<Double> collection = Arrays.asList(1d, 2d);
65 final double[] probabilities = {0, 0};
66 Assertions.assertThrows(IllegalArgumentException.class,
67 () -> new DiscreteProbabilityCollectionSampler<>(rng,
68 collection,
69 probabilities));
70 }
71
72 @ParameterizedTest
73 @ValueSource(doubles = {-1, Double.POSITIVE_INFINITY, Double.NaN})
74 void testPrecondition4(double p) {
75 final List<Double> collection = Arrays.asList(1d, 2d);
76 final double[] probabilities = {0, p};
77 Assertions.assertThrows(IllegalArgumentException.class,
78 () -> new DiscreteProbabilityCollectionSampler<>(rng,
79 collection,
80 probabilities));
81 }
82
83 @ParameterizedTest
84 @ValueSource(doubles = {-1, Double.POSITIVE_INFINITY, Double.NaN})
85 void testPrecondition5(double p) {
86 final Map<String, Double> collection = new HashMap<>();
87 collection.put("one", 0.0);
88 collection.put("two", p);
89 Assertions.assertThrows(IllegalArgumentException.class,
90 () -> new DiscreteProbabilityCollectionSampler<>(rng,
91 collection));
92 }
93
94 @Test
95 void testPrecondition6() {
96
97 final Map<String, Double> collection = Collections.emptyMap();
98 Assertions.assertThrows(IllegalArgumentException.class,
99 () -> new DiscreteProbabilityCollectionSampler<>(rng,
100 collection));
101 }
102
103 @Test
104 void testPrecondition7() {
105
106 final List<Double> collection = Collections.emptyList();
107 final double[] probabilities = {};
108 Assertions.assertThrows(IllegalArgumentException.class,
109 () -> new DiscreteProbabilityCollectionSampler<>(rng,
110 collection,
111 probabilities));
112 }
113
114 @Test
115 void testSample() {
116 final DiscreteProbabilityCollectionSampler<Double> sampler =
117 new DiscreteProbabilityCollectionSampler<>(rng,
118 Arrays.asList(3d, -1d, 3d, 7d, -2d, 8d),
119 new double[] {0.2, 0.2, 0.3, 0.3, 0, 0});
120 final double expectedMean = 3.4;
121 final double expectedVariance = 7.84;
122
123 final int n = 100000000;
124 double sum = 0;
125 double sumOfSquares = 0;
126 for (int i = 0; i < n; i++) {
127 final double rand = sampler.sample();
128 sum += rand;
129 sumOfSquares += rand * rand;
130 }
131
132 final double mean = sum / n;
133 Assertions.assertEquals(expectedMean, mean, 1e-3);
134 final double variance = sumOfSquares / n - mean * mean;
135 Assertions.assertEquals(expectedVariance, variance, 2e-3);
136 }
137
138
139 @Test
140 void testSampleUsingMap() {
141 final UniformRandomProvider rng1 = RandomAssert.seededRNG();
142 final UniformRandomProvider rng2 = RandomAssert.seededRNG();
143 final List<Integer> items = Arrays.asList(1, 3, 4, 6, 9);
144 final double[] probabilities = {0.1, 0.2, 0.3, 0.4, 0.5};
145 final DiscreteProbabilityCollectionSampler<Integer> sampler1 =
146 new DiscreteProbabilityCollectionSampler<>(rng1, items, probabilities);
147
148
149 final Map<Integer, Double> map = new TreeMap<>();
150 for (int i = 0; i < probabilities.length; i++) {
151 map.put(items.get(i), probabilities[i]);
152 }
153 final DiscreteProbabilityCollectionSampler<Integer> sampler2 =
154 new DiscreteProbabilityCollectionSampler<>(rng2, map);
155
156 for (int i = 0; i < 50; i++) {
157 Assertions.assertEquals(sampler1.sample(), sampler2.sample());
158 }
159 }
160
161
162
163
164
165
166 @Test
167 void testSampleWithProbabilityAtLastItem() {
168
169
170 final UniformRandomProvider dummyRng = new UniformRandomProvider() {
171 private int count;
172
173 @Override
174 public long nextLong() {
175 return 0;
176 }
177
178 @Override
179 public double nextDouble() {
180
181 return (count++ == 0) ? 0 : 1.0;
182 }
183 };
184
185 final List<Double> items = Arrays.asList(1d, 2d);
186 final DiscreteProbabilityCollectionSampler<Double> sampler =
187 new DiscreteProbabilityCollectionSampler<>(dummyRng,
188 items,
189 new double[] {0.5, 0.5});
190 final Double item1 = sampler.sample();
191 final Double item2 = sampler.sample();
192
193 Assertions.assertTrue(items.contains(item1), "Sample item1 is not from the list");
194 Assertions.assertTrue(items.contains(item2), "Sample item2 is not from the list");
195
196 Assertions.assertNotSame(item1, item2, "Item1 and 2 should be different");
197 }
198
199
200
201
202 @Test
203 void testSharedStateSampler() {
204 final UniformRandomProvider rng1 = RandomAssert.seededRNG();
205 final UniformRandomProvider rng2 = RandomAssert.seededRNG();
206 final List<Double> items = Arrays.asList(1d, 2d, 3d, 4d);
207 final DiscreteProbabilityCollectionSampler<Double> sampler1 =
208 new DiscreteProbabilityCollectionSampler<>(rng1,
209 items,
210 new double[] {0.1, 0.2, 0.3, 0.4});
211 final DiscreteProbabilityCollectionSampler<Double> sampler2 = sampler1.withUniformRandomProvider(rng2);
212 RandomAssert.assertProduceSameSequence(sampler1, sampler2);
213 }
214 }