1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.nabla;
18  
19  import org.apache.commons.math3.analysis.UnivariateFunction;
20  import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
21  import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction;
22  import org.apache.commons.math3.util.FastMath;
23  import org.apache.commons.nabla.forward.ForwardModeDifferentiator;
24  import org.junit.Assert;
25  
26  
27  public abstract class AbstractMathTest {
28  
29      public interface ReferenceFunction extends UnivariateFunction {
30          double firstDerivative(double t);
31      }
32  
33      protected void checkReference(ReferenceFunction reference,
34                                    double t0, double t1, int n,
35                                    double threshold) {
36          try {
37              ForwardModeDifferentiator differentiator = new ForwardModeDifferentiator();
38              differentiator.addMathImplementation(MathExtensions.class);
39              UnivariateDifferentiableFunction derivative = differentiator.differentiate(reference);
40              for (int i = 0; i < n; ++i) {
41                  double t = ((n - 1 - i) * t0 + i * t1) / (n - 1);
42                  DerivativeStructure dpT = new DerivativeStructure(1, 1, 0, t);
43                  Assert.assertEquals("error = " + ((reference.firstDerivative(t) - derivative.value(dpT).getPartialDerivative(1))),
44                                      reference.firstDerivative(t), derivative.value(dpT).getPartialDerivative(1),
45                                      threshold);
46                  Assert.assertEquals("error = " + ((reference.value(t) - derivative.value(dpT).getValue())),
47                                      reference.value(t), derivative.value(dpT).getValue(),
48                                      threshold);
49              }
50          } catch (DifferentiationException de) {
51              de.printStackTrace(System.err);
52              Assert.fail(de.getLocalizedMessage());
53          }
54      }
55  
56      public static class MathExtensions {
57          public static double acosh(double a) {
58              return FastMath.log(a + FastMath.sqrt(a - 1) * FastMath.sqrt(a + 1));
59          }
60          public static double asinh(double a) {
61              return FastMath.log(a + FastMath.sqrt(a * a + 1));
62          }
63          public static double atanh(double a) {
64              return (FastMath.log1p(a) - FastMath.log1p(-a)) / 2;
65          }
66          public static double sqrt(double a) {
67              return FastMath.sqrt(a);
68          }
69      }
70  
71  }