View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math4.legacy;
19  
20  import java.io.ByteArrayInputStream;
21  import java.io.ByteArrayOutputStream;
22  import java.io.ObjectInputStream;
23  import java.io.ObjectOutputStream;
24  import java.text.DecimalFormat;
25  
26  import org.junit.Assert;
27  
28  import org.apache.commons.numbers.complex.Complex;
29  import org.apache.commons.numbers.core.Precision;
30  import org.apache.commons.statistics.distribution.ContinuousDistribution;
31  import org.apache.commons.math4.legacy.core.FieldElement;
32  import org.apache.commons.math4.core.jdkmath.JdkMath;
33  import org.apache.commons.math4.legacy.util.ComplexFormat;
34  import org.apache.commons.math4.legacy.linear.FieldMatrix;
35  import org.apache.commons.math4.legacy.linear.RealMatrix;
36  import org.apache.commons.math4.legacy.linear.RealVector;
37  import org.apache.commons.math4.legacy.stat.inference.ChiSquareTest;
38  
39  /**
40   */
41  public final class TestUtils {
42      /**
43       * Collection of static methods used in math unit tests.
44       */
45      private TestUtils() {
46          super();
47      }
48  
49      /**
50       * Verifies that expected and actual are within delta, or are both NaN or
51       * infinities of the same sign.
52       */
53      public static void assertEquals(double expected, double actual, double delta) {
54          Assert.assertEquals(null, expected, actual, delta);
55      }
56  
57      /**
58       * Verifies that expected and actual are within delta, or are both NaN or
59       * infinities of the same sign.
60       */
61      public static void assertEquals(String msg, double expected, double actual, double delta) {
62          // check for NaN
63          if(Double.isNaN(expected)){
64              Assert.assertTrue("" + actual + " is not NaN.",
65                  Double.isNaN(actual));
66          } else {
67              Assert.assertEquals(msg, expected, actual, delta);
68          }
69      }
70  
71      /**
72       * Verifies that the two arguments are exactly the same, either
73       * both NaN or infinities of same sign, or identical floating point values.
74       */
75      public static void assertSame(double expected, double actual) {
76       Assert.assertEquals(expected, actual, 0);
77      }
78  
79      /**
80       * Verifies that real and imaginary parts of the two complex arguments
81       * are exactly the same.  Also ensures that NaN / infinite components match.
82       */
83      public static void assertSame(Complex expected, Complex actual) {
84          assertSame(expected.getReal(), actual.getReal());
85          assertSame(expected.getImaginary(), actual.getImaginary());
86      }
87  
88      /**
89       * Verifies that real and imaginary parts of the two complex arguments
90       * differ by at most delta.  Also ensures that NaN / infinite components match.
91       */
92      public static void assertEquals(Complex expected, Complex actual, double delta) {
93          Assert.assertEquals(expected.getReal(), actual.getReal(), delta);
94          Assert.assertEquals(expected.getImaginary(), actual.getImaginary(), delta);
95      }
96  
97      /**
98       * Verifies that two double arrays have equal entries, up to tolerance
99       */
100     public static void assertEquals(double expected[], double observed[], double tolerance) {
101         assertEquals("Array comparison failure", expected, observed, tolerance);
102     }
103 
104     /**
105      * Serializes an object to a bytes array and then recovers the object from the bytes array.
106      * Returns the deserialized object.
107      *
108      * @param o  object to serialize and recover
109      * @return  the recovered, deserialized object
110      */
111     public static Object serializeAndRecover(Object o) {
112         try {
113             // serialize the Object
114             ByteArrayOutputStream bos = new ByteArrayOutputStream();
115             ObjectOutputStream so = new ObjectOutputStream(bos);
116             so.writeObject(o);
117 
118             // deserialize the Object
119             ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray());
120             ObjectInputStream si = new ObjectInputStream(bis);
121             return si.readObject();
122         } catch (Exception e) {
123             throw new RuntimeException(e);
124         }
125     }
126 
127     /**
128      * Verifies that serialization preserves equals and hashCode.
129      * Serializes the object, then recovers it and checks equals and hash code.
130      *
131      * @param object  the object to serialize and recover
132      */
133     public static void checkSerializedEquality(Object object) {
134         Object object2 = serializeAndRecover(object);
135         Assert.assertEquals("Equals check", object, object2);
136         Assert.assertEquals("HashCode check", object.hashCode(), object2.hashCode());
137     }
138 
139     /**
140      * Verifies that the relative error in actual vs. expected is less than or
141      * equal to relativeError.  If expected is infinite or NaN, actual must be
142      * the same (NaN or infinity of the same sign).
143      *
144      * @param expected expected value
145      * @param actual  observed value
146      * @param relativeError  maximum allowable relative error
147      */
148     public static void assertRelativelyEquals(double expected, double actual,
149             double relativeError) {
150         assertRelativelyEquals(null, expected, actual, relativeError);
151     }
152 
153     /**
154      * Verifies that the relative error in actual vs. expected is less than or
155      * equal to relativeError.  If expected is infinite or NaN, actual must be
156      * the same (NaN or infinity of the same sign).
157      *
158      * @param msg  message to return with failure
159      * @param expected expected value
160      * @param actual  observed value
161      * @param relativeError  maximum allowable relative error
162      */
163     public static void assertRelativelyEquals(String msg, double expected,
164             double actual, double relativeError) {
165         if (Double.isNaN(expected)) {
166             Assert.assertTrue(msg, Double.isNaN(actual));
167         } else if (Double.isNaN(actual)) {
168             Assert.assertTrue(msg, Double.isNaN(expected));
169         } else if (Double.isInfinite(actual) || Double.isInfinite(expected)) {
170             Assert.assertEquals(expected, actual, relativeError);
171         } else if (expected == 0.0) {
172             Assert.assertEquals(msg, actual, expected, relativeError);
173         } else {
174             double absError = JdkMath.abs(expected) * relativeError;
175             Assert.assertEquals(msg, expected, actual, absError);
176         }
177     }
178 
179     /**
180      * Fails iff values does not contain a number within epsilon of z.
181      *
182      * @param msg  message to return with failure
183      * @param values complex array to search
184      * @param z  value sought
185      * @param epsilon  tolerance
186      */
187     public static void assertContains(String msg, Complex[] values,
188                                       Complex z, double epsilon) {
189         for (Complex value : values) {
190             if (Precision.equals(value.getReal(), z.getReal(), epsilon) &&
191                 Precision.equals(value.getImaginary(), z.getImaginary(), epsilon)) {
192                 return;
193             }
194         }
195         Assert.fail(msg + " Unable to find " + (new ComplexFormat()).format(z));
196     }
197 
198     /**
199      * Fails iff values does not contain a number within epsilon of z.
200      *
201      * @param values complex array to search
202      * @param z  value sought
203      * @param epsilon  tolerance
204      */
205     public static void assertContains(Complex[] values,
206             Complex z, double epsilon) {
207         assertContains(null, values, z, epsilon);
208     }
209 
210     /**
211      * Fails iff values does not contain a number within epsilon of x.
212      *
213      * @param msg  message to return with failure
214      * @param values double array to search
215      * @param x value sought
216      * @param epsilon  tolerance
217      */
218     public static void assertContains(String msg, double[] values,
219             double x, double epsilon) {
220         for (double value : values) {
221             if (Precision.equals(value, x, epsilon)) {
222                 return;
223             }
224         }
225         Assert.fail(msg + " Unable to find " + x);
226     }
227 
228     /**
229      * Fails iff values does not contain a number within epsilon of x.
230      *
231      * @param values double array to search
232      * @param x value sought
233      * @param epsilon  tolerance
234      */
235     public static void assertContains(double[] values, double x,
236             double epsilon) {
237        assertContains(null, values, x, epsilon);
238     }
239 
240     /**
241      * Asserts that all entries of the specified vectors are equal to within a
242      * positive {@code delta}.
243      *
244      * @param message the identifying message for the assertion error (can be
245      * {@code null})
246      * @param expected expected value
247      * @param actual actual value
248      * @param delta the maximum difference between the entries of the expected
249      * and actual vectors for which both entries are still considered equal
250      */
251     public static void assertEquals(final String message,
252         final double[] expected, final RealVector actual, final double delta) {
253         final String msgAndSep = message.isEmpty() ? "" : message + ", ";
254         Assert.assertEquals(msgAndSep + "dimension", expected.length,
255             actual.getDimension());
256         for (int i = 0; i < expected.length; i++) {
257             Assert.assertEquals(msgAndSep + "entry #" + i, expected[i],
258                 actual.getEntry(i), delta);
259         }
260     }
261 
262     /**
263      * Asserts that all entries of the specified vectors are equal to within a
264      * positive {@code delta}.
265      *
266      * @param message the identifying message for the assertion error (can be
267      * {@code null})
268      * @param expected expected value
269      * @param actual actual value
270      * @param delta the maximum difference between the entries of the expected
271      * and actual vectors for which both entries are still considered equal
272      */
273     public static void assertEquals(final String message,
274         final RealVector expected, final RealVector actual, final double delta) {
275         final String msgAndSep = message.isEmpty() ? "" : message + ", ";
276         Assert.assertEquals(msgAndSep + "dimension", expected.getDimension(),
277             actual.getDimension());
278         final int dim = expected.getDimension();
279         for (int i = 0; i < dim; i++) {
280             Assert.assertEquals(msgAndSep + "entry #" + i,
281                 expected.getEntry(i), actual.getEntry(i), delta);
282         }
283     }
284 
285     /** verifies that two matrices are close (1-norm) */
286     public static void assertEquals(String msg, RealMatrix expected, RealMatrix observed, double tolerance) {
287 
288         Assert.assertNotNull(msg + "\nObserved should not be null",observed);
289 
290         if (expected.getColumnDimension() != observed.getColumnDimension() ||
291                 expected.getRowDimension() != observed.getRowDimension()) {
292             StringBuilder messageBuffer = new StringBuilder(msg);
293             messageBuffer.append("\nObserved has incorrect dimensions.");
294             messageBuffer.append("\nobserved is " + observed.getRowDimension() +
295                     " x " + observed.getColumnDimension());
296             messageBuffer.append("\nexpected " + expected.getRowDimension() +
297                     " x " + expected.getColumnDimension());
298             Assert.fail(messageBuffer.toString());
299         }
300 
301         RealMatrix delta = expected.subtract(observed);
302         if (delta.getNorm() >= tolerance) {
303             StringBuilder messageBuffer = new StringBuilder(msg);
304             messageBuffer.append("\nExpected: " + expected);
305             messageBuffer.append("\nObserved: " + observed);
306             messageBuffer.append("\nexpected - observed: " + delta);
307             Assert.fail(messageBuffer.toString());
308         }
309     }
310 
311     /** verifies that two matrices are equal */
312     public static void assertEquals(FieldMatrix<? extends FieldElement<?>> expected,
313                                     FieldMatrix<? extends FieldElement<?>> observed) {
314 
315         Assert.assertNotNull("Observed should not be null",observed);
316 
317         if (expected.getColumnDimension() != observed.getColumnDimension() ||
318                 expected.getRowDimension() != observed.getRowDimension()) {
319             StringBuilder messageBuffer = new StringBuilder();
320             messageBuffer.append("Observed has incorrect dimensions.");
321             messageBuffer.append("\nobserved is " + observed.getRowDimension() +
322                     " x " + observed.getColumnDimension());
323             messageBuffer.append("\nexpected " + expected.getRowDimension() +
324                     " x " + expected.getColumnDimension());
325             Assert.fail(messageBuffer.toString());
326         }
327 
328         for (int i = 0; i < expected.getRowDimension(); ++i) {
329             for (int j = 0; j < expected.getColumnDimension(); ++j) {
330                 FieldElement<?> eij = expected.getEntry(i, j);
331                 FieldElement<?> oij = observed.getEntry(i, j);
332                 Assert.assertEquals(eij, oij);
333             }
334         }
335     }
336 
337     /** verifies that two arrays are close (sup norm) */
338     public static void assertEquals(String msg, double[] expected, double[] observed, double tolerance) {
339         StringBuilder out = new StringBuilder(msg);
340         if (expected.length != observed.length) {
341             out.append("\n Arrays not same length. \n");
342             out.append("expected has length ");
343             out.append(expected.length);
344             out.append(" observed length = ");
345             out.append(observed.length);
346             Assert.fail(out.toString());
347         }
348         boolean failure = false;
349         for (int i=0; i < expected.length; i++) {
350             if (!Precision.equalsIncludingNaN(expected[i], observed[i], tolerance)) {
351                 failure = true;
352                 out.append("\n Elements at index ");
353                 out.append(i);
354                 out.append(" differ. ");
355                 out.append(" expected = ");
356                 out.append(expected[i]);
357                 out.append(" observed = ");
358                 out.append(observed[i]);
359             }
360         }
361         if (failure) {
362             Assert.fail(out.toString());
363         }
364     }
365 
366     /** verifies that two arrays are close (sup norm) */
367     public static void assertEquals(String msg, float[] expected, float[] observed, float tolerance) {
368         StringBuilder out = new StringBuilder(msg);
369         if (expected.length != observed.length) {
370             out.append("\n Arrays not same length. \n");
371             out.append("expected has length ");
372             out.append(expected.length);
373             out.append(" observed length = ");
374             out.append(observed.length);
375             Assert.fail(out.toString());
376         }
377         boolean failure = false;
378         for (int i=0; i < expected.length; i++) {
379             if (!Precision.equalsIncludingNaN(expected[i], observed[i], tolerance)) {
380                 failure = true;
381                 out.append("\n Elements at index ");
382                 out.append(i);
383                 out.append(" differ. ");
384                 out.append(" expected = ");
385                 out.append(expected[i]);
386                 out.append(" observed = ");
387                 out.append(observed[i]);
388             }
389         }
390         if (failure) {
391             Assert.fail(out.toString());
392         }
393     }
394 
395     /** verifies that two arrays are close (sup norm) */
396     public static void assertEquals(String msg, Complex[] expected, Complex[] observed, double tolerance) {
397         StringBuilder out = new StringBuilder(msg);
398         if (expected.length != observed.length) {
399             out.append("\n Arrays not same length. \n");
400             out.append("expected has length ");
401             out.append(expected.length);
402             out.append(" observed length = ");
403             out.append(observed.length);
404             Assert.fail(out.toString());
405         }
406         boolean failure = false;
407         for (int i=0; i < expected.length; i++) {
408             if (!Precision.equalsIncludingNaN(expected[i].getReal(), observed[i].getReal(), tolerance)) {
409                 failure = true;
410                 out.append("\n Real elements at index ");
411                 out.append(i);
412                 out.append(" differ. ");
413                 out.append(" expected = ");
414                 out.append(expected[i].getReal());
415                 out.append(" observed = ");
416                 out.append(observed[i].getReal());
417             }
418             if (!Precision.equalsIncludingNaN(expected[i].getImaginary(), observed[i].getImaginary(), tolerance)) {
419                 failure = true;
420                 out.append("\n Imaginary elements at index ");
421                 out.append(i);
422                 out.append(" differ. ");
423                 out.append(" expected = ");
424                 out.append(expected[i].getImaginary());
425                 out.append(" observed = ");
426                 out.append(observed[i].getImaginary());
427             }
428         }
429         if (failure) {
430             Assert.fail(out.toString());
431         }
432     }
433 
434     /** verifies that two arrays are equal */
435     public static <T extends FieldElement<T>> void assertEquals(T[] m, T[] n) {
436         if (m.length != n.length) {
437             Assert.fail("vectors not same length");
438         }
439         for (int i = 0; i < m.length; i++) {
440             Assert.assertEquals(m[i],n[i]);
441         }
442     }
443 
444     /**
445      * Computes the sum of squared deviations of <values> from <target>
446      * @param values array of deviates
447      * @param target value to compute deviations from
448      *
449      * @return sum of squared deviations
450      */
451     public static double sumSquareDev(double[] values, double target) {
452         double sumsq = 0d;
453         for (int i = 0; i < values.length; i++) {
454             final double dev = values[i] - target;
455             sumsq += dev * dev;
456         }
457         return sumsq;
458     }
459 
460     /**
461      * Asserts the null hypothesis for a ChiSquare test.  Fails and dumps arguments and test
462      * statistics if the null hypothesis can be rejected with confidence 100 * (1 - alpha)%
463      *
464      * @param valueLabels labels for the values of the discrete distribution under test
465      * @param expected expected counts
466      * @param observed observed counts
467      * @param alpha significance level of the test
468      */
469     public static void assertChiSquareAccept(String[] valueLabels, double[] expected, long[] observed, double alpha) {
470         ChiSquareTest chiSquareTest = new ChiSquareTest();
471 
472         // Fail if we can reject null hypothesis that distributions are the same
473         if (chiSquareTest.chiSquareTest(expected, observed, alpha)) {
474             StringBuilder msgBuffer = new StringBuilder();
475             DecimalFormat df = new DecimalFormat("#.##");
476             msgBuffer.append("Chisquare test failed");
477             msgBuffer.append(" p-value = ");
478             msgBuffer.append(chiSquareTest.chiSquareTest(expected, observed));
479             msgBuffer.append(" chisquare statistic = ");
480             msgBuffer.append(chiSquareTest.chiSquare(expected, observed));
481             msgBuffer.append(". \n");
482             msgBuffer.append("value\texpected\tobserved\n");
483             for (int i = 0; i < expected.length; i++) {
484                 msgBuffer.append(valueLabels[i]);
485                 msgBuffer.append("\t");
486                 msgBuffer.append(df.format(expected[i]));
487                 msgBuffer.append("\t\t");
488                 msgBuffer.append(observed[i]);
489                 msgBuffer.append("\n");
490             }
491             msgBuffer.append("This test can fail randomly due to sampling error with probability ");
492             msgBuffer.append(alpha);
493             msgBuffer.append(".");
494             Assert.fail(msgBuffer.toString());
495         }
496     }
497 
498     /**
499      * Asserts the null hypothesis for a ChiSquare test.  Fails and dumps arguments and test
500      * statistics if the null hypothesis can be rejected with confidence 100 * (1 - alpha)%
501      *
502      * @param values integer values whose observed and expected counts are being compared
503      * @param expected expected counts
504      * @param observed observed counts
505      * @param alpha significance level of the test
506      */
507     public static void assertChiSquareAccept(int[] values, double[] expected, long[] observed, double alpha) {
508         String[] labels = new String[values.length];
509         for (int i = 0; i < values.length; i++) {
510             labels[i] = Integer.toString(values[i]);
511         }
512         assertChiSquareAccept(labels, expected, observed, alpha);
513     }
514 
515     /**
516      * Asserts the null hypothesis for a ChiSquare test.  Fails and dumps arguments and test
517      * statistics if the null hypothesis can be rejected with confidence 100 * (1 - alpha)%
518      *
519      * @param expected expected counts
520      * @param observed observed counts
521      * @param alpha significance level of the test
522      */
523     public static void assertChiSquareAccept(double[] expected, long[] observed, double alpha) {
524         String[] labels = new String[expected.length];
525         for (int i = 0; i < labels.length; i++) {
526             labels[i] = Integer.toString(i + 1);
527         }
528         assertChiSquareAccept(labels, expected, observed, alpha);
529     }
530 
531     /**
532      * Computes the 25th, 50th and 75th percentiles of the given distribution and returns
533      * these values in an array.
534      */
535     public static double[] getDistributionQuartiles(ContinuousDistribution distribution) {
536         double[] quantiles = new double[3];
537         quantiles[0] = distribution.inverseCumulativeProbability(0.25d);
538         quantiles[1] = distribution.inverseCumulativeProbability(0.5d);
539         quantiles[2] = distribution.inverseCumulativeProbability(0.75d);
540         return quantiles;
541     }
542 
543     /**
544      * Updates observed counts of values in quartiles.
545      * counts[0] <-> 1st quartile ... counts[3] <-> top quartile
546      */
547     public static void updateCounts(double value, long[] counts, double[] quartiles) {
548         if (value < quartiles[0]) {
549             counts[0]++;
550         } else if (value > quartiles[2]) {
551             counts[3]++;
552         } else if (value > quartiles[1]) {
553             counts[2]++;
554         } else {
555             counts[1]++;
556         }
557     }
558 
559     /**
560      * Eliminates points with zero mass from densityPoints and densityValues parallel
561      * arrays.  Returns the number of positive mass points and collapses the arrays so
562      * that the first <returned value> elements of the input arrays represent the positive
563      * mass points.
564      */
565     public static int eliminateZeroMassPoints(int[] densityPoints, double[] densityValues) {
566         int positiveMassCount = 0;
567         for (int i = 0; i < densityValues.length; i++) {
568             if (densityValues[i] > 0) {
569                 positiveMassCount++;
570             }
571         }
572         if (positiveMassCount < densityValues.length) {
573             int[] newPoints = new int[positiveMassCount];
574             double[] newValues = new double[positiveMassCount];
575             int j = 0;
576             for (int i = 0; i < densityValues.length; i++) {
577                 if (densityValues[i] > 0) {
578                     newPoints[j] = densityPoints[i];
579                     newValues[j] = densityValues[i];
580                     j++;
581                 }
582             }
583             System.arraycopy(newPoints,0,densityPoints,0,positiveMassCount);
584             System.arraycopy(newValues,0,densityValues,0,positiveMassCount);
585         }
586         return positiveMassCount;
587     }
588 }