View Javadoc

1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.nabla.forward.analysis;
18  
19  import java.io.IOException;
20  import java.io.InputStream;
21  import java.lang.reflect.Field;
22  import java.util.Set;
23  
24  import org.apache.commons.math3.analysis.UnivariateFunction;
25  import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction;
26  import org.apache.commons.nabla.DifferentiationException;
27  import org.apache.commons.nabla.NablaMessages;
28  import org.objectweb.asm.ClassReader;
29  import org.objectweb.asm.Label;
30  import org.objectweb.asm.Opcodes;
31  import org.objectweb.asm.Type;
32  import org.objectweb.asm.tree.ClassNode;
33  import org.objectweb.asm.tree.FieldNode;
34  import org.objectweb.asm.tree.MethodNode;
35  
36  /**
37   * Differentiator for classes using forward mode.
38   * <p>
39   * This differentiator transforms classes implementing the
40   * {@link UnivariateFunction UnivariateFunction} interface and convert
41   * them to classes implementing the {@link UnivariateDifferentiableFunction
42   * UnivariateDifferentiableFunction} interface.
43   * </p>
44   * <p>
45   * The differentiator creates a new class in the same package as the primitive class and
46   * which only preserve a private reference to the primitive instance. They access the
47   * current value of all necessary primitive instance fields thanks to reflection and
48   * bypassing access restrictions.
49   * </p>
50   * <p>
51   * The original class bytecode is not changed at all.
52   * </p>
53   * @version $Id$
54   */
55  public class ClassDifferentiator {
56  
57      /** Name for the primitive instance field. */
58      private static final String PRIMITIVE_FIELD = "primitive";
59  
60      /** Name fo the constructor methods. */
61      private static final String INIT = "<init>";
62  
63      /** Math implementation classes. */
64      private final Set<String> mathClasses;
65  
66      /** Class to differentiate. */
67      private final Class<? extends UnivariateFunction> primitiveClass;
68  
69      /** Node of the class to differentiate. */
70      private final ClassNode primitiveNode;
71  
72      /** Class to differentiate. */
73      private final ClassNode classNode;
74  
75      /**
76       * Simple constructor.
77       * @param primitiveClass primitive class
78       * @param mathClasses math implementation classes
79       * @exception DifferentiationException if class cannot be differentiated
80       * @throws IOException if class cannot be read
81       */
82      public ClassDifferentiator(final Class<? extends UnivariateFunction> primitiveClass,
83                                 final Set<String> mathClasses)
84          throws DifferentiationException, IOException {
85  
86          // get the original class
87          this.primitiveClass = primitiveClass;
88          final String classResourceName = "/" + primitiveClass.getName().replace('.', '/') + ".class";
89          final InputStream stream = primitiveClass.getResourceAsStream(classResourceName);
90          final ClassReader reader = new ClassReader(stream);
91          primitiveNode = new ClassNode(Opcodes.ASM4);
92          reader.accept(primitiveNode, ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES);
93          this.mathClasses = mathClasses;
94          classNode = new ClassNode(Opcodes.ASM4);
95  
96          // check the UnivariateFunction interface is implemented
97          final Class<UnivariateFunction> uFuncClass = UnivariateFunction.class;
98          boolean isDifferentiable = false;
99          for (String interf : primitiveNode.interfaces) {
100             final String interfName = interf.replace('/', '.');
101             Class<?> interfClass = null;
102             try {
103                 interfClass = Class.forName(interfName);
104             } catch (ClassNotFoundException cnfe) {
105                 // this should never occur since class has already been loaded
106                 // and an instance already exists ...
107                 throw new DifferentiationException(NablaMessages.INTERFACE_NOT_FOUND_WHILE_DIFFERENTIATING,
108                                                    interfName, primitiveNode.name);
109             }
110             if (interfClass != null) {
111                 isDifferentiable = isDifferentiable || uFuncClass.isAssignableFrom(interfClass);
112             }
113         }
114 
115         if (!isDifferentiable) {
116             throw new DifferentiationException(NablaMessages.CLASS_DOES_NOT_IMPLEMENT_INTERFACE,
117                                                primitiveNode.name, uFuncClass.getName());
118         }
119 
120         // change the class properties for the derived class
121         classNode.visit(primitiveNode.version, Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC,
122                         primitiveNode.name + "_NablaForwardModeUnivariateDerivative",
123                         null, Type.getType(Object.class).getInternalName(),
124                         new String[] {
125                             UnivariateDifferentiableFunction.class.getName().replace('.', '/')
126                         });
127 
128         // add boilerplate code
129         addPrimitiveField();
130         addConstructor();
131         addGetPrimitiveFieldMethod();
132 
133     }
134 
135     /**
136      * Differentiate a method.
137      * @param name of the method
138      * @param primitiveDesc descriptor of the method in the primitive class
139      * @param derivativeDesc descriptor of the method in the derivative class
140      * @exception DifferentiationException if method cannot be differentiated
141      */
142     public void differentiateMethod(final String name, final String primitiveDesc,
143                                     final String derivativeDesc)
144         throws DifferentiationException {
145 
146         for (final MethodNode method : primitiveNode.methods) {
147             if (method.name.equals(name) && method.desc.equals(primitiveDesc)) {
148 
149                 final MethodDifferentiator differentiator =
150                         new MethodDifferentiator(mathClasses, classNode.name);
151                 differentiator.differentiate(primitiveNode.name, method);
152                 classNode.methods.add(method);
153 
154             }
155         }
156     }
157 
158     /**
159      * Get the derived class.
160      * @return derived class
161      */
162     public ClassNode getDerivedClass() {
163         return classNode;
164     }
165 
166     /** Add the primitive field.
167      */
168     private void addPrimitiveField() {
169         final FieldNode primitiveField =
170             new FieldNode(Opcodes.ACC_PRIVATE | Opcodes.ACC_FINAL | Opcodes.ACC_SYNTHETIC,
171                           PRIMITIVE_FIELD, Type.getDescriptor(primitiveClass), null, null);
172         classNode.fields.add(primitiveField);
173     }
174 
175     /** Add the class constructor.
176      */
177     private void addConstructor() {
178         final MethodNode constructor =
179             new MethodNode(Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC, INIT,
180                            Type.getMethodDescriptor(Type.VOID_TYPE, Type.getType(primitiveClass)),
181                            null, null);
182         constructor.visitVarInsn(Opcodes.ALOAD, 0);
183         constructor.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getType(Object.class).getInternalName(),
184                                     INIT, "()V");
185         constructor.visitVarInsn(Opcodes.ALOAD, 0);
186         constructor.visitVarInsn(Opcodes.ALOAD, 1);
187         constructor.visitFieldInsn(Opcodes.PUTFIELD, classNode.name, PRIMITIVE_FIELD,
188                                    Type.getDescriptor(primitiveClass));
189         constructor.visitInsn(Opcodes.RETURN);
190         constructor.visitMaxs(0, 0);
191         classNode.methods.add(constructor);
192     }
193 
194     /** Add the getPrimitiveField method.
195      */
196     private void addGetPrimitiveFieldMethod() {
197         final MethodNode method =
198             new MethodNode(Opcodes.ACC_PRIVATE | Opcodes.ACC_SYNTHETIC, "getPrimitiveField",
199                            Type.getMethodDescriptor(Type.getType(Object.class), Type.getType(String.class)),
200                            null, null);
201         final Label start     = new Label();
202         final Label end       = new Label();
203         method.visitTryCatchBlock(start, end, end, Type.getInternalName(IllegalAccessException.class));
204         method.visitTryCatchBlock(start, end, end, Type.getInternalName(NoSuchFieldException.class));
205         method.visitLabel(start);
206         method.visitLdcInsn(Type.getType(primitiveClass));
207         method.visitVarInsn(Opcodes.ALOAD, 1);
208         method.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Class.class),
209                                "getDeclaredField",
210                                Type.getMethodDescriptor(Type.getType(Field.class), Type.getType(String.class)));
211         method.visitVarInsn(Opcodes.ASTORE, 2);
212         method.visitVarInsn(Opcodes.ALOAD, 2);
213         method.visitInsn(Opcodes.ICONST_1);
214         method.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Field.class),
215                                "setAccessible",
216                                Type.getMethodDescriptor(Type.VOID_TYPE, Type.BOOLEAN_TYPE));
217         method.visitVarInsn(Opcodes.ALOAD, 2);
218         method.visitVarInsn(Opcodes.ALOAD, 0);
219         method.visitFieldInsn(Opcodes.GETFIELD, classNode.name, PRIMITIVE_FIELD,
220                               Type.getDescriptor(primitiveClass));
221         method.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Field.class),
222                                "get",
223                                Type.getMethodDescriptor(Type.getType(Object.class), Type.getType(Object.class)));
224         method.visitInsn(Opcodes.ARETURN);
225         method.visitLabel(end);
226         method.visitVarInsn(Opcodes.ASTORE, 2);
227         method.visitTypeInsn(Opcodes.NEW, Type.getInternalName(RuntimeException.class));
228         method.visitInsn(Opcodes.DUP);
229         method.visitVarInsn(Opcodes.ALOAD, 2);
230         method.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getInternalName(RuntimeException.class),
231                                INIT, Type.getMethodDescriptor(Type.VOID_TYPE, Type.getType(Throwable.class)));
232         method.visitInsn(Opcodes.ATHROW);
233         classNode.methods.add(method);
234     }
235 
236 }