View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.rng.sampling;
19  
20  import java.util.Arrays;
21  import java.util.Collections;
22  import java.util.HashMap;
23  import java.util.List;
24  import java.util.Map;
25  import java.util.TreeMap;
26  import org.apache.commons.rng.UniformRandomProvider;
27  import org.junit.jupiter.api.Assertions;
28  import org.junit.jupiter.api.Test;
29  import org.junit.jupiter.params.ParameterizedTest;
30  import org.junit.jupiter.params.provider.ValueSource;
31  
32  /**
33   * Test class for {@link DiscreteProbabilityCollectionSampler}.
34   */
35  class DiscreteProbabilityCollectionSamplerTest {
36      /** RNG. */
37      private final UniformRandomProvider rng = RandomAssert.createRNG();
38  
39      @Test
40      void testPrecondition1() {
41          // Size mismatch
42          final List<Double> collection = Arrays.asList(1d, 2d);
43          final double[] probabilities = {0};
44          Assertions.assertThrows(IllegalArgumentException.class,
45              () -> new DiscreteProbabilityCollectionSampler<>(rng,
46                   collection,
47                   probabilities));
48      }
49  
50      @Test
51      void testPrecondition2() {
52          // Negative probability
53          final List<Double> collection = Arrays.asList(1d, 2d);
54          final double[] probabilities = {0, -1};
55          Assertions.assertThrows(IllegalArgumentException.class,
56              () -> new DiscreteProbabilityCollectionSampler<>(rng,
57                  collection,
58                  probabilities));
59      }
60  
61      @Test
62      void testPrecondition3() {
63          // Probabilities do not sum above 0
64          final List<Double> collection = Arrays.asList(1d, 2d);
65          final double[] probabilities = {0, 0};
66          Assertions.assertThrows(IllegalArgumentException.class,
67              () -> new DiscreteProbabilityCollectionSampler<>(rng,
68                  collection,
69                  probabilities));
70      }
71  
72      @ParameterizedTest
73      @ValueSource(doubles = {-1, Double.POSITIVE_INFINITY, Double.NaN})
74      void testPrecondition4(double p) {
75          final List<Double> collection = Arrays.asList(1d, 2d);
76          final double[] probabilities = {0, p};
77          Assertions.assertThrows(IllegalArgumentException.class,
78              () -> new DiscreteProbabilityCollectionSampler<>(rng,
79                  collection,
80                  probabilities));
81      }
82  
83      @ParameterizedTest
84      @ValueSource(doubles = {-1, Double.POSITIVE_INFINITY, Double.NaN})
85      void testPrecondition5(double p) {
86          final Map<String, Double> collection = new HashMap<>();
87          collection.put("one", 0.0);
88          collection.put("two", p);
89          Assertions.assertThrows(IllegalArgumentException.class,
90              () -> new DiscreteProbabilityCollectionSampler<>(rng,
91                  collection));
92      }
93  
94      @Test
95      void testPrecondition6() {
96          // Empty Map<T, Double> not allowed
97          final Map<String, Double> collection = Collections.emptyMap();
98          Assertions.assertThrows(IllegalArgumentException.class,
99              () -> new DiscreteProbabilityCollectionSampler<>(rng,
100                  collection));
101     }
102 
103     @Test
104     void testPrecondition7() {
105         // Empty List<T> not allowed
106         final List<Double> collection = Collections.emptyList();
107         final double[] probabilities = {};
108         Assertions.assertThrows(IllegalArgumentException.class,
109             () -> new DiscreteProbabilityCollectionSampler<>(rng,
110                 collection,
111                 probabilities));
112     }
113 
114     @Test
115     void testSample() {
116         final DiscreteProbabilityCollectionSampler<Double> sampler =
117             new DiscreteProbabilityCollectionSampler<>(rng,
118                                                        Arrays.asList(3d, -1d, 3d, 7d, -2d, 8d),
119                                                        new double[] {0.2, 0.2, 0.3, 0.3, 0, 0});
120         final double expectedMean = 3.4;
121         final double expectedVariance = 7.84;
122 
123         final int n = 100000000;
124         double sum = 0;
125         double sumOfSquares = 0;
126         for (int i = 0; i < n; i++) {
127             final double rand = sampler.sample();
128             sum += rand;
129             sumOfSquares += rand * rand;
130         }
131 
132         final double mean = sum / n;
133         Assertions.assertEquals(expectedMean, mean, 1e-3);
134         final double variance = sumOfSquares / n - mean * mean;
135         Assertions.assertEquals(expectedVariance, variance, 2e-3);
136     }
137 
138 
139     @Test
140     void testSampleUsingMap() {
141         final UniformRandomProvider rng1 = RandomAssert.seededRNG();
142         final UniformRandomProvider rng2 = RandomAssert.seededRNG();
143         final List<Integer> items = Arrays.asList(1, 3, 4, 6, 9);
144         final double[] probabilities = {0.1, 0.2, 0.3, 0.4, 0.5};
145         final DiscreteProbabilityCollectionSampler<Integer> sampler1 =
146             new DiscreteProbabilityCollectionSampler<>(rng1, items, probabilities);
147 
148         // Create a map version. The map iterator must be ordered so use a TreeMap.
149         final Map<Integer, Double> map = new TreeMap<>();
150         for (int i = 0; i < probabilities.length; i++) {
151             map.put(items.get(i), probabilities[i]);
152         }
153         final DiscreteProbabilityCollectionSampler<Integer> sampler2 =
154             new DiscreteProbabilityCollectionSampler<>(rng2, map);
155 
156         for (int i = 0; i < 50; i++) {
157             Assertions.assertEquals(sampler1.sample(), sampler2.sample());
158         }
159     }
160 
161     /**
162      * Edge-case test:
163      * Create a sampler that will return 1 for nextDouble() forcing the search to
164      * identify the end item of the cumulative probability array.
165      */
166     @Test
167     void testSampleWithProbabilityAtLastItem() {
168         // Ensure the samples pick probability 0 (the first item) and then
169         // a probability (for the second item) that hits an edge case.
170         final UniformRandomProvider dummyRng = new UniformRandomProvider() {
171             private int count;
172 
173             @Override
174             public long nextLong() {
175                 return 0;
176             }
177 
178             @Override
179             public double nextDouble() {
180                 // Return 0 then the 1.0 for the probability
181                 return (count++ == 0) ? 0 : 1.0;
182             }
183         };
184 
185         final List<Double> items = Arrays.asList(1d, 2d);
186         final DiscreteProbabilityCollectionSampler<Double> sampler =
187             new DiscreteProbabilityCollectionSampler<>(dummyRng,
188                                                        items,
189                                                        new double[] {0.5, 0.5});
190         final Double item1 = sampler.sample();
191         final Double item2 = sampler.sample();
192         // Check they are in the list
193         Assertions.assertTrue(items.contains(item1), "Sample item1 is not from the list");
194         Assertions.assertTrue(items.contains(item2), "Sample item2 is not from the list");
195         // Test the two samples are different items
196         Assertions.assertNotSame(item1, item2, "Item1 and 2 should be different");
197     }
198 
199     /**
200      * Test the SharedStateSampler implementation.
201      */
202     @Test
203     void testSharedStateSampler() {
204         final UniformRandomProvider rng1 = RandomAssert.seededRNG();
205         final UniformRandomProvider rng2 = RandomAssert.seededRNG();
206         final List<Double> items = Arrays.asList(1d, 2d, 3d, 4d);
207         final DiscreteProbabilityCollectionSampler<Double> sampler1 =
208             new DiscreteProbabilityCollectionSampler<>(rng1,
209                                                        items,
210                                                        new double[] {0.1, 0.2, 0.3, 0.4});
211         final DiscreteProbabilityCollectionSampler<Double> sampler2 = sampler1.withUniformRandomProvider(rng2);
212         RandomAssert.assertProduceSameSequence(sampler1, sampler2);
213     }
214 }