1   
2   
3   
4   
5   
6   
7   
8   
9   
10  
11  
12  
13  
14  
15  
16  
17  package org.apache.commons.math4.legacy.distribution;
18  
19  import static org.junit.Assert.assertEquals;
20  
21  import java.util.ArrayList;
22  import java.util.List;
23  
24  import org.apache.commons.statistics.distribution.ContinuousDistribution;
25  import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
26  import org.apache.commons.math4.legacy.exception.MathArithmeticException;
27  import org.apache.commons.math4.legacy.exception.NotANumberException;
28  import org.apache.commons.math4.legacy.exception.NotFiniteNumberException;
29  import org.apache.commons.math4.legacy.exception.NotPositiveException;
30  import org.apache.commons.math4.core.jdkmath.JdkMath;
31  import org.apache.commons.math4.legacy.core.Pair;
32  import org.apache.commons.rng.UniformRandomProvider;
33  import org.apache.commons.rng.simple.RandomSource;
34  import org.junit.Assert;
35  import org.junit.Test;
36  
37  
38  
39  
40  
41  public class EnumeratedRealDistributionTest {
42  
43      
44  
45  
46      private final EnumeratedRealDistribution testDistribution;
47  
48      
49  
50  
51      public EnumeratedRealDistributionTest() {
52          
53          
54          testDistribution = new EnumeratedRealDistribution(
55                  new double[]{3.0, -1.0, 3.0, 7.0, -2.0, 8.0},
56                  new double[]{0.2, 0.2, 0.3, 0.3, 0.0, 0.0});
57      }
58  
59      
60  
61  
62  
63      @Test
64      public void testExceptions() {
65          EnumeratedRealDistribution invalid = null;
66          try {
67              invalid = new EnumeratedRealDistribution(new double[]{1.0, 2.0}, new double[]{0.0});
68              Assert.fail("Expected DimensionMismatchException");
69          } catch (DimensionMismatchException e) {
70          }
71          try{
72          invalid = new EnumeratedRealDistribution(new double[]{1.0, 2.0}, new double[]{0.0, -1.0});
73              Assert.fail("Expected NotPositiveException");
74          } catch (NotPositiveException e) {
75          }
76          try {
77              invalid = new EnumeratedRealDistribution(new double[]{1.0, 2.0}, new double[]{0.0, 0.0});
78              Assert.fail("Expected MathArithmeticException");
79          } catch (MathArithmeticException e) {
80          }
81          try {
82              invalid = new EnumeratedRealDistribution(new double[]{1.0, 2.0}, new double[]{0.0, Double.NaN});
83              Assert.fail("Expected NotANumberException");
84          } catch (NotANumberException e) {
85          }
86          try {
87              invalid = new EnumeratedRealDistribution(new double[]{1.0, 2.0}, new double[]{0.0, Double.POSITIVE_INFINITY});
88              Assert.fail("Expected NotFiniteNumberException");
89          } catch (NotFiniteNumberException e) {
90          }
91          Assert.assertNull("Expected non-initialized DiscreteRealDistribution", invalid);
92      }
93  
94      
95  
96  
97      @Test
98      public void testProbability() {
99          double[] points = new double[]{-2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0};
100         double[] results = new double[]{0, 0.2, 0, 0, 0, 0.5, 0, 0, 0, 0.3, 0};
101         for (int p = 0; p < points.length; p++) {
102             double density = testDistribution.density(points[p]);
103             Assert.assertEquals(results[p], density, 0.0);
104         }
105     }
106 
107     
108 
109 
110     @Test
111     public void testDensity() {
112         double[] points = new double[]{-2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0};
113         double[] results = new double[]{0, 0.2, 0, 0, 0, 0.5, 0, 0, 0, 0.3, 0};
114         for (int p = 0; p < points.length; p++) {
115             double density = testDistribution.density(points[p]);
116             Assert.assertEquals(results[p], density, 0.0);
117         }
118     }
119 
120     
121 
122 
123     @Test
124     public void testCumulativeProbability() {
125         double[] points = new double[]{-2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0};
126         double[] results = new double[]{0, 0.2, 0.2, 0.2, 0.2, 0.7, 0.7, 0.7, 0.7, 1.0, 1.0};
127         for (int p = 0; p < points.length; p++) {
128             double probability = testDistribution.cumulativeProbability(points[p]);
129             Assert.assertEquals(results[p], probability, 1e-10);
130         }
131     }
132 
133     
134 
135 
136     @Test
137     public void testGetNumericalMean() {
138         Assert.assertEquals(3.4, testDistribution.getMean(), 1e-10);
139     }
140 
141     
142 
143 
144     @Test
145     public void testGetNumericalVariance() {
146         Assert.assertEquals(7.84, testDistribution.getVariance(), 1e-10);
147     }
148 
149     
150 
151 
152     @Test
153     public void testGetSupportLowerBound() {
154         Assert.assertEquals(-1, testDistribution.getSupportLowerBound(), 0);
155     }
156 
157     
158 
159 
160     @Test
161     public void testGetSupportUpperBound() {
162         Assert.assertEquals(7, testDistribution.getSupportUpperBound(), 0);
163     }
164 
165     
166 
167 
168     @Test
169     public void testSample() {
170         final int n = 1000000;
171         final ContinuousDistribution.Sampler sampler =
172             testDistribution.createSampler(RandomSource.XO_RO_SHI_RO_128_PP.create());
173         final double[] samples = AbstractRealDistribution.sample(n, sampler);
174         Assert.assertEquals(n, samples.length);
175         double sum = 0;
176         double sumOfSquares = 0;
177         for (int i = 0; i < samples.length; i++) {
178             sum += samples[i];
179             sumOfSquares += samples[i] * samples[i];
180         }
181         final double mean = testDistribution.getMean();
182         Assert.assertEquals("Mean", mean, sum / n, mean * 1e-2);
183         final double var = testDistribution.getVariance();
184         Assert.assertEquals("Variance", var, sumOfSquares / n - JdkMath.pow(sum / n, 2), var * 1e-2);
185     }
186 
187     @Test
188     public void testIssue942() {
189         List<Pair<Object,Double>> list = new ArrayList<>();
190         list.add(new Pair<Object, Double>(new Object() {}, Double.valueOf(0)));
191         list.add(new Pair<Object, Double>(new Object() {}, Double.valueOf(1)));
192         final UniformRandomProvider rng = RandomSource.WELL_512_A.create();
193         Assert.assertEquals(1, new EnumeratedDistribution<>(list).createSampler(rng).sample(1).length);
194     }
195 
196     @Test
197     public void testIssue1065() {
198         
199         
200         
201         
202         
203         
204         
205         
206         
207         
208         
209         
210         
211         
212         
213         
214         
215         
216         
217         
218         
219         
220         
221         
222 
223         EnumeratedRealDistribution distribution = new EnumeratedRealDistribution(
224                 new double[] { 14.0, 18.0, 21.0, 28.0, 31.0, 33.0 },
225                 new double[] { 4.0 / 16.0, 5.0 / 16.0, 0.0 / 16.0, 3.0 / 16.0, 1.0 / 16.0, 3.0 / 16.0 });
226 
227         assertEquals(14.0, distribution.inverseCumulativeProbability(0.0000), 0.0);
228         assertEquals(14.0, distribution.inverseCumulativeProbability(0.2500), 0.0);
229         assertEquals(33.0, distribution.inverseCumulativeProbability(1.0000), 0.0);
230 
231         assertEquals(18.0, distribution.inverseCumulativeProbability(0.5000), 0.0);
232         assertEquals(18.0, distribution.inverseCumulativeProbability(0.5624), 0.0);
233         assertEquals(28.0, distribution.inverseCumulativeProbability(0.5626), 0.0);
234         assertEquals(31.0, distribution.inverseCumulativeProbability(0.7600), 0.0);
235         assertEquals(18.0, distribution.inverseCumulativeProbability(0.5625), 0.0);
236         assertEquals(28.0, distribution.inverseCumulativeProbability(0.7500), 0.0);
237     }
238 
239     @Test
240     public void testCreateFromDoubles() {
241         final double[] data = new double[] {0, 1, 1, 2, 2, 2};
242         EnumeratedRealDistribution distribution = new EnumeratedRealDistribution(data);
243         assertEquals(0.5, distribution.density(2), 0);
244         assertEquals(0.5, distribution.cumulativeProbability(1), 0);
245     }
246 }