1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.nabla;
18
19 import org.apache.commons.math3.analysis.UnivariateFunction;
20 import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
21 import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction;
22 import org.apache.commons.math3.util.FastMath;
23 import org.apache.commons.nabla.forward.ForwardModeDifferentiator;
24 import org.junit.Assert;
25
26
27 public abstract class AbstractMathTest {
28
29 public interface ReferenceFunction extends UnivariateFunction {
30 double firstDerivative(double t);
31 }
32
33 protected void checkReference(ReferenceFunction reference,
34 double t0, double t1, int n,
35 double threshold) {
36 try {
37 ForwardModeDifferentiator differentiator = new ForwardModeDifferentiator();
38 differentiator.addMathImplementation(MathExtensions.class);
39 UnivariateDifferentiableFunction derivative = differentiator.differentiate(reference);
40 for (int i = 0; i < n; ++i) {
41 double t = ((n - 1 - i) * t0 + i * t1) / (n - 1);
42 DerivativeStructure dpT = new DerivativeStructure(1, 1, 0, t);
43 Assert.assertEquals("error = " + ((reference.firstDerivative(t) - derivative.value(dpT).getPartialDerivative(1))),
44 reference.firstDerivative(t), derivative.value(dpT).getPartialDerivative(1),
45 threshold);
46 Assert.assertEquals("error = " + ((reference.value(t) - derivative.value(dpT).getValue())),
47 reference.value(t), derivative.value(dpT).getValue(),
48 threshold);
49 }
50 } catch (DifferentiationException de) {
51 de.printStackTrace(System.err);
52 Assert.fail(de.getLocalizedMessage());
53 }
54 }
55
56 public static class MathExtensions {
57 public static double acosh(double a) {
58 return FastMath.log(a + FastMath.sqrt(a - 1) * FastMath.sqrt(a + 1));
59 }
60 public static double asinh(double a) {
61 return FastMath.log(a + FastMath.sqrt(a * a + 1));
62 }
63 public static double atanh(double a) {
64 return (FastMath.log1p(a) - FastMath.log1p(-a)) / 2;
65 }
66 public static double sqrt(double a) {
67 return FastMath.sqrt(a);
68 }
69 }
70
71 }