1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.nabla;
18  
19  import java.lang.reflect.InvocationTargetException;
20  import java.lang.reflect.Method;
21  
22  import org.apache.commons.nabla.core.DifferentialPair;
23  import org.apache.commons.nabla.core.UnivariateDifferentiable;
24  import org.apache.commons.nabla.numerical.EightPointsScheme;
25  import org.junit.Assert;
26  
27  
28  public abstract class AbstractStaticFunctionsTest {
29  
30      /** Get the nabla class to test (either NablaMath or NablaStrictMath).
31       * @return nabla class
32       */
33      protected abstract Class<?> getNablaClass();
34  
35      /** Get the java class to test (either Math or StrictMath).
36       * @return java class
37       */
38      protected abstract Class<?> getJavaClass();
39  
40      protected void defaultMonadicTest(String name) {
41          checkMonadicFunction(getNablaClass(), getJavaClass(), name,
42                               0.1, 0.3, 10, 1.0e-20, 2.0e-12);
43      }
44  
45      protected void defaultDiadicTest(String name) {
46          checkDiadicFunction(getNablaClass(), getJavaClass(), name,
47                              1.0, 4.0, 4, 2.0, 7.0, 6, 1.0e-15, 1.0e-11);
48      }
49  
50      /** Check a monadic static function in a range.
51       * @param nablaClass class implementing the function with differentials
52       * @param javaClass class implementing the function without differentials
53       * @param name name of the function
54       * @param lower lower bound of the check interval
55       * @param upper upper bound of the check interval
56       * @param n number of check points
57       * @param valueTolerance tolerance on the value
58       * @param differentialTolerance tolerance on the differentials
59       */
60      protected void checkMonadicFunction(Class<?> nablaClass, Class<?> javaClass,
61                                          String name,
62                                          double lower, double upper, int n,
63                                          double valueTolerance,
64                                          double differentialTolerance) {
65  
66          // build the sequence of arguments sets
67          double[][] argsSetsSequence = new double[n][];
68          for (int i = 0; i < n; ++i) {
69              double arg = (((n - 1) - i) * lower + i * upper) / (n - 1);
70              argsSetsSequence[i] = new double[] { arg };
71          }
72  
73          // perform the check
74          checkFunction(nablaClass, javaClass, new Class<?>[] { Double.TYPE },
75                        name, argsSetsSequence, valueTolerance, differentialTolerance);
76  
77      }
78  
79      /** Check a diadic static function in a range.
80       * @param nablaClass class implementing the function with differentials
81       * @param javaClass class implementing the function without differentials
82       * @param name name of the function
83       * @param lower1 lower bound of the check interval for the first argument
84       * @param upper1 upper bound of the check interval for the first argument
85       * @param n1 number of check points for the first argument
86       * @param lower2 lower bound of the check interval for the second argument
87       * @param upper2 upper bound of the check interval for the second argument
88       * @param n2 number of check points for the second argument
89       * @param valueTolerance tolerance on the value
90       * @param differentialTolerance tolerance on the differentials
91       */
92      protected void checkDiadicFunction(Class<?> nablaClass, Class<?> javaClass,
93                                         String name,
94                                         double lower1, double upper1, int n1,
95                                         double lower2, double upper2, int n2,
96                                         double valueTolerance,
97                                         double differentialTolerance) {
98  
99          // build the sequence of arguments sets
100         double[][] argsSetsSequence = new double[n1 * n2][];
101         int k = 0;
102         for (int i = 0; i < n1; ++i) {
103             double arg1 = (((n1 - 1) - i) * lower1 + i * upper1) / (n1 - 1);
104             for (int j = 0; j < n2; ++j) {
105                 double arg2 = (((n2 - 1) - j) * lower2 + j * upper2) / (n2 - 1);
106                 argsSetsSequence[k++] = new double[] { arg1, arg2 };
107             }
108         }
109 
110         // perform the check
111         checkFunction(nablaClass, javaClass, new Class<?>[] { Double.TYPE, Double.TYPE },
112                       name, argsSetsSequence, valueTolerance, differentialTolerance);
113 
114     }
115 
116     /** Check a static function for a sequence of arguments sets.
117      * @param nablaClass class implementing the function with differentials
118      * @param javaClass class implementing the function without differentials
119      * @param javaTypes types of the parameters of the java reference method
120      * @param name name of the function
121      * @param argsSetsSequence sequence of arguments sets
122      * @param valueTolerance tolerance on the value
123      * @param differentialTolerance tolerance on the differential
124      */
125     protected void checkFunction(Class<?> nablaClass,
126                                  Class<?> javaClass, Class<?>[] javaTypes,
127                                  String name, double[][] argsSetsSequence,
128                                  double valueTolerance,
129                                  double differentialTolerance) {
130         try {
131 
132             // get the reference java method
133             Method javaMethod = javaClass.getMethod(name, javaTypes);
134 
135             // get the nabla methods we can test
136             Method[] nablaMethods = nablaClass.getMethods();
137             for (int i = 0; i < nablaMethods.length; ++i) {
138                 if (nablaMethods[i].getName().equals(name) &&
139                         (nablaMethods[i].getParameterTypes().length == argsSetsSequence[0].length)) {
140 
141                     // test this method
142                     for (int j = 0; j < argsSetsSequence.length; ++j) {
143                         checkValue(nablaMethods[i], javaMethod, argsSetsSequence[j], valueTolerance);
144                         checkDifferential(nablaMethods[i], javaMethod, argsSetsSequence[j], differentialTolerance);
145                     }
146 
147                 }
148             }
149 
150         } catch (NoSuchMethodException nsme) {
151             Assert.fail(nsme.getMessage());
152         }
153 
154     }
155 
156     /** Check the value of the function.
157      * @param nablaMethod method with differential
158      * @param javaMethod method without differential
159      * @param argsSet arguments set
160      * @param valueTolerance tolerance on the value
161      */
162     private void checkValue(Method nablaMethod, Method javaMethod,
163                             double[] argsSet, double valueTolerance) {
164         try {
165 
166             // call the nabla method
167             Object[] nablaArgs = convert(nablaMethod.getParameterTypes(), argsSet);
168             DifferentialPair dp = (DifferentialPair) nablaMethod.invoke(null, nablaArgs);
169 
170             // call the reference java method
171             Object[] javaArgs = convert(javaMethod.getParameterTypes(), argsSet);
172             double d = ((Double) javaMethod.invoke(null, javaArgs)).doubleValue();
173 
174             // check the nabla and java classes compute the same function
175             Assert.assertEquals(d, dp.getValue(), valueTolerance);
176 
177         } catch (InvocationTargetException ite) {
178             Assert.fail(ite.getMessage());
179         } catch (IllegalAccessException iae) {
180             Assert.fail(iae.getMessage());
181         }
182     }
183 
184     /** Check the differential of the function.
185      * @param nablaMethod method with differential
186      * @param javaMethod method without differential
187      * @param argsSet arguments set
188      * @param differentialTolerance tolerance on the differential
189      */
190     private void checkDifferential(Method nablaMethod, final Method javaMethod,
191                                  double[] argsSet,
192                                  double differentialTolerance) {
193         try {
194 
195             // compute the differential as implemented by the nabla function
196             final Object[] converted = convert(nablaMethod.getParameterTypes(), argsSet);
197             DifferentialPair nablaDP = (DifferentialPair) nablaMethod.invoke(null, converted);
198 
199             // compute the reference differential by finite differences on the java function
200             // (in fact, this reference differential will be LESS accurate than the nabla one ...)
201             DifferentialPair differencesDP = new EightPointsScheme(1.0e-3).differentiate(new UnivariateDifferentiable() {
202                 public double f(double x) {
203                     try {
204                         Object[] changed = new Object[converted.length];
205                         for (int i = 0; i < converted.length; ++i) {
206                             if (converted[i] instanceof DifferentialPair) {
207                                 DifferentialPair a = (DifferentialPair) converted[i];
208                                 changed[i] = new Double(a.getValue() + x * a.getFirstDerivative());                
209                             } else {
210                                 changed[i] = converted[i];
211                             }
212                         }
213                         return ((Double) javaMethod.invoke(null, changed)).doubleValue();
214                     } catch (IllegalArgumentException e) {
215                         return Double.NaN;
216                     } catch (IllegalAccessException e) {
217                         return Double.NaN;
218                     } catch (InvocationTargetException e) {
219                         return Double.NaN;
220                     }
221                 }
222             }).f(DifferentialPair.newVariable(0.0));
223 
224             // check the nabla and java classes compute the same differential
225             Assert.assertEquals(differencesDP.getFirstDerivative(), nablaDP.getFirstDerivative(),
226                                 differentialTolerance);
227 
228         } catch (InvocationTargetException ite) {
229             Assert.fail(ite.getMessage());
230         } catch (IllegalAccessException iae) {
231             Assert.fail(iae.getMessage());
232         }
233     }
234 
235     /** Converts an arguments array, with types conversion as needed.
236      * @param types types of the arguments
237      * @param values raw values of the arguments
238      * @return an arguments array
239      */
240     private Object[] convert(Class<?>[] types, double[] values) {
241         Object[] arguments = new Object[values.length];
242         for (int i = 0; i < values.length; ++i) {
243             if (types[i].equals(DifferentialPair.class)) {
244                 arguments[i] = DifferentialPair.newVariable(values[i]);
245             } else if (types[i].equals(Double.TYPE)) {
246                 arguments[i] = new Double(values[i]);
247             } else if (types[i].equals(Integer.TYPE)) {
248                 arguments[i] = Integer.valueOf((int) values[i]);
249             } else if (types[i].equals(Long.TYPE)) {
250                 arguments[i] = Long.valueOf((long) values[i]);
251             }
252         }
253         return arguments;
254     }
255 
256 }