1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.nabla.forward;
18  
19  import org.apache.commons.math3.analysis.UnivariateFunction;
20  import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
21  import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction;
22  import org.apache.commons.math3.util.FastMath;
23  import org.apache.commons.nabla.AbstractMathTest;
24  import org.junit.Assert;
25  import org.junit.Test;
26  
27  public class ForwardModeDifferentiatorTest extends AbstractMathTest {
28  
29      @Test
30      public void testSingleCall() {
31          checkReference(new ReferenceFunction() {
32              public double value(double t) { return FastMath.cos(t); }
33              public double firstDerivative(double t) { return -FastMath.sin(t); }
34          }, 0, 2 * FastMath.PI, 20, 2.0e-16);
35      }
36  
37      @Test
38      public void testEmbeddedCalls() {
39          checkReference(new ReferenceFunction() {
40              public double value(double t) { return FastMath.exp(FastMath.sin(t)); }
41              public double firstDerivative(double t) { return FastMath.cos(t) * FastMath.exp(FastMath.sin(t)); }
42          }, 0.1, 10, 20, 2.0e-16);
43      }
44  
45      @Test
46      public void testParameterIndependent() {
47          checkReference(new ReferenceFunction() {
48              public double value(double t) { return 12; }
49              public double firstDerivative(double t) { return 0; }
50          }, 0.1, 5, 20, 1.0e-20);
51      }
52  
53      @Test
54      public void testSimpleExpression() {
55          checkReference(new ReferenceFunction() {
56              public double value(double t) { return 1.0 / t; }
57              public double firstDerivative(double t) { return -1 / (t * t); }
58          }, 0.1, 5, 20, 2.0e-14);
59      }
60  
61      @Test
62      public void testMul() {
63          checkReference(new ReferenceFunction() {
64              public double value(double t) { return t * t; }
65              public double firstDerivative(double t) { return 2 * t; }
66          }, 0.1, 5, 20, 1.0e-20);
67      }
68  
69      @Test
70      public void testPolynomialExpression() {
71          checkReference(new ReferenceFunction() {
72              public double value(double t) { return (((4 * t + 2) * t + 1) * t - 2) * t + 5; }
73              public double firstDerivative(double t) { return ((16 * t + 6) * t + 2) * t - 2; }
74          }, 0.1, 5, 20, 3.0e-13);
75      }
76  
77      @Test
78      public void testNarrowing() {
79          checkReference(new ReferenceFunction() {
80              public double value(double t) { return t - (int) t; }
81              public double firstDerivative(double t) { return 1; }
82          }, 0.1, 5, 20, 1.0e-20);
83      }
84  
85      @Test
86      public void testLocalVariables() {
87          checkReference(new ReferenceFunction() {
88              public double value(double t) { double threeT = 3 * t; return threeT * threeT; }
89              public double firstDerivative(double t) { return 18 * t; }
90          }, -5, 5, 20, 2.0e-14);
91      }
92  
93      @Test
94      public void testLoopLdc() {
95          checkReference(new ReferenceFunction() {
96              public double value(double t) {
97                  double result = 2.0;
98                  for (int i = 0; i < 3; ++i) {
99                      result *= t;
100                 }
101                 return result;
102             }
103             public double firstDerivative(double t) { return 6 * t * t; }
104         }, -5, 5, 20, 2.0e-14);
105     }
106 
107     @Test
108     public void testLoopDcons() {
109         checkReference(new ReferenceFunction() {
110             public double value(double t) {
111                 double result = 1.0;
112                 for (int i = 0; i < 3; ++i) {
113                     result *= t;
114                 }
115                 return result;
116             }
117             public double firstDerivative(double t) { return 3 * t * t; }
118         }, -5, 5, 20, 8.0e-15);
119     }
120 
121     @Test
122     public void testPartialDerivatives() throws Exception {
123         PartialFunction function = new PartialFunction(1);
124 
125         final UnivariateDifferentiableFunction derivative = new
126                 ForwardModeDifferentiator().differentiate(function);
127         DerivativeStructure t = new DerivativeStructure(1, 1, 0, 1.0);
128         Assert.assertEquals(3, derivative.value(t).getPartialDerivative(1), 1.0e-20);
129         Assert.assertEquals(2, derivative.value(t).getValue(), 1.0e-20);
130         function.setX(2);
131         Assert.assertEquals(4, derivative.value(t).getPartialDerivative(1), 1.0e-20);
132         Assert.assertEquals(3, derivative.value(t).getValue(), 1.0e-20);
133     }
134 
135     public class PartialFunction implements UnivariateFunction {
136         private double x;
137         public PartialFunction(double x) {
138             this.x = x;
139         }
140         public void setX(double x) {
141             this.x = x;
142         }
143         public double getX() {
144             return x;
145         }
146         public double value(double y) {
147             return x * y + y * y;
148         }
149     }
150 
151 }