1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.nabla.forward;
18
19 import java.io.IOException;
20 import java.io.OutputStream;
21 import java.lang.reflect.Constructor;
22 import java.lang.reflect.InvocationTargetException;
23 import java.util.HashMap;
24 import java.util.HashSet;
25 import java.util.Set;
26
27 import org.apache.commons.math3.analysis.UnivariateFunction;
28 import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
29 import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction;
30 import org.apache.commons.math3.analysis.differentiation.UnivariateFunctionDifferentiator;
31 import org.apache.commons.math3.util.FastMath;
32 import org.apache.commons.nabla.DifferentiationException;
33 import org.apache.commons.nabla.NablaMessages;
34 import org.apache.commons.nabla.forward.analysis.ClassDifferentiator;
35 import org.objectweb.asm.ClassWriter;
36 import org.objectweb.asm.Type;
37 import org.objectweb.asm.tree.ClassNode;
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57 public class ForwardModeDifferentiator implements UnivariateFunctionDifferentiator {
58
59
60 private final HashMap<Class<? extends UnivariateFunction>,
61 Class<? extends UnivariateDifferentiableFunction>> map;
62
63
64 private final HashMap<String, byte[]> byteCodeMap;
65
66
67 private final Set<String> mathClasses;
68
69
70
71
72 public ForwardModeDifferentiator() {
73 map = new HashMap<Class<? extends UnivariateFunction>,
74 Class<? extends UnivariateDifferentiableFunction>>();
75 byteCodeMap = new HashMap<String, byte[]>();
76 mathClasses = new HashSet<String>();
77 addMathImplementation(Math.class);
78 addMathImplementation(StrictMath.class);
79 addMathImplementation(FastMath.class);
80 }
81
82
83
84
85
86
87
88
89
90
91 public void addMathImplementation(final Class<?> mathClass) {
92 mathClasses.add(mathClass.getName().replace('.', '/'));
93 }
94
95
96
97
98 public void dumpCache(final OutputStream out) {
99
100 throw new RuntimeException("not implemented yet");
101 }
102
103
104 public UnivariateDifferentiableFunction differentiate(final UnivariateFunction d) {
105
106
107 final Class<? extends UnivariateDifferentiableFunction> derivativeClass =
108 getDerivativeClass(d.getClass());
109
110 try {
111
112
113 final Constructor<? extends UnivariateDifferentiableFunction> constructor =
114 derivativeClass.getConstructor(d.getClass());
115 return constructor.newInstance(d);
116
117 } catch (InstantiationException ie) {
118 throw new DifferentiationException(NablaMessages.CANNOT_INSTANTIATE_ABSTRACT_CLASS,
119 derivativeClass.getName(), ie.getMessage());
120 } catch (IllegalAccessException iae) {
121 throw new DifferentiationException(NablaMessages.ILLEGAL_ACCESS_TO_CONSTRUCTOR,
122 derivativeClass.getName(), iae.getMessage());
123 } catch (NoSuchMethodException nsme) {
124 throw new DifferentiationException(NablaMessages.CANNOT_BUILD_CLASS_FROM_OTHER_CLASS,
125 derivativeClass.getName(), d.getClass().getName(), nsme.getMessage());
126 } catch (InvocationTargetException ite) {
127 throw new DifferentiationException(NablaMessages.CANNOT_INSTANTIATE_CLASS_FROM_OTHER_INSTANCE,
128 derivativeClass.getName(), d.getClass().getName(), ite.getMessage());
129 } catch (VerifyError ve) {
130 throw new DifferentiationException(NablaMessages.INCORRECT_GENERATED_CODE,
131 derivativeClass.getName(), d.getClass().getName(), ve.getMessage());
132 }
133
134 }
135
136
137
138
139
140
141
142
143 private Class<? extends UnivariateDifferentiableFunction>
144 getDerivativeClass(final Class<? extends UnivariateFunction> differentiableClass)
145 throws DifferentiationException {
146
147
148 Class<? extends UnivariateDifferentiableFunction> derivativeClass =
149 map.get(differentiableClass);
150
151
152 if (derivativeClass == null) {
153
154
155 derivativeClass = createDerivativeClass(differentiableClass);
156
157
158 map.put(differentiableClass, derivativeClass);
159
160 }
161
162
163 return derivativeClass;
164
165 }
166
167
168
169
170
171
172 private Class<? extends UnivariateDifferentiableFunction>
173 createDerivativeClass(final Class<? extends UnivariateFunction> differentiableClass)
174 throws DifferentiationException {
175 try {
176
177
178 final ClassDifferentiator differentiator =
179 new ClassDifferentiator(differentiableClass, mathClasses);
180 final Type dsType = Type.getType(DerivativeStructure.class);
181 differentiator.differentiateMethod("value",
182 Type.getMethodDescriptor(Type.DOUBLE_TYPE, Type.DOUBLE_TYPE),
183 Type.getMethodDescriptor(dsType, dsType));
184
185
186 final ClassNode derived = differentiator.getDerivedClass();
187 final ClassWriter writer = new ClassWriter(ClassWriter.COMPUTE_FRAMES);
188 final String name = derived.name.replace('/', '.');
189 derived.accept(writer);
190 final byte[] bytecode = writer.toByteArray();
191
192 final Class<? extends UnivariateDifferentiableFunction> dClass =
193 new DerivativeLoader(differentiableClass).defineClass(name, bytecode);
194 byteCodeMap.put(name, bytecode);
195 return dClass;
196
197 } catch (IOException ioe) {
198 throw new DifferentiationException(NablaMessages.CANNOT_READ_CLASS,
199 differentiableClass.getName(), ioe.getMessage());
200 }
201 }
202
203
204 private static class DerivativeLoader extends ClassLoader {
205
206
207
208
209 public DerivativeLoader(final Class<? extends UnivariateFunction> differentiableClass) {
210 super(differentiableClass.getClassLoader());
211 }
212
213
214
215
216
217
218 @SuppressWarnings("unchecked")
219 public Class<? extends UnivariateDifferentiableFunction>
220 defineClass(final String name, final byte[] bytecode) {
221 return (Class<? extends UnivariateDifferentiableFunction>) defineClass(name, bytecode, 0, bytecode.length);
222 }
223 }
224
225 }