001 /*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements. See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License. You may obtain a copy of the License at
008 *
009 * http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017 package org.apache.commons.nabla.forward.analysis;
018
019 import java.io.IOException;
020 import java.io.InputStream;
021 import java.lang.reflect.Field;
022 import java.util.Set;
023
024 import org.apache.commons.math3.analysis.UnivariateFunction;
025 import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction;
026 import org.apache.commons.nabla.DifferentiationException;
027 import org.apache.commons.nabla.NablaMessages;
028 import org.objectweb.asm.ClassReader;
029 import org.objectweb.asm.Label;
030 import org.objectweb.asm.Opcodes;
031 import org.objectweb.asm.Type;
032 import org.objectweb.asm.tree.ClassNode;
033 import org.objectweb.asm.tree.FieldNode;
034 import org.objectweb.asm.tree.MethodNode;
035
036 /**
037 * Differentiator for classes using forward mode.
038 * <p>
039 * This differentiator transforms classes implementing the
040 * {@link UnivariateFunction UnivariateFunction} interface and convert
041 * them to classes implementing the {@link UnivariateDifferentiableFunction
042 * UnivariateDifferentiableFunction} interface.
043 * </p>
044 * <p>
045 * The differentiator creates a new class in the same package as the primitive class and
046 * which only preserve a private reference to the primitive instance. They access the
047 * current value of all necessary primitive instance fields thanks to reflection and
048 * bypassing access restrictions.
049 * </p>
050 * <p>
051 * The original class bytecode is not changed at all.
052 * </p>
053 * @version $Id$
054 */
055 public class ClassDifferentiator {
056
057 /** Name for the primitive instance field. */
058 private static final String PRIMITIVE_FIELD = "primitive";
059
060 /** Name fo the constructor methods. */
061 private static final String INIT = "<init>";
062
063 /** Math implementation classes. */
064 private final Set<String> mathClasses;
065
066 /** Class to differentiate. */
067 private final Class<? extends UnivariateFunction> primitiveClass;
068
069 /** Node of the class to differentiate. */
070 private final ClassNode primitiveNode;
071
072 /** Class to differentiate. */
073 private final ClassNode classNode;
074
075 /**
076 * Simple constructor.
077 * @param primitiveClass primitive class
078 * @param mathClasses math implementation classes
079 * @exception DifferentiationException if class cannot be differentiated
080 * @throws IOException if class cannot be read
081 */
082 public ClassDifferentiator(final Class<? extends UnivariateFunction> primitiveClass,
083 final Set<String> mathClasses)
084 throws DifferentiationException, IOException {
085
086 // get the original class
087 this.primitiveClass = primitiveClass;
088 final String classResourceName = "/" + primitiveClass.getName().replace('.', '/') + ".class";
089 final InputStream stream = primitiveClass.getResourceAsStream(classResourceName);
090 final ClassReader reader = new ClassReader(stream);
091 primitiveNode = new ClassNode(Opcodes.ASM4);
092 reader.accept(primitiveNode, ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES);
093 this.mathClasses = mathClasses;
094 classNode = new ClassNode(Opcodes.ASM4);
095
096 // check the UnivariateFunction interface is implemented
097 final Class<UnivariateFunction> uFuncClass = UnivariateFunction.class;
098 boolean isDifferentiable = false;
099 for (String interf : primitiveNode.interfaces) {
100 final String interfName = interf.replace('/', '.');
101 Class<?> interfClass = null;
102 try {
103 interfClass = Class.forName(interfName);
104 } catch (ClassNotFoundException cnfe) {
105 // this should never occur since class has already been loaded
106 // and an instance already exists ...
107 throw new DifferentiationException(NablaMessages.INTERFACE_NOT_FOUND_WHILE_DIFFERENTIATING,
108 interfName, primitiveNode.name);
109 }
110 if (interfClass != null) {
111 isDifferentiable = isDifferentiable || uFuncClass.isAssignableFrom(interfClass);
112 }
113 }
114
115 if (!isDifferentiable) {
116 throw new DifferentiationException(NablaMessages.CLASS_DOES_NOT_IMPLEMENT_INTERFACE,
117 primitiveNode.name, uFuncClass.getName());
118 }
119
120 // change the class properties for the derived class
121 classNode.visit(primitiveNode.version, Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC,
122 primitiveNode.name + "_NablaForwardModeUnivariateDerivative",
123 null, Type.getType(Object.class).getInternalName(),
124 new String[] {
125 UnivariateDifferentiableFunction.class.getName().replace('.', '/')
126 });
127
128 // add boilerplate code
129 addPrimitiveField();
130 addConstructor();
131 addGetPrimitiveFieldMethod();
132
133 }
134
135 /**
136 * Differentiate a method.
137 * @param name of the method
138 * @param primitiveDesc descriptor of the method in the primitive class
139 * @param derivativeDesc descriptor of the method in the derivative class
140 * @exception DifferentiationException if method cannot be differentiated
141 */
142 public void differentiateMethod(final String name, final String primitiveDesc,
143 final String derivativeDesc)
144 throws DifferentiationException {
145
146 for (final MethodNode method : primitiveNode.methods) {
147 if (method.name.equals(name) && method.desc.equals(primitiveDesc)) {
148
149 final MethodDifferentiator differentiator =
150 new MethodDifferentiator(mathClasses, classNode.name);
151 differentiator.differentiate(primitiveNode.name, method);
152 classNode.methods.add(method);
153
154 }
155 }
156 }
157
158 /**
159 * Get the derived class.
160 * @return derived class
161 */
162 public ClassNode getDerivedClass() {
163 return classNode;
164 }
165
166 /** Add the primitive field.
167 */
168 private void addPrimitiveField() {
169 final FieldNode primitiveField =
170 new FieldNode(Opcodes.ACC_PRIVATE | Opcodes.ACC_FINAL | Opcodes.ACC_SYNTHETIC,
171 PRIMITIVE_FIELD, Type.getDescriptor(primitiveClass), null, null);
172 classNode.fields.add(primitiveField);
173 }
174
175 /** Add the class constructor.
176 */
177 private void addConstructor() {
178 final MethodNode constructor =
179 new MethodNode(Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC, INIT,
180 Type.getMethodDescriptor(Type.VOID_TYPE, Type.getType(primitiveClass)),
181 null, null);
182 constructor.visitVarInsn(Opcodes.ALOAD, 0);
183 constructor.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getType(Object.class).getInternalName(),
184 INIT, "()V");
185 constructor.visitVarInsn(Opcodes.ALOAD, 0);
186 constructor.visitVarInsn(Opcodes.ALOAD, 1);
187 constructor.visitFieldInsn(Opcodes.PUTFIELD, classNode.name, PRIMITIVE_FIELD,
188 Type.getDescriptor(primitiveClass));
189 constructor.visitInsn(Opcodes.RETURN);
190 constructor.visitMaxs(0, 0);
191 classNode.methods.add(constructor);
192 }
193
194 /** Add the getPrimitiveField method.
195 */
196 private void addGetPrimitiveFieldMethod() {
197 final MethodNode method =
198 new MethodNode(Opcodes.ACC_PRIVATE | Opcodes.ACC_SYNTHETIC, "getPrimitiveField",
199 Type.getMethodDescriptor(Type.getType(Object.class), Type.getType(String.class)),
200 null, null);
201 final Label start = new Label();
202 final Label end = new Label();
203 method.visitTryCatchBlock(start, end, end, Type.getInternalName(IllegalAccessException.class));
204 method.visitTryCatchBlock(start, end, end, Type.getInternalName(NoSuchFieldException.class));
205 method.visitLabel(start);
206 method.visitLdcInsn(Type.getType(primitiveClass));
207 method.visitVarInsn(Opcodes.ALOAD, 1);
208 method.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Class.class),
209 "getDeclaredField",
210 Type.getMethodDescriptor(Type.getType(Field.class), Type.getType(String.class)));
211 method.visitVarInsn(Opcodes.ASTORE, 2);
212 method.visitVarInsn(Opcodes.ALOAD, 2);
213 method.visitInsn(Opcodes.ICONST_1);
214 method.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Field.class),
215 "setAccessible",
216 Type.getMethodDescriptor(Type.VOID_TYPE, Type.BOOLEAN_TYPE));
217 method.visitVarInsn(Opcodes.ALOAD, 2);
218 method.visitVarInsn(Opcodes.ALOAD, 0);
219 method.visitFieldInsn(Opcodes.GETFIELD, classNode.name, PRIMITIVE_FIELD,
220 Type.getDescriptor(primitiveClass));
221 method.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Field.class),
222 "get",
223 Type.getMethodDescriptor(Type.getType(Object.class), Type.getType(Object.class)));
224 method.visitInsn(Opcodes.ARETURN);
225 method.visitLabel(end);
226 method.visitVarInsn(Opcodes.ASTORE, 2);
227 method.visitTypeInsn(Opcodes.NEW, Type.getInternalName(RuntimeException.class));
228 method.visitInsn(Opcodes.DUP);
229 method.visitVarInsn(Opcodes.ALOAD, 2);
230 method.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getInternalName(RuntimeException.class),
231 INIT, Type.getMethodDescriptor(Type.VOID_TYPE, Type.getType(Throwable.class)));
232 method.visitInsn(Opcodes.ATHROW);
233 classNode.methods.add(method);
234 }
235
236 }