1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.nabla.automatic;
18  
19  import org.apache.commons.nabla.ReferenceFunction;
20  import org.apache.commons.nabla.core.DifferentialPair;
21  import org.apache.commons.nabla.core.DifferentiationException;
22  import org.apache.commons.nabla.core.UnivariateDerivative;
23  
24  import junit.framework.TestCase;
25  
26  public abstract class AbstractMathTest extends TestCase {
27  
28      protected void checkReference(ReferenceFunction reference,
29                                    double t0, double t1, int n,
30                                    double threshold) {
31          try {
32              AutomaticDifferentiator differentiator = new AutomaticDifferentiator();
33              differentiator.addMathImplementation(MathExtensions.class);
34              UnivariateDerivative derivative = differentiator.differentiate(reference);
35              for (int i = 0; i < n; ++i) {
36                  double t = ((n - 1 - i) * t0 + i * t1) / (n - 1);
37                  DifferentialPair dpT = DifferentialPair.newVariable(t);
38                  assertEquals(reference.fPrime(t), derivative.f(dpT).getFirstDerivative(), threshold);
39              }
40          } catch (DifferentiationException de) {
41              fail(de.getMessage());
42          }
43      }
44  
45      public static class MathExtensions {
46          public static double acosh(double a) {
47              return Math.log(a + Math.sqrt(a - 1) * Math.sqrt(a + 1));
48          }
49          public static double asinh(double a) {
50              return Math.log(a + Math.sqrt(a * a + 1));
51          }
52          public static double atanh(double a) {
53              return (Math.log1p(a) - Math.log1p(-a)) / 2;
54          }
55          public static double sqrt(double a) {
56              return Math.sqrt(a);
57          }
58      }
59  
60  }