1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.nabla.automatic;
18
19 import org.apache.commons.nabla.ReferenceFunction;
20 import org.apache.commons.nabla.core.DifferentialPair;
21 import org.apache.commons.nabla.core.DifferentiationException;
22 import org.apache.commons.nabla.core.UnivariateDerivative;
23
24 import junit.framework.TestCase;
25
26 public abstract class AbstractMathTest extends TestCase {
27
28 protected void checkReference(ReferenceFunction reference,
29 double t0, double t1, int n,
30 double threshold) {
31 try {
32 AutomaticDifferentiator differentiator = new AutomaticDifferentiator();
33 differentiator.addMathImplementation(MathExtensions.class);
34 UnivariateDerivative derivative = differentiator.differentiate(reference);
35 for (int i = 0; i < n; ++i) {
36 double t = ((n - 1 - i) * t0 + i * t1) / (n - 1);
37 DifferentialPair dpT = DifferentialPair.newVariable(t);
38 assertEquals(reference.fPrime(t), derivative.f(dpT).getFirstDerivative(), threshold);
39 }
40 } catch (DifferentiationException de) {
41 fail(de.getMessage());
42 }
43 }
44
45 public static class MathExtensions {
46 public static double acosh(double a) {
47 return Math.log(a + Math.sqrt(a - 1) * Math.sqrt(a + 1));
48 }
49 public static double asinh(double a) {
50 return Math.log(a + Math.sqrt(a * a + 1));
51 }
52 public static double atanh(double a) {
53 return (Math.log1p(a) - Math.log1p(-a)) / 2;
54 }
55 public static double sqrt(double a) {
56 return Math.sqrt(a);
57 }
58 }
59
60 }