1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.nabla.forward;
18
19 import org.apache.commons.math3.analysis.UnivariateFunction;
20 import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
21 import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction;
22 import org.apache.commons.math3.util.FastMath;
23 import org.apache.commons.nabla.AbstractMathTest;
24 import org.junit.Assert;
25 import org.junit.Test;
26
27 public class ForwardModeDifferentiatorTest extends AbstractMathTest {
28
29 @Test
30 public void testSingleCall() {
31 checkReference(new ReferenceFunction() {
32 public double value(double t) { return FastMath.cos(t); }
33 public double firstDerivative(double t) { return -FastMath.sin(t); }
34 }, 0, 2 * FastMath.PI, 20, 2.0e-16);
35 }
36
37 @Test
38 public void testEmbeddedCalls() {
39 checkReference(new ReferenceFunction() {
40 public double value(double t) { return FastMath.exp(FastMath.sin(t)); }
41 public double firstDerivative(double t) { return FastMath.cos(t) * FastMath.exp(FastMath.sin(t)); }
42 }, 0.1, 10, 20, 2.0e-16);
43 }
44
45 @Test
46 public void testParameterIndependent() {
47 checkReference(new ReferenceFunction() {
48 public double value(double t) { return 12; }
49 public double firstDerivative(double t) { return 0; }
50 }, 0.1, 5, 20, 1.0e-20);
51 }
52
53 @Test
54 public void testSimpleExpression() {
55 checkReference(new ReferenceFunction() {
56 public double value(double t) { return 1.0 / t; }
57 public double firstDerivative(double t) { return -1 / (t * t); }
58 }, 0.1, 5, 20, 2.0e-14);
59 }
60
61 @Test
62 public void testMul() {
63 checkReference(new ReferenceFunction() {
64 public double value(double t) { return t * t; }
65 public double firstDerivative(double t) { return 2 * t; }
66 }, 0.1, 5, 20, 1.0e-20);
67 }
68
69 @Test
70 public void testPolynomialExpression() {
71 checkReference(new ReferenceFunction() {
72 public double value(double t) { return (((4 * t + 2) * t + 1) * t - 2) * t + 5; }
73 public double firstDerivative(double t) { return ((16 * t + 6) * t + 2) * t - 2; }
74 }, 0.1, 5, 20, 3.0e-13);
75 }
76
77 @Test
78 public void testNarrowing() {
79 checkReference(new ReferenceFunction() {
80 public double value(double t) { return t - (int) t; }
81 public double firstDerivative(double t) { return 1; }
82 }, 0.1, 5, 20, 1.0e-20);
83 }
84
85 @Test
86 public void testLocalVariables() {
87 checkReference(new ReferenceFunction() {
88 public double value(double t) { double threeT = 3 * t; return threeT * threeT; }
89 public double firstDerivative(double t) { return 18 * t; }
90 }, -5, 5, 20, 2.0e-14);
91 }
92
93 @Test
94 public void testLoopLdc() {
95 checkReference(new ReferenceFunction() {
96 public double value(double t) {
97 double result = 2.0;
98 for (int i = 0; i < 3; ++i) {
99 result *= t;
100 }
101 return result;
102 }
103 public double firstDerivative(double t) { return 6 * t * t; }
104 }, -5, 5, 20, 2.0e-14);
105 }
106
107 @Test
108 public void testLoopDcons() {
109 checkReference(new ReferenceFunction() {
110 public double value(double t) {
111 double result = 1.0;
112 for (int i = 0; i < 3; ++i) {
113 result *= t;
114 }
115 return result;
116 }
117 public double firstDerivative(double t) { return 3 * t * t; }
118 }, -5, 5, 20, 8.0e-15);
119 }
120
121 @Test
122 public void testPartialDerivatives() throws Exception {
123 PartialFunction function = new PartialFunction(1);
124
125 final UnivariateDifferentiableFunction derivative = new
126 ForwardModeDifferentiator().differentiate(function);
127 DerivativeStructure t = new DerivativeStructure(1, 1, 0, 1.0);
128 Assert.assertEquals(3, derivative.value(t).getPartialDerivative(1), 1.0e-20);
129 Assert.assertEquals(2, derivative.value(t).getValue(), 1.0e-20);
130 function.setX(2);
131 Assert.assertEquals(4, derivative.value(t).getPartialDerivative(1), 1.0e-20);
132 Assert.assertEquals(3, derivative.value(t).getValue(), 1.0e-20);
133 }
134
135 public class PartialFunction implements UnivariateFunction {
136 private double x;
137 public PartialFunction(double x) {
138 this.x = x;
139 }
140 public void setX(double x) {
141 this.x = x;
142 }
143 public double getX() {
144 return x;
145 }
146 public double value(double y) {
147 return x * y + y * y;
148 }
149 }
150
151 }