1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.math4.legacy;
19
20 import java.io.ByteArrayInputStream;
21 import java.io.ByteArrayOutputStream;
22 import java.io.ObjectInputStream;
23 import java.io.ObjectOutputStream;
24 import java.text.DecimalFormat;
25
26 import org.junit.Assert;
27
28 import org.apache.commons.numbers.complex.Complex;
29 import org.apache.commons.numbers.core.Precision;
30 import org.apache.commons.statistics.distribution.ContinuousDistribution;
31 import org.apache.commons.math4.legacy.core.FieldElement;
32 import org.apache.commons.math4.core.jdkmath.JdkMath;
33 import org.apache.commons.math4.legacy.util.ComplexFormat;
34 import org.apache.commons.math4.legacy.linear.FieldMatrix;
35 import org.apache.commons.math4.legacy.linear.RealMatrix;
36 import org.apache.commons.math4.legacy.linear.RealVector;
37 import org.apache.commons.math4.legacy.stat.inference.ChiSquareTest;
38
39
40
41 public final class TestUtils {
42
43
44
45 private TestUtils() {
46 super();
47 }
48
49
50
51
52
53 public static void assertEquals(double expected, double actual, double delta) {
54 Assert.assertEquals(null, expected, actual, delta);
55 }
56
57
58
59
60
61 public static void assertEquals(String msg, double expected, double actual, double delta) {
62
63 if(Double.isNaN(expected)){
64 Assert.assertTrue("" + actual + " is not NaN.",
65 Double.isNaN(actual));
66 } else {
67 Assert.assertEquals(msg, expected, actual, delta);
68 }
69 }
70
71
72
73
74
75 public static void assertSame(double expected, double actual) {
76 Assert.assertEquals(expected, actual, 0);
77 }
78
79
80
81
82
83 public static void assertSame(Complex expected, Complex actual) {
84 assertSame(expected.getReal(), actual.getReal());
85 assertSame(expected.getImaginary(), actual.getImaginary());
86 }
87
88
89
90
91
92 public static void assertEquals(Complex expected, Complex actual, double delta) {
93 Assert.assertEquals(expected.getReal(), actual.getReal(), delta);
94 Assert.assertEquals(expected.getImaginary(), actual.getImaginary(), delta);
95 }
96
97
98
99
100 public static void assertEquals(double expected[], double observed[], double tolerance) {
101 assertEquals("Array comparison failure", expected, observed, tolerance);
102 }
103
104
105
106
107
108
109
110
111 public static Object serializeAndRecover(Object o) {
112 try {
113
114 ByteArrayOutputStream bos = new ByteArrayOutputStream();
115 ObjectOutputStream so = new ObjectOutputStream(bos);
116 so.writeObject(o);
117
118
119 ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray());
120 ObjectInputStream si = new ObjectInputStream(bis);
121 return si.readObject();
122 } catch (Exception e) {
123 throw new RuntimeException(e);
124 }
125 }
126
127
128
129
130
131
132
133 public static void checkSerializedEquality(Object object) {
134 Object object2 = serializeAndRecover(object);
135 Assert.assertEquals("Equals check", object, object2);
136 Assert.assertEquals("HashCode check", object.hashCode(), object2.hashCode());
137 }
138
139
140
141
142
143
144
145
146
147
148 public static void assertRelativelyEquals(double expected, double actual,
149 double relativeError) {
150 assertRelativelyEquals(null, expected, actual, relativeError);
151 }
152
153
154
155
156
157
158
159
160
161
162
163 public static void assertRelativelyEquals(String msg, double expected,
164 double actual, double relativeError) {
165 if (Double.isNaN(expected)) {
166 Assert.assertTrue(msg, Double.isNaN(actual));
167 } else if (Double.isNaN(actual)) {
168 Assert.assertTrue(msg, Double.isNaN(expected));
169 } else if (Double.isInfinite(actual) || Double.isInfinite(expected)) {
170 Assert.assertEquals(expected, actual, relativeError);
171 } else if (expected == 0.0) {
172 Assert.assertEquals(msg, actual, expected, relativeError);
173 } else {
174 double absError = JdkMath.abs(expected) * relativeError;
175 Assert.assertEquals(msg, expected, actual, absError);
176 }
177 }
178
179
180
181
182
183
184
185
186
187 public static void assertContains(String msg, Complex[] values,
188 Complex z, double epsilon) {
189 for (Complex value : values) {
190 if (Precision.equals(value.getReal(), z.getReal(), epsilon) &&
191 Precision.equals(value.getImaginary(), z.getImaginary(), epsilon)) {
192 return;
193 }
194 }
195 Assert.fail(msg + " Unable to find " + (new ComplexFormat()).format(z));
196 }
197
198
199
200
201
202
203
204
205 public static void assertContains(Complex[] values,
206 Complex z, double epsilon) {
207 assertContains(null, values, z, epsilon);
208 }
209
210
211
212
213
214
215
216
217
218 public static void assertContains(String msg, double[] values,
219 double x, double epsilon) {
220 for (double value : values) {
221 if (Precision.equals(value, x, epsilon)) {
222 return;
223 }
224 }
225 Assert.fail(msg + " Unable to find " + x);
226 }
227
228
229
230
231
232
233
234
235 public static void assertContains(double[] values, double x,
236 double epsilon) {
237 assertContains(null, values, x, epsilon);
238 }
239
240
241
242
243
244
245
246
247
248
249
250
251 public static void assertEquals(final String message,
252 final double[] expected, final RealVector actual, final double delta) {
253 final String msgAndSep = message.isEmpty() ? "" : message + ", ";
254 Assert.assertEquals(msgAndSep + "dimension", expected.length,
255 actual.getDimension());
256 for (int i = 0; i < expected.length; i++) {
257 Assert.assertEquals(msgAndSep + "entry #" + i, expected[i],
258 actual.getEntry(i), delta);
259 }
260 }
261
262
263
264
265
266
267
268
269
270
271
272
273 public static void assertEquals(final String message,
274 final RealVector expected, final RealVector actual, final double delta) {
275 final String msgAndSep = message.isEmpty() ? "" : message + ", ";
276 Assert.assertEquals(msgAndSep + "dimension", expected.getDimension(),
277 actual.getDimension());
278 final int dim = expected.getDimension();
279 for (int i = 0; i < dim; i++) {
280 Assert.assertEquals(msgAndSep + "entry #" + i,
281 expected.getEntry(i), actual.getEntry(i), delta);
282 }
283 }
284
285
286 public static void assertEquals(String msg, RealMatrix expected, RealMatrix observed, double tolerance) {
287
288 Assert.assertNotNull(msg + "\nObserved should not be null",observed);
289
290 if (expected.getColumnDimension() != observed.getColumnDimension() ||
291 expected.getRowDimension() != observed.getRowDimension()) {
292 StringBuilder messageBuffer = new StringBuilder(msg);
293 messageBuffer.append("\nObserved has incorrect dimensions.");
294 messageBuffer.append("\nobserved is " + observed.getRowDimension() +
295 " x " + observed.getColumnDimension());
296 messageBuffer.append("\nexpected " + expected.getRowDimension() +
297 " x " + expected.getColumnDimension());
298 Assert.fail(messageBuffer.toString());
299 }
300
301 RealMatrix delta = expected.subtract(observed);
302 if (delta.getNorm() >= tolerance) {
303 StringBuilder messageBuffer = new StringBuilder(msg);
304 messageBuffer.append("\nExpected: " + expected);
305 messageBuffer.append("\nObserved: " + observed);
306 messageBuffer.append("\nexpected - observed: " + delta);
307 Assert.fail(messageBuffer.toString());
308 }
309 }
310
311
312 public static void assertEquals(FieldMatrix<? extends FieldElement<?>> expected,
313 FieldMatrix<? extends FieldElement<?>> observed) {
314
315 Assert.assertNotNull("Observed should not be null",observed);
316
317 if (expected.getColumnDimension() != observed.getColumnDimension() ||
318 expected.getRowDimension() != observed.getRowDimension()) {
319 StringBuilder messageBuffer = new StringBuilder();
320 messageBuffer.append("Observed has incorrect dimensions.");
321 messageBuffer.append("\nobserved is " + observed.getRowDimension() +
322 " x " + observed.getColumnDimension());
323 messageBuffer.append("\nexpected " + expected.getRowDimension() +
324 " x " + expected.getColumnDimension());
325 Assert.fail(messageBuffer.toString());
326 }
327
328 for (int i = 0; i < expected.getRowDimension(); ++i) {
329 for (int j = 0; j < expected.getColumnDimension(); ++j) {
330 FieldElement<?> eij = expected.getEntry(i, j);
331 FieldElement<?> oij = observed.getEntry(i, j);
332 Assert.assertEquals(eij, oij);
333 }
334 }
335 }
336
337
338 public static void assertEquals(String msg, double[] expected, double[] observed, double tolerance) {
339 StringBuilder out = new StringBuilder(msg);
340 if (expected.length != observed.length) {
341 out.append("\n Arrays not same length. \n");
342 out.append("expected has length ");
343 out.append(expected.length);
344 out.append(" observed length = ");
345 out.append(observed.length);
346 Assert.fail(out.toString());
347 }
348 boolean failure = false;
349 for (int i=0; i < expected.length; i++) {
350 if (!Precision.equalsIncludingNaN(expected[i], observed[i], tolerance)) {
351 failure = true;
352 out.append("\n Elements at index ");
353 out.append(i);
354 out.append(" differ. ");
355 out.append(" expected = ");
356 out.append(expected[i]);
357 out.append(" observed = ");
358 out.append(observed[i]);
359 }
360 }
361 if (failure) {
362 Assert.fail(out.toString());
363 }
364 }
365
366
367 public static void assertEquals(String msg, float[] expected, float[] observed, float tolerance) {
368 StringBuilder out = new StringBuilder(msg);
369 if (expected.length != observed.length) {
370 out.append("\n Arrays not same length. \n");
371 out.append("expected has length ");
372 out.append(expected.length);
373 out.append(" observed length = ");
374 out.append(observed.length);
375 Assert.fail(out.toString());
376 }
377 boolean failure = false;
378 for (int i=0; i < expected.length; i++) {
379 if (!Precision.equalsIncludingNaN(expected[i], observed[i], tolerance)) {
380 failure = true;
381 out.append("\n Elements at index ");
382 out.append(i);
383 out.append(" differ. ");
384 out.append(" expected = ");
385 out.append(expected[i]);
386 out.append(" observed = ");
387 out.append(observed[i]);
388 }
389 }
390 if (failure) {
391 Assert.fail(out.toString());
392 }
393 }
394
395
396 public static void assertEquals(String msg, Complex[] expected, Complex[] observed, double tolerance) {
397 StringBuilder out = new StringBuilder(msg);
398 if (expected.length != observed.length) {
399 out.append("\n Arrays not same length. \n");
400 out.append("expected has length ");
401 out.append(expected.length);
402 out.append(" observed length = ");
403 out.append(observed.length);
404 Assert.fail(out.toString());
405 }
406 boolean failure = false;
407 for (int i=0; i < expected.length; i++) {
408 if (!Precision.equalsIncludingNaN(expected[i].getReal(), observed[i].getReal(), tolerance)) {
409 failure = true;
410 out.append("\n Real elements at index ");
411 out.append(i);
412 out.append(" differ. ");
413 out.append(" expected = ");
414 out.append(expected[i].getReal());
415 out.append(" observed = ");
416 out.append(observed[i].getReal());
417 }
418 if (!Precision.equalsIncludingNaN(expected[i].getImaginary(), observed[i].getImaginary(), tolerance)) {
419 failure = true;
420 out.append("\n Imaginary elements at index ");
421 out.append(i);
422 out.append(" differ. ");
423 out.append(" expected = ");
424 out.append(expected[i].getImaginary());
425 out.append(" observed = ");
426 out.append(observed[i].getImaginary());
427 }
428 }
429 if (failure) {
430 Assert.fail(out.toString());
431 }
432 }
433
434
435 public static <T extends FieldElement<T>> void assertEquals(T[] m, T[] n) {
436 if (m.length != n.length) {
437 Assert.fail("vectors not same length");
438 }
439 for (int i = 0; i < m.length; i++) {
440 Assert.assertEquals(m[i],n[i]);
441 }
442 }
443
444
445
446
447
448
449
450
451 public static double sumSquareDev(double[] values, double target) {
452 double sumsq = 0d;
453 for (int i = 0; i < values.length; i++) {
454 final double dev = values[i] - target;
455 sumsq += dev * dev;
456 }
457 return sumsq;
458 }
459
460
461
462
463
464
465
466
467
468
469 public static void assertChiSquareAccept(String[] valueLabels, double[] expected, long[] observed, double alpha) {
470 ChiSquareTest chiSquareTest = new ChiSquareTest();
471
472
473 if (chiSquareTest.chiSquareTest(expected, observed, alpha)) {
474 StringBuilder msgBuffer = new StringBuilder();
475 DecimalFormat df = new DecimalFormat("#.##");
476 msgBuffer.append("Chisquare test failed");
477 msgBuffer.append(" p-value = ");
478 msgBuffer.append(chiSquareTest.chiSquareTest(expected, observed));
479 msgBuffer.append(" chisquare statistic = ");
480 msgBuffer.append(chiSquareTest.chiSquare(expected, observed));
481 msgBuffer.append(". \n");
482 msgBuffer.append("value\texpected\tobserved\n");
483 for (int i = 0; i < expected.length; i++) {
484 msgBuffer.append(valueLabels[i]);
485 msgBuffer.append("\t");
486 msgBuffer.append(df.format(expected[i]));
487 msgBuffer.append("\t\t");
488 msgBuffer.append(observed[i]);
489 msgBuffer.append("\n");
490 }
491 msgBuffer.append("This test can fail randomly due to sampling error with probability ");
492 msgBuffer.append(alpha);
493 msgBuffer.append(".");
494 Assert.fail(msgBuffer.toString());
495 }
496 }
497
498
499
500
501
502
503
504
505
506
507 public static void assertChiSquareAccept(int[] values, double[] expected, long[] observed, double alpha) {
508 String[] labels = new String[values.length];
509 for (int i = 0; i < values.length; i++) {
510 labels[i] = Integer.toString(values[i]);
511 }
512 assertChiSquareAccept(labels, expected, observed, alpha);
513 }
514
515
516
517
518
519
520
521
522
523 public static void assertChiSquareAccept(double[] expected, long[] observed, double alpha) {
524 String[] labels = new String[expected.length];
525 for (int i = 0; i < labels.length; i++) {
526 labels[i] = Integer.toString(i + 1);
527 }
528 assertChiSquareAccept(labels, expected, observed, alpha);
529 }
530
531
532
533
534
535 public static double[] getDistributionQuartiles(ContinuousDistribution distribution) {
536 double[] quantiles = new double[3];
537 quantiles[0] = distribution.inverseCumulativeProbability(0.25d);
538 quantiles[1] = distribution.inverseCumulativeProbability(0.5d);
539 quantiles[2] = distribution.inverseCumulativeProbability(0.75d);
540 return quantiles;
541 }
542
543
544
545
546
547 public static void updateCounts(double value, long[] counts, double[] quartiles) {
548 if (value < quartiles[0]) {
549 counts[0]++;
550 } else if (value > quartiles[2]) {
551 counts[3]++;
552 } else if (value > quartiles[1]) {
553 counts[2]++;
554 } else {
555 counts[1]++;
556 }
557 }
558
559
560
561
562
563
564
565 public static int eliminateZeroMassPoints(int[] densityPoints, double[] densityValues) {
566 int positiveMassCount = 0;
567 for (int i = 0; i < densityValues.length; i++) {
568 if (densityValues[i] > 0) {
569 positiveMassCount++;
570 }
571 }
572 if (positiveMassCount < densityValues.length) {
573 int[] newPoints = new int[positiveMassCount];
574 double[] newValues = new double[positiveMassCount];
575 int j = 0;
576 for (int i = 0; i < densityValues.length; i++) {
577 if (densityValues[i] > 0) {
578 newPoints[j] = densityPoints[i];
579 newValues[j] = densityValues[i];
580 j++;
581 }
582 }
583 System.arraycopy(newPoints,0,densityPoints,0,positiveMassCount);
584 System.arraycopy(newValues,0,densityValues,0,positiveMassCount);
585 }
586 return positiveMassCount;
587 }
588 }