1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17 package org.apache.commons.nabla;
18
19 import java.lang.reflect.InvocationTargetException;
20 import java.lang.reflect.Method;
21
22 import org.apache.commons.nabla.core.DifferentialPair;
23 import org.apache.commons.nabla.core.UnivariateDifferentiable;
24 import org.apache.commons.nabla.numerical.EightPointsScheme;
25 import org.junit.Assert;
26
27
28 public abstract class AbstractStaticFunctionsTest {
29
30 /** Get the nabla class to test (either NablaMath or NablaStrictMath).
31 * @return nabla class
32 */
33 protected abstract Class<?> getNablaClass();
34
35 /** Get the java class to test (either Math or StrictMath).
36 * @return java class
37 */
38 protected abstract Class<?> getJavaClass();
39
40 protected void defaultMonadicTest(String name) {
41 checkMonadicFunction(getNablaClass(), getJavaClass(), name,
42 0.1, 0.3, 10, 1.0e-20, 2.0e-12);
43 }
44
45 protected void defaultDiadicTest(String name) {
46 checkDiadicFunction(getNablaClass(), getJavaClass(), name,
47 1.0, 4.0, 4, 2.0, 7.0, 6, 1.0e-15, 1.0e-11);
48 }
49
50 /** Check a monadic static function in a range.
51 * @param nablaClass class implementing the function with differentials
52 * @param javaClass class implementing the function without differentials
53 * @param name name of the function
54 * @param lower lower bound of the check interval
55 * @param upper upper bound of the check interval
56 * @param n number of check points
57 * @param valueTolerance tolerance on the value
58 * @param differentialTolerance tolerance on the differentials
59 */
60 protected void checkMonadicFunction(Class<?> nablaClass, Class<?> javaClass,
61 String name,
62 double lower, double upper, int n,
63 double valueTolerance,
64 double differentialTolerance) {
65
66 // build the sequence of arguments sets
67 double[][] argsSetsSequence = new double[n][];
68 for (int i = 0; i < n; ++i) {
69 double arg = (((n - 1) - i) * lower + i * upper) / (n - 1);
70 argsSetsSequence[i] = new double[] { arg };
71 }
72
73 // perform the check
74 checkFunction(nablaClass, javaClass, new Class<?>[] { Double.TYPE },
75 name, argsSetsSequence, valueTolerance, differentialTolerance);
76
77 }
78
79 /** Check a diadic static function in a range.
80 * @param nablaClass class implementing the function with differentials
81 * @param javaClass class implementing the function without differentials
82 * @param name name of the function
83 * @param lower1 lower bound of the check interval for the first argument
84 * @param upper1 upper bound of the check interval for the first argument
85 * @param n1 number of check points for the first argument
86 * @param lower2 lower bound of the check interval for the second argument
87 * @param upper2 upper bound of the check interval for the second argument
88 * @param n2 number of check points for the second argument
89 * @param valueTolerance tolerance on the value
90 * @param differentialTolerance tolerance on the differentials
91 */
92 protected void checkDiadicFunction(Class<?> nablaClass, Class<?> javaClass,
93 String name,
94 double lower1, double upper1, int n1,
95 double lower2, double upper2, int n2,
96 double valueTolerance,
97 double differentialTolerance) {
98
99 // build the sequence of arguments sets
100 double[][] argsSetsSequence = new double[n1 * n2][];
101 int k = 0;
102 for (int i = 0; i < n1; ++i) {
103 double arg1 = (((n1 - 1) - i) * lower1 + i * upper1) / (n1 - 1);
104 for (int j = 0; j < n2; ++j) {
105 double arg2 = (((n2 - 1) - j) * lower2 + j * upper2) / (n2 - 1);
106 argsSetsSequence[k++] = new double[] { arg1, arg2 };
107 }
108 }
109
110 // perform the check
111 checkFunction(nablaClass, javaClass, new Class<?>[] { Double.TYPE, Double.TYPE },
112 name, argsSetsSequence, valueTolerance, differentialTolerance);
113
114 }
115
116 /** Check a static function for a sequence of arguments sets.
117 * @param nablaClass class implementing the function with differentials
118 * @param javaClass class implementing the function without differentials
119 * @param javaTypes types of the parameters of the java reference method
120 * @param name name of the function
121 * @param argsSetsSequence sequence of arguments sets
122 * @param valueTolerance tolerance on the value
123 * @param differentialTolerance tolerance on the differential
124 */
125 protected void checkFunction(Class<?> nablaClass,
126 Class<?> javaClass, Class<?>[] javaTypes,
127 String name, double[][] argsSetsSequence,
128 double valueTolerance,
129 double differentialTolerance) {
130 try {
131
132 // get the reference java method
133 Method javaMethod = javaClass.getMethod(name, javaTypes);
134
135 // get the nabla methods we can test
136 Method[] nablaMethods = nablaClass.getMethods();
137 for (int i = 0; i < nablaMethods.length; ++i) {
138 if (nablaMethods[i].getName().equals(name) &&
139 (nablaMethods[i].getParameterTypes().length == argsSetsSequence[0].length)) {
140
141 // test this method
142 for (int j = 0; j < argsSetsSequence.length; ++j) {
143 checkValue(nablaMethods[i], javaMethod, argsSetsSequence[j], valueTolerance);
144 checkDifferential(nablaMethods[i], javaMethod, argsSetsSequence[j], differentialTolerance);
145 }
146
147 }
148 }
149
150 } catch (NoSuchMethodException nsme) {
151 Assert.fail(nsme.getMessage());
152 }
153
154 }
155
156 /** Check the value of the function.
157 * @param nablaMethod method with differential
158 * @param javaMethod method without differential
159 * @param argsSet arguments set
160 * @param valueTolerance tolerance on the value
161 */
162 private void checkValue(Method nablaMethod, Method javaMethod,
163 double[] argsSet, double valueTolerance) {
164 try {
165
166 // call the nabla method
167 Object[] nablaArgs = convert(nablaMethod.getParameterTypes(), argsSet);
168 DifferentialPair dp = (DifferentialPair) nablaMethod.invoke(null, nablaArgs);
169
170 // call the reference java method
171 Object[] javaArgs = convert(javaMethod.getParameterTypes(), argsSet);
172 double d = ((Double) javaMethod.invoke(null, javaArgs)).doubleValue();
173
174 // check the nabla and java classes compute the same function
175 Assert.assertEquals(d, dp.getValue(), valueTolerance);
176
177 } catch (InvocationTargetException ite) {
178 Assert.fail(ite.getMessage());
179 } catch (IllegalAccessException iae) {
180 Assert.fail(iae.getMessage());
181 }
182 }
183
184 /** Check the differential of the function.
185 * @param nablaMethod method with differential
186 * @param javaMethod method without differential
187 * @param argsSet arguments set
188 * @param differentialTolerance tolerance on the differential
189 */
190 private void checkDifferential(Method nablaMethod, final Method javaMethod,
191 double[] argsSet,
192 double differentialTolerance) {
193 try {
194
195 // compute the differential as implemented by the nabla function
196 final Object[] converted = convert(nablaMethod.getParameterTypes(), argsSet);
197 DifferentialPair nablaDP = (DifferentialPair) nablaMethod.invoke(null, converted);
198
199 // compute the reference differential by finite differences on the java function
200 // (in fact, this reference differential will be LESS accurate than the nabla one ...)
201 DifferentialPair differencesDP = new EightPointsScheme(1.0e-3).differentiate(new UnivariateDifferentiable() {
202 public double f(double x) {
203 try {
204 Object[] changed = new Object[converted.length];
205 for (int i = 0; i < converted.length; ++i) {
206 if (converted[i] instanceof DifferentialPair) {
207 DifferentialPair a = (DifferentialPair) converted[i];
208 changed[i] = new Double(a.getValue() + x * a.getFirstDerivative());
209 } else {
210 changed[i] = converted[i];
211 }
212 }
213 return ((Double) javaMethod.invoke(null, changed)).doubleValue();
214 } catch (IllegalArgumentException e) {
215 return Double.NaN;
216 } catch (IllegalAccessException e) {
217 return Double.NaN;
218 } catch (InvocationTargetException e) {
219 return Double.NaN;
220 }
221 }
222 }).f(DifferentialPair.newVariable(0.0));
223
224 // check the nabla and java classes compute the same differential
225 Assert.assertEquals(differencesDP.getFirstDerivative(), nablaDP.getFirstDerivative(),
226 differentialTolerance);
227
228 } catch (InvocationTargetException ite) {
229 Assert.fail(ite.getMessage());
230 } catch (IllegalAccessException iae) {
231 Assert.fail(iae.getMessage());
232 }
233 }
234
235 /** Converts an arguments array, with types conversion as needed.
236 * @param types types of the arguments
237 * @param values raw values of the arguments
238 * @return an arguments array
239 */
240 private Object[] convert(Class<?>[] types, double[] values) {
241 Object[] arguments = new Object[values.length];
242 for (int i = 0; i < values.length; ++i) {
243 if (types[i].equals(DifferentialPair.class)) {
244 arguments[i] = DifferentialPair.newVariable(values[i]);
245 } else if (types[i].equals(Double.TYPE)) {
246 arguments[i] = new Double(values[i]);
247 } else if (types[i].equals(Integer.TYPE)) {
248 arguments[i] = Integer.valueOf((int) values[i]);
249 } else if (types[i].equals(Long.TYPE)) {
250 arguments[i] = Long.valueOf((long) values[i]);
251 }
252 }
253 return arguments;
254 }
255
256 }