View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.math4.legacy.stat.descriptive;
18  
19  
20  import java.util.Locale;
21  
22  import org.apache.commons.math4.legacy.TestUtils;
23  import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
24  import org.apache.commons.math4.legacy.exception.MathIllegalStateException;
25  import org.apache.commons.math4.legacy.stat.descriptive.moment.Mean;
26  import org.apache.commons.math4.core.jdkmath.JdkMath;
27  import org.junit.Test;
28  import org.junit.Assert;
29  
30  /**
31   * Test cases for the {@link MultivariateSummaryStatistics} class.
32   *
33   */
34  
35  public class MultivariateSummaryStatisticsTest {
36  
37      protected MultivariateSummaryStatistics createMultivariateSummaryStatistics(int k, boolean isCovarianceBiasCorrected) {
38          return new MultivariateSummaryStatistics(k, isCovarianceBiasCorrected);
39      }
40  
41      @Test
42      public void testSetterInjection() {
43          MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(2, true);
44          u.setMeanImpl(new StorelessUnivariateStatistic[] {
45                          new SumMean(), new SumMean()
46                        });
47          u.addValue(new double[] { 1, 2 });
48          u.addValue(new double[] { 3, 4 });
49          Assert.assertEquals(4, u.getMean()[0], 1E-14);
50          Assert.assertEquals(6, u.getMean()[1], 1E-14);
51          u.clear();
52          u.addValue(new double[] { 1, 2 });
53          u.addValue(new double[] { 3, 4 });
54          Assert.assertEquals(4, u.getMean()[0], 1E-14);
55          Assert.assertEquals(6, u.getMean()[1], 1E-14);
56          u.clear();
57          u.setMeanImpl(new StorelessUnivariateStatistic[] {
58                          new Mean(), new Mean()
59                        }); // OK after clear
60          u.addValue(new double[] { 1, 2 });
61          u.addValue(new double[] { 3, 4 });
62          Assert.assertEquals(2, u.getMean()[0], 1E-14);
63          Assert.assertEquals(3, u.getMean()[1], 1E-14);
64          Assert.assertEquals(2, u.getDimension());
65      }
66  
67      @Test
68      public void testSetterIllegalState() {
69          MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(2, true);
70          u.addValue(new double[] { 1, 2 });
71          u.addValue(new double[] { 3, 4 });
72          try {
73              u.setMeanImpl(new StorelessUnivariateStatistic[] { new SumMean(), new SumMean() });
74              Assert.fail("Expecting MathIllegalStateException");
75          } catch (MathIllegalStateException ex) {
76              // expected
77          }
78      }
79  
80      @Test
81      public void testToString() {
82          MultivariateSummaryStatistics stats = createMultivariateSummaryStatistics(2, true);
83          stats.addValue(new double[] {1, 3});
84          stats.addValue(new double[] {2, 2});
85          stats.addValue(new double[] {3, 1});
86          Locale d = Locale.getDefault();
87          Locale.setDefault(Locale.US);
88          final String suffix = System.getProperty("line.separator");
89          Assert.assertEquals("MultivariateSummaryStatistics:" + suffix+
90                       "n: 3" +suffix+
91                       "min: 1.0, 1.0" +suffix+
92                       "max: 3.0, 3.0" +suffix+
93                       "mean: 2.0, 2.0" +suffix+
94                       "geometric mean: 1.817..., 1.817..." +suffix+
95                       "sum of squares: 14.0, 14.0" +suffix+
96                       "sum of logarithms: 1.791..., 1.791..." +suffix+
97                       "standard deviation: 1.0, 1.0" +suffix+
98                       "covariance: Array2DRowRealMatrix{{1.0,-1.0},{-1.0,1.0}}" +suffix,
99                       stats.toString().replaceAll("([0-9]+\\.[0-9][0-9][0-9])[0-9]+", "$1..."));
100         Locale.setDefault(d);
101     }
102 
103     @Test
104     public void testShuffledStatistics() {
105         // the purpose of this test is only to check the get/set methods
106         // we are aware shuffling statistics like this is really not
107         // something sensible to do in production ...
108         MultivariateSummaryStatistics reference = createMultivariateSummaryStatistics(2, true);
109         MultivariateSummaryStatistics shuffled  = createMultivariateSummaryStatistics(2, true);
110 
111         StorelessUnivariateStatistic[] tmp = shuffled.getGeoMeanImpl();
112         shuffled.setGeoMeanImpl(shuffled.getMeanImpl());
113         shuffled.setMeanImpl(shuffled.getMaxImpl());
114         shuffled.setMaxImpl(shuffled.getMinImpl());
115         shuffled.setMinImpl(shuffled.getSumImpl());
116         shuffled.setSumImpl(shuffled.getSumsqImpl());
117         shuffled.setSumsqImpl(shuffled.getSumLogImpl());
118         shuffled.setSumLogImpl(tmp);
119 
120         for (int i = 100; i > 0; --i) {
121             reference.addValue(new double[] {i, i});
122             shuffled.addValue(new double[] {i, i});
123         }
124 
125         TestUtils.assertEquals(reference.getMean(),          shuffled.getGeometricMean(), 1.0e-10);
126         TestUtils.assertEquals(reference.getMax(),           shuffled.getMean(),          1.0e-10);
127         TestUtils.assertEquals(reference.getMin(),           shuffled.getMax(),           1.0e-10);
128         TestUtils.assertEquals(reference.getSum(),           shuffled.getMin(),           1.0e-10);
129         TestUtils.assertEquals(reference.getSumSq(),         shuffled.getSum(),           1.0e-10);
130         TestUtils.assertEquals(reference.getSumLog(),        shuffled.getSumSq(),         1.0e-10);
131         TestUtils.assertEquals(reference.getGeometricMean(), shuffled.getSumLog(),        1.0e-10);
132     }
133 
134     /**
135      * Bogus mean implementation to test setter injection.
136      * Returns the sum instead of the mean.
137      */
138     static class SumMean implements StorelessUnivariateStatistic {
139         private double sum = 0;
140         private long n = 0;
141         @Override
142         public double evaluate(double[] values, int begin, int length) {
143             return 0;
144         }
145         @Override
146         public double evaluate(double[] values) {
147             return 0;
148         }
149         @Override
150         public void clear() {
151           sum = 0;
152           n = 0;
153         }
154         @Override
155         public long getN() {
156             return n;
157         }
158         @Override
159         public double getResult() {
160             return sum;
161         }
162         @Override
163         public void increment(double d) {
164             sum += d;
165             n++;
166         }
167         @Override
168         public void incrementAll(double[] values, int start, int length) {
169         }
170         @Override
171         public void incrementAll(double[] values) {
172         }
173         @Override
174         public StorelessUnivariateStatistic copy() {
175             return new SumMean();
176         }
177     }
178 
179     @Test
180     public void testDimension() {
181         try {
182             createMultivariateSummaryStatistics(2, true).addValue(new double[3]);
183             Assert.fail("Expecting DimensionMismatchException");
184         } catch (DimensionMismatchException dme) {
185             // expected behavior
186         }
187     }
188 
189     /** test stats */
190     @Test
191     public void testStats() {
192         MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(2, true);
193         Assert.assertEquals(0, u.getN());
194         u.addValue(new double[] { 1, 2 });
195         u.addValue(new double[] { 2, 3 });
196         u.addValue(new double[] { 2, 3 });
197         u.addValue(new double[] { 3, 4 });
198         Assert.assertEquals( 4, u.getN());
199         Assert.assertEquals( 8, u.getSum()[0], 1.0e-10);
200         Assert.assertEquals(12, u.getSum()[1], 1.0e-10);
201         Assert.assertEquals(18, u.getSumSq()[0], 1.0e-10);
202         Assert.assertEquals(38, u.getSumSq()[1], 1.0e-10);
203         Assert.assertEquals( 1, u.getMin()[0], 1.0e-10);
204         Assert.assertEquals( 2, u.getMin()[1], 1.0e-10);
205         Assert.assertEquals( 3, u.getMax()[0], 1.0e-10);
206         Assert.assertEquals( 4, u.getMax()[1], 1.0e-10);
207         Assert.assertEquals(2.4849066497880003102, u.getSumLog()[0], 1.0e-10);
208         Assert.assertEquals( 4.276666119016055311, u.getSumLog()[1], 1.0e-10);
209         Assert.assertEquals( 1.8612097182041991979, u.getGeometricMean()[0], 1.0e-10);
210         Assert.assertEquals( 2.9129506302439405217, u.getGeometricMean()[1], 1.0e-10);
211         Assert.assertEquals( 2, u.getMean()[0], 1.0e-10);
212         Assert.assertEquals( 3, u.getMean()[1], 1.0e-10);
213         Assert.assertEquals(JdkMath.sqrt(2.0 / 3.0), u.getStandardDeviation()[0], 1.0e-10);
214         Assert.assertEquals(JdkMath.sqrt(2.0 / 3.0), u.getStandardDeviation()[1], 1.0e-10);
215         Assert.assertEquals(2.0 / 3.0, u.getCovariance().getEntry(0, 0), 1.0e-10);
216         Assert.assertEquals(2.0 / 3.0, u.getCovariance().getEntry(0, 1), 1.0e-10);
217         Assert.assertEquals(2.0 / 3.0, u.getCovariance().getEntry(1, 0), 1.0e-10);
218         Assert.assertEquals(2.0 / 3.0, u.getCovariance().getEntry(1, 1), 1.0e-10);
219         u.clear();
220         Assert.assertEquals(0, u.getN());
221     }
222 
223     @Test
224     public void testN0andN1Conditions() {
225         MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(1, true);
226         Assert.assertTrue(Double.isNaN(u.getMean()[0]));
227         Assert.assertTrue(Double.isNaN(u.getStandardDeviation()[0]));
228 
229         /* n=1 */
230         u.addValue(new double[] { 1 });
231         Assert.assertEquals(1.0, u.getMean()[0], 1.0e-10);
232         Assert.assertEquals(1.0, u.getGeometricMean()[0], 1.0e-10);
233         Assert.assertEquals(0.0, u.getStandardDeviation()[0], 1.0e-10);
234 
235         /* n=2 */
236         u.addValue(new double[] { 2 });
237         Assert.assertTrue(u.getStandardDeviation()[0] > 0);
238     }
239 
240     @Test
241     public void testNaNContracts() {
242         MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(1, true);
243         Assert.assertTrue(Double.isNaN(u.getMean()[0]));
244         Assert.assertTrue(Double.isNaN(u.getMin()[0]));
245         Assert.assertTrue(Double.isNaN(u.getStandardDeviation()[0]));
246         Assert.assertTrue(Double.isNaN(u.getGeometricMean()[0]));
247 
248         u.addValue(new double[] { 1.0 });
249         Assert.assertFalse(Double.isNaN(u.getMean()[0]));
250         Assert.assertFalse(Double.isNaN(u.getMin()[0]));
251         Assert.assertFalse(Double.isNaN(u.getStandardDeviation()[0]));
252         Assert.assertFalse(Double.isNaN(u.getGeometricMean()[0]));
253     }
254 
255     @Test
256     public void testEqualsAndHashCode() {
257         MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(2, true);
258         MultivariateSummaryStatistics t = null;
259         int emptyHash = u.hashCode();
260         Assert.assertEquals(u, u);
261         Assert.assertNotEquals(u, t);
262         Assert.assertFalse(u.equals(Double.valueOf(0)));
263         t = createMultivariateSummaryStatistics(2, true);
264         Assert.assertEquals(t, u);
265         Assert.assertEquals(u, t);
266         Assert.assertEquals(emptyHash, t.hashCode());
267 
268         // Add some data to u
269         u.addValue(new double[] { 2d, 1d });
270         u.addValue(new double[] { 1d, 1d });
271         u.addValue(new double[] { 3d, 1d });
272         u.addValue(new double[] { 4d, 1d });
273         u.addValue(new double[] { 5d, 1d });
274         Assert.assertFalse(t.equals(u));
275         Assert.assertFalse(u.equals(t));
276         Assert.assertTrue(u.hashCode() != t.hashCode());
277 
278         //Add data in same order to t
279         t.addValue(new double[] { 2d, 1d });
280         t.addValue(new double[] { 1d, 1d });
281         t.addValue(new double[] { 3d, 1d });
282         t.addValue(new double[] { 4d, 1d });
283         t.addValue(new double[] { 5d, 1d });
284         Assert.assertTrue(t.equals(u));
285         Assert.assertTrue(u.equals(t));
286         Assert.assertEquals(u.hashCode(), t.hashCode());
287 
288         // Clear and make sure summaries are indistinguishable from empty summary
289         u.clear();
290         t.clear();
291         Assert.assertTrue(t.equals(u));
292         Assert.assertTrue(u.equals(t));
293         Assert.assertEquals(emptyHash, t.hashCode());
294         Assert.assertEquals(emptyHash, u.hashCode());
295     }
296 }