001 /* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 package org.apache.commons.nabla.forward.analysis; 018 019 import java.io.IOException; 020 import java.io.InputStream; 021 import java.lang.reflect.Field; 022 import java.util.Set; 023 024 import org.apache.commons.math3.analysis.UnivariateFunction; 025 import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction; 026 import org.apache.commons.nabla.DifferentiationException; 027 import org.apache.commons.nabla.NablaMessages; 028 import org.objectweb.asm.ClassReader; 029 import org.objectweb.asm.Label; 030 import org.objectweb.asm.Opcodes; 031 import org.objectweb.asm.Type; 032 import org.objectweb.asm.tree.ClassNode; 033 import org.objectweb.asm.tree.FieldNode; 034 import org.objectweb.asm.tree.MethodNode; 035 036 /** 037 * Differentiator for classes using forward mode. 038 * <p> 039 * This differentiator transforms classes implementing the 040 * {@link UnivariateFunction UnivariateFunction} interface and convert 041 * them to classes implementing the {@link UnivariateDifferentiableFunction 042 * UnivariateDifferentiableFunction} interface. 043 * </p> 044 * <p> 045 * The differentiator creates a new class in the same package as the primitive class and 046 * which only preserve a private reference to the primitive instance. They access the 047 * current value of all necessary primitive instance fields thanks to reflection and 048 * bypassing access restrictions. 049 * </p> 050 * <p> 051 * The original class bytecode is not changed at all. 052 * </p> 053 * @version $Id$ 054 */ 055 public class ClassDifferentiator { 056 057 /** Name for the primitive instance field. */ 058 private static final String PRIMITIVE_FIELD = "primitive"; 059 060 /** Name fo the constructor methods. */ 061 private static final String INIT = "<init>"; 062 063 /** Math implementation classes. */ 064 private final Set<String> mathClasses; 065 066 /** Class to differentiate. */ 067 private final Class<? extends UnivariateFunction> primitiveClass; 068 069 /** Node of the class to differentiate. */ 070 private final ClassNode primitiveNode; 071 072 /** Class to differentiate. */ 073 private final ClassNode classNode; 074 075 /** 076 * Simple constructor. 077 * @param primitiveClass primitive class 078 * @param mathClasses math implementation classes 079 * @exception DifferentiationException if class cannot be differentiated 080 * @throws IOException if class cannot be read 081 */ 082 public ClassDifferentiator(final Class<? extends UnivariateFunction> primitiveClass, 083 final Set<String> mathClasses) 084 throws DifferentiationException, IOException { 085 086 // get the original class 087 this.primitiveClass = primitiveClass; 088 final String classResourceName = "/" + primitiveClass.getName().replace('.', '/') + ".class"; 089 final InputStream stream = primitiveClass.getResourceAsStream(classResourceName); 090 final ClassReader reader = new ClassReader(stream); 091 primitiveNode = new ClassNode(Opcodes.ASM4); 092 reader.accept(primitiveNode, ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES); 093 this.mathClasses = mathClasses; 094 classNode = new ClassNode(Opcodes.ASM4); 095 096 // check the UnivariateFunction interface is implemented 097 final Class<UnivariateFunction> uFuncClass = UnivariateFunction.class; 098 boolean isDifferentiable = false; 099 for (String interf : primitiveNode.interfaces) { 100 final String interfName = interf.replace('/', '.'); 101 Class<?> interfClass = null; 102 try { 103 interfClass = Class.forName(interfName); 104 } catch (ClassNotFoundException cnfe) { 105 // this should never occur since class has already been loaded 106 // and an instance already exists ... 107 throw new DifferentiationException(NablaMessages.INTERFACE_NOT_FOUND_WHILE_DIFFERENTIATING, 108 interfName, primitiveNode.name); 109 } 110 if (interfClass != null) { 111 isDifferentiable = isDifferentiable || uFuncClass.isAssignableFrom(interfClass); 112 } 113 } 114 115 if (!isDifferentiable) { 116 throw new DifferentiationException(NablaMessages.CLASS_DOES_NOT_IMPLEMENT_INTERFACE, 117 primitiveNode.name, uFuncClass.getName()); 118 } 119 120 // change the class properties for the derived class 121 classNode.visit(primitiveNode.version, Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC, 122 primitiveNode.name + "_NablaForwardModeUnivariateDerivative", 123 null, Type.getType(Object.class).getInternalName(), 124 new String[] { 125 UnivariateDifferentiableFunction.class.getName().replace('.', '/') 126 }); 127 128 // add boilerplate code 129 addPrimitiveField(); 130 addConstructor(); 131 addGetPrimitiveFieldMethod(); 132 133 } 134 135 /** 136 * Differentiate a method. 137 * @param name of the method 138 * @param primitiveDesc descriptor of the method in the primitive class 139 * @param derivativeDesc descriptor of the method in the derivative class 140 * @exception DifferentiationException if method cannot be differentiated 141 */ 142 public void differentiateMethod(final String name, final String primitiveDesc, 143 final String derivativeDesc) 144 throws DifferentiationException { 145 146 for (final MethodNode method : primitiveNode.methods) { 147 if (method.name.equals(name) && method.desc.equals(primitiveDesc)) { 148 149 final MethodDifferentiator differentiator = 150 new MethodDifferentiator(mathClasses, classNode.name); 151 differentiator.differentiate(primitiveNode.name, method); 152 classNode.methods.add(method); 153 154 } 155 } 156 } 157 158 /** 159 * Get the derived class. 160 * @return derived class 161 */ 162 public ClassNode getDerivedClass() { 163 return classNode; 164 } 165 166 /** Add the primitive field. 167 */ 168 private void addPrimitiveField() { 169 final FieldNode primitiveField = 170 new FieldNode(Opcodes.ACC_PRIVATE | Opcodes.ACC_FINAL | Opcodes.ACC_SYNTHETIC, 171 PRIMITIVE_FIELD, Type.getDescriptor(primitiveClass), null, null); 172 classNode.fields.add(primitiveField); 173 } 174 175 /** Add the class constructor. 176 */ 177 private void addConstructor() { 178 final MethodNode constructor = 179 new MethodNode(Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC, INIT, 180 Type.getMethodDescriptor(Type.VOID_TYPE, Type.getType(primitiveClass)), 181 null, null); 182 constructor.visitVarInsn(Opcodes.ALOAD, 0); 183 constructor.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getType(Object.class).getInternalName(), 184 INIT, "()V"); 185 constructor.visitVarInsn(Opcodes.ALOAD, 0); 186 constructor.visitVarInsn(Opcodes.ALOAD, 1); 187 constructor.visitFieldInsn(Opcodes.PUTFIELD, classNode.name, PRIMITIVE_FIELD, 188 Type.getDescriptor(primitiveClass)); 189 constructor.visitInsn(Opcodes.RETURN); 190 constructor.visitMaxs(0, 0); 191 classNode.methods.add(constructor); 192 } 193 194 /** Add the getPrimitiveField method. 195 */ 196 private void addGetPrimitiveFieldMethod() { 197 final MethodNode method = 198 new MethodNode(Opcodes.ACC_PRIVATE | Opcodes.ACC_SYNTHETIC, "getPrimitiveField", 199 Type.getMethodDescriptor(Type.getType(Object.class), Type.getType(String.class)), 200 null, null); 201 final Label start = new Label(); 202 final Label end = new Label(); 203 method.visitTryCatchBlock(start, end, end, Type.getInternalName(IllegalAccessException.class)); 204 method.visitTryCatchBlock(start, end, end, Type.getInternalName(NoSuchFieldException.class)); 205 method.visitLabel(start); 206 method.visitLdcInsn(Type.getType(primitiveClass)); 207 method.visitVarInsn(Opcodes.ALOAD, 1); 208 method.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Class.class), 209 "getDeclaredField", 210 Type.getMethodDescriptor(Type.getType(Field.class), Type.getType(String.class))); 211 method.visitVarInsn(Opcodes.ASTORE, 2); 212 method.visitVarInsn(Opcodes.ALOAD, 2); 213 method.visitInsn(Opcodes.ICONST_1); 214 method.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Field.class), 215 "setAccessible", 216 Type.getMethodDescriptor(Type.VOID_TYPE, Type.BOOLEAN_TYPE)); 217 method.visitVarInsn(Opcodes.ALOAD, 2); 218 method.visitVarInsn(Opcodes.ALOAD, 0); 219 method.visitFieldInsn(Opcodes.GETFIELD, classNode.name, PRIMITIVE_FIELD, 220 Type.getDescriptor(primitiveClass)); 221 method.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Field.class), 222 "get", 223 Type.getMethodDescriptor(Type.getType(Object.class), Type.getType(Object.class))); 224 method.visitInsn(Opcodes.ARETURN); 225 method.visitLabel(end); 226 method.visitVarInsn(Opcodes.ASTORE, 2); 227 method.visitTypeInsn(Opcodes.NEW, Type.getInternalName(RuntimeException.class)); 228 method.visitInsn(Opcodes.DUP); 229 method.visitVarInsn(Opcodes.ALOAD, 2); 230 method.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getInternalName(RuntimeException.class), 231 INIT, Type.getMethodDescriptor(Type.VOID_TYPE, Type.getType(Throwable.class))); 232 method.visitInsn(Opcodes.ATHROW); 233 classNode.methods.add(method); 234 } 235 236 }