001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    package org.apache.commons.nabla.forward.analysis;
018    
019    import java.io.IOException;
020    import java.io.InputStream;
021    import java.lang.reflect.Field;
022    import java.util.Set;
023    
024    import org.apache.commons.math3.analysis.UnivariateFunction;
025    import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction;
026    import org.apache.commons.nabla.DifferentiationException;
027    import org.apache.commons.nabla.NablaMessages;
028    import org.objectweb.asm.ClassReader;
029    import org.objectweb.asm.Label;
030    import org.objectweb.asm.Opcodes;
031    import org.objectweb.asm.Type;
032    import org.objectweb.asm.tree.ClassNode;
033    import org.objectweb.asm.tree.FieldNode;
034    import org.objectweb.asm.tree.MethodNode;
035    
036    /**
037     * Differentiator for classes using forward mode.
038     * <p>
039     * This differentiator transforms classes implementing the
040     * {@link UnivariateFunction UnivariateFunction} interface and convert
041     * them to classes implementing the {@link UnivariateDifferentiableFunction
042     * UnivariateDifferentiableFunction} interface.
043     * </p>
044     * <p>
045     * The differentiator creates a new class in the same package as the primitive class and
046     * which only preserve a private reference to the primitive instance. They access the
047     * current value of all necessary primitive instance fields thanks to reflection and
048     * bypassing access restrictions.
049     * </p>
050     * <p>
051     * The original class bytecode is not changed at all.
052     * </p>
053     * @version $Id$
054     */
055    public class ClassDifferentiator {
056    
057        /** Name for the primitive instance field. */
058        private static final String PRIMITIVE_FIELD = "primitive";
059    
060        /** Name fo the constructor methods. */
061        private static final String INIT = "<init>";
062    
063        /** Math implementation classes. */
064        private final Set<String> mathClasses;
065    
066        /** Class to differentiate. */
067        private final Class<? extends UnivariateFunction> primitiveClass;
068    
069        /** Node of the class to differentiate. */
070        private final ClassNode primitiveNode;
071    
072        /** Class to differentiate. */
073        private final ClassNode classNode;
074    
075        /**
076         * Simple constructor.
077         * @param primitiveClass primitive class
078         * @param mathClasses math implementation classes
079         * @exception DifferentiationException if class cannot be differentiated
080         * @throws IOException if class cannot be read
081         */
082        public ClassDifferentiator(final Class<? extends UnivariateFunction> primitiveClass,
083                                   final Set<String> mathClasses)
084            throws DifferentiationException, IOException {
085    
086            // get the original class
087            this.primitiveClass = primitiveClass;
088            final String classResourceName = "/" + primitiveClass.getName().replace('.', '/') + ".class";
089            final InputStream stream = primitiveClass.getResourceAsStream(classResourceName);
090            final ClassReader reader = new ClassReader(stream);
091            primitiveNode = new ClassNode(Opcodes.ASM4);
092            reader.accept(primitiveNode, ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES);
093            this.mathClasses = mathClasses;
094            classNode = new ClassNode(Opcodes.ASM4);
095    
096            // check the UnivariateFunction interface is implemented
097            final Class<UnivariateFunction> uFuncClass = UnivariateFunction.class;
098            boolean isDifferentiable = false;
099            for (String interf : primitiveNode.interfaces) {
100                final String interfName = interf.replace('/', '.');
101                Class<?> interfClass = null;
102                try {
103                    interfClass = Class.forName(interfName);
104                } catch (ClassNotFoundException cnfe) {
105                    // this should never occur since class has already been loaded
106                    // and an instance already exists ...
107                    throw new DifferentiationException(NablaMessages.INTERFACE_NOT_FOUND_WHILE_DIFFERENTIATING,
108                                                       interfName, primitiveNode.name);
109                }
110                if (interfClass != null) {
111                    isDifferentiable = isDifferentiable || uFuncClass.isAssignableFrom(interfClass);
112                }
113            }
114    
115            if (!isDifferentiable) {
116                throw new DifferentiationException(NablaMessages.CLASS_DOES_NOT_IMPLEMENT_INTERFACE,
117                                                   primitiveNode.name, uFuncClass.getName());
118            }
119    
120            // change the class properties for the derived class
121            classNode.visit(primitiveNode.version, Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC,
122                            primitiveNode.name + "_NablaForwardModeUnivariateDerivative",
123                            null, Type.getType(Object.class).getInternalName(),
124                            new String[] {
125                                UnivariateDifferentiableFunction.class.getName().replace('.', '/')
126                            });
127    
128            // add boilerplate code
129            addPrimitiveField();
130            addConstructor();
131            addGetPrimitiveFieldMethod();
132    
133        }
134    
135        /**
136         * Differentiate a method.
137         * @param name of the method
138         * @param primitiveDesc descriptor of the method in the primitive class
139         * @param derivativeDesc descriptor of the method in the derivative class
140         * @exception DifferentiationException if method cannot be differentiated
141         */
142        public void differentiateMethod(final String name, final String primitiveDesc,
143                                        final String derivativeDesc)
144            throws DifferentiationException {
145    
146            for (final MethodNode method : primitiveNode.methods) {
147                if (method.name.equals(name) && method.desc.equals(primitiveDesc)) {
148    
149                    final MethodDifferentiator differentiator =
150                            new MethodDifferentiator(mathClasses, classNode.name);
151                    differentiator.differentiate(primitiveNode.name, method);
152                    classNode.methods.add(method);
153    
154                }
155            }
156        }
157    
158        /**
159         * Get the derived class.
160         * @return derived class
161         */
162        public ClassNode getDerivedClass() {
163            return classNode;
164        }
165    
166        /** Add the primitive field.
167         */
168        private void addPrimitiveField() {
169            final FieldNode primitiveField =
170                new FieldNode(Opcodes.ACC_PRIVATE | Opcodes.ACC_FINAL | Opcodes.ACC_SYNTHETIC,
171                              PRIMITIVE_FIELD, Type.getDescriptor(primitiveClass), null, null);
172            classNode.fields.add(primitiveField);
173        }
174    
175        /** Add the class constructor.
176         */
177        private void addConstructor() {
178            final MethodNode constructor =
179                new MethodNode(Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC, INIT,
180                               Type.getMethodDescriptor(Type.VOID_TYPE, Type.getType(primitiveClass)),
181                               null, null);
182            constructor.visitVarInsn(Opcodes.ALOAD, 0);
183            constructor.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getType(Object.class).getInternalName(),
184                                        INIT, "()V");
185            constructor.visitVarInsn(Opcodes.ALOAD, 0);
186            constructor.visitVarInsn(Opcodes.ALOAD, 1);
187            constructor.visitFieldInsn(Opcodes.PUTFIELD, classNode.name, PRIMITIVE_FIELD,
188                                       Type.getDescriptor(primitiveClass));
189            constructor.visitInsn(Opcodes.RETURN);
190            constructor.visitMaxs(0, 0);
191            classNode.methods.add(constructor);
192        }
193    
194        /** Add the getPrimitiveField method.
195         */
196        private void addGetPrimitiveFieldMethod() {
197            final MethodNode method =
198                new MethodNode(Opcodes.ACC_PRIVATE | Opcodes.ACC_SYNTHETIC, "getPrimitiveField",
199                               Type.getMethodDescriptor(Type.getType(Object.class), Type.getType(String.class)),
200                               null, null);
201            final Label start     = new Label();
202            final Label end       = new Label();
203            method.visitTryCatchBlock(start, end, end, Type.getInternalName(IllegalAccessException.class));
204            method.visitTryCatchBlock(start, end, end, Type.getInternalName(NoSuchFieldException.class));
205            method.visitLabel(start);
206            method.visitLdcInsn(Type.getType(primitiveClass));
207            method.visitVarInsn(Opcodes.ALOAD, 1);
208            method.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Class.class),
209                                   "getDeclaredField",
210                                   Type.getMethodDescriptor(Type.getType(Field.class), Type.getType(String.class)));
211            method.visitVarInsn(Opcodes.ASTORE, 2);
212            method.visitVarInsn(Opcodes.ALOAD, 2);
213            method.visitInsn(Opcodes.ICONST_1);
214            method.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Field.class),
215                                   "setAccessible",
216                                   Type.getMethodDescriptor(Type.VOID_TYPE, Type.BOOLEAN_TYPE));
217            method.visitVarInsn(Opcodes.ALOAD, 2);
218            method.visitVarInsn(Opcodes.ALOAD, 0);
219            method.visitFieldInsn(Opcodes.GETFIELD, classNode.name, PRIMITIVE_FIELD,
220                                  Type.getDescriptor(primitiveClass));
221            method.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Field.class),
222                                   "get",
223                                   Type.getMethodDescriptor(Type.getType(Object.class), Type.getType(Object.class)));
224            method.visitInsn(Opcodes.ARETURN);
225            method.visitLabel(end);
226            method.visitVarInsn(Opcodes.ASTORE, 2);
227            method.visitTypeInsn(Opcodes.NEW, Type.getInternalName(RuntimeException.class));
228            method.visitInsn(Opcodes.DUP);
229            method.visitVarInsn(Opcodes.ALOAD, 2);
230            method.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getInternalName(RuntimeException.class),
231                                   INIT, Type.getMethodDescriptor(Type.VOID_TYPE, Type.getType(Throwable.class)));
232            method.visitInsn(Opcodes.ATHROW);
233            classNode.methods.add(method);
234        }
235    
236    }