001 /* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 package org.apache.commons.nabla.algorithmic.forward.analysis; 018 019 import java.util.Set; 020 021 import org.apache.commons.nabla.core.DifferentiationException; 022 import org.apache.commons.nabla.core.UnivariateDerivative; 023 import org.apache.commons.nabla.core.UnivariateDifferentiable; 024 import org.objectweb.asm.AnnotationVisitor; 025 import org.objectweb.asm.Attribute; 026 import org.objectweb.asm.ClassVisitor; 027 import org.objectweb.asm.FieldVisitor; 028 import org.objectweb.asm.MethodVisitor; 029 import org.objectweb.asm.Opcodes; 030 031 /** 032 * Visitor (in asm sense) for differentiating classes. 033 * <p> 034 * This visitor visits classes implementing the 035 * {@link UnivariateDifferentiable UnivariateDifferentiable} interface and convert 036 * them to classes implementing the {@link UnivariateDerivative 037 * UnivariateDerivative} interface. 038 * </p> 039 * <p> 040 * The visitor creates a new class as an inner class of the visited class. 041 * Instances of the generated class are therefore automatically bound to their 042 * primitive instance which is their directly enclosing instance. As such they 043 * have access to the current value of all fields. 044 * </p> 045 * <p> 046 * The visited class bytecode is not changed at all. 047 * </p> 048 */ 049 public class ClassDifferentiator implements ClassVisitor { 050 051 /** Name for the primitive instance field. */ 052 private static final String PRIMITIVE_FIELD = "primitive"; 053 054 /** Math implementation classes. */ 055 private final Set<String> mathClasses; 056 057 /** Class generating visitor. */ 058 private final ClassVisitor generator; 059 060 /** Error reporter. */ 061 private final ErrorReporter errorReporter; 062 063 /** Primitive class name. */ 064 private String primitiveName; 065 066 /** Descriptor for the primitive class. */ 067 private String primitiveDesc; 068 069 /** Derivative class name. */ 070 private String derivativeName; 071 072 /** Indicator for specific fields and method addition. */ 073 private boolean specificMembersAdded; 074 075 /** 076 * Simple constructor. 077 * @param mathClasses math implementation classes 078 * @param generator visitor to which class generation calls will be delegated 079 */ 080 public ClassDifferentiator(final Set<String> mathClasses, 081 final ClassVisitor generator) { 082 this.mathClasses = mathClasses; 083 this.generator = generator; 084 errorReporter = new ErrorReporter(); 085 } 086 087 /** 088 * Get the name of the derivative class. 089 * @return name of the (generated) derivative class 090 */ 091 public String getDerivativeClassName() { 092 return derivativeName; 093 } 094 095 /** {@inheritDoc} */ 096 public void visit(final int version, final int access, 097 final String name, final String signature, 098 final String superName, final String[] interfaces) { 099 // set up the various names 100 primitiveName = name; 101 derivativeName = primitiveName + "$NablaUnivariateDerivative"; 102 primitiveDesc = "L" + primitiveName + ";"; 103 104 // check the UnivariateDifferentiable interface is implemented 105 final Class<UnivariateDifferentiable> uDerClass = UnivariateDifferentiable.class; 106 boolean isDifferentiable = false; 107 for (String interf : interfaces) { 108 final String interfName = interf.replace('/', '.'); 109 Class<?> interfClass = null; 110 try { 111 interfClass = Class.forName(interfName); 112 } catch (ClassNotFoundException cnfe) { 113 // this should never occur since class has already been loaded 114 // and an instance already exists ... 115 errorReporter.register(new DifferentiationException("interface {0} not found " + 116 "while differentiating class {1}", 117 interfName, name)); 118 } 119 if (interfClass != null) { 120 isDifferentiable = isDifferentiable || uDerClass.isAssignableFrom(interfClass); 121 } 122 } 123 124 if (isDifferentiable) { 125 // generate the new class implementing the UnivariateDerivative interface 126 generator.visit(version, Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC, 127 derivativeName, signature, superName, 128 new String[] { 129 UnivariateDerivative.class.getName().replace('.', '/') 130 }); 131 } else { 132 errorReporter.register(new DifferentiationException("the {0} class does not implement " + 133 "the {1} interface", 134 name, uDerClass.getName())); 135 } 136 137 specificMembersAdded = false; 138 139 } 140 141 /** {@inheritDoc} */ 142 public MethodVisitor visitMethod(final int access, final String name, 143 final String desc, final String signature, 144 final String[] exceptions) { 145 146 // don't do anything if an error has already been encountered 147 if (errorReporter.hasError()) { 148 return null; 149 } 150 151 if (!specificMembersAdded) { 152 // add the specific members we need 153 addPrimitiveField(); 154 addConstructor(); 155 addGetPrimitive(); 156 specificMembersAdded = true; 157 } 158 159 // is it the "public double f(double)" method we want to differentiate ? 160 if (((access & Opcodes.ACC_PUBLIC) == Opcodes.ACC_PUBLIC) && 161 "f".equals(name) && "(D)D".equals(desc) && 162 ((exceptions == null) || (exceptions.length == 0))) { 163 164 // get a generator for the method we are going to create 165 final MethodVisitor visitor = 166 generator.visitMethod(access | Opcodes.ACC_SYNTHETIC, name, 167 MethodDifferentiator.DP_RETURN_DP_DESCRIPTOR, null, null); 168 169 // make sure our own differentiator will be used to transform the code 170 return new MethodDifferentiator(access, name, desc, signature, exceptions, 171 visitor, primitiveName, mathClasses, errorReporter); 172 173 } 174 175 // we are not interested in this method 176 return null; 177 178 } 179 180 /** {@inheritDoc} */ 181 public FieldVisitor visitField(final int access, final String name, 182 final String desc, final String signature, 183 final Object value) { 184 // we are not interested in any fields 185 return null; 186 } 187 188 /** {@inheritDoc} */ 189 public void visitSource(final String source, final String debug) { 190 } 191 192 /** {@inheritDoc} */ 193 public void visitOuterClass(final String owner, final String name, 194 final String desc) { 195 } 196 197 /** {@inheritDoc} */ 198 public AnnotationVisitor visitAnnotation(final String desc, 199 final boolean visible) { 200 return null; 201 } 202 203 /** {@inheritDoc} */ 204 public void visitAttribute(final Attribute attr) { 205 } 206 207 /** {@inheritDoc} */ 208 public void visitInnerClass(final String name, final String outerName, 209 final String innerName, final int access) { 210 } 211 212 /** {@inheritDoc} */ 213 public void visitEnd() { 214 215 // don't do anything if an error has already been encountered 216 if (errorReporter.hasError()) { 217 return; 218 } 219 220 generator.visitEnd(); 221 222 } 223 224 /** Add the primitive field. 225 */ 226 private void addPrimitiveField() { 227 final FieldVisitor visitor = 228 generator.visitField(Opcodes.ACC_PRIVATE | Opcodes.ACC_FINAL | Opcodes.ACC_SYNTHETIC, 229 PRIMITIVE_FIELD, primitiveDesc, null, null); 230 visitor.visitEnd(); 231 } 232 233 /** Add the class constructor. 234 */ 235 private void addConstructor() { 236 final String init = "<init>"; 237 final MethodVisitor visitor = 238 generator.visitMethod(Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC, init, 239 "(" + primitiveDesc + ")V", null, null); 240 visitor.visitCode(); 241 visitor.visitVarInsn(Opcodes.ALOAD, 0); 242 visitor.visitMethodInsn(Opcodes.INVOKESPECIAL, "java/lang/Object", init, "()V"); 243 visitor.visitVarInsn(Opcodes.ALOAD, 0); 244 visitor.visitVarInsn(Opcodes.ALOAD, 1); 245 visitor.visitFieldInsn(Opcodes.PUTFIELD, derivativeName, PRIMITIVE_FIELD, primitiveDesc); 246 visitor.visitInsn(Opcodes.RETURN); 247 visitor.visitMaxs(0, 0); 248 visitor.visitEnd(); 249 } 250 251 /** Add the {@link UnivariateDerivative#getPrimitive() getPrimitive()} method. 252 */ 253 private void addGetPrimitive() { 254 final MethodVisitor visitor = 255 generator.visitMethod(Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC, "getPrimitive", 256 "()" + primitiveDesc, null, null); 257 visitor.visitCode(); 258 visitor.visitVarInsn(Opcodes.ALOAD, 0); 259 visitor.visitFieldInsn(Opcodes.GETFIELD, derivativeName, PRIMITIVE_FIELD, primitiveDesc); 260 visitor.visitInsn(Opcodes.ARETURN); 261 visitor.visitMaxs(0, 0); 262 visitor.visitEnd(); 263 } 264 265 /** Report the errors that may have occurred during analysis. 266 * @exception DifferentiationException if the derivative class 267 * could not be generated 268 */ 269 public void reportErrors() throws DifferentiationException { 270 errorReporter.reportErrors(); 271 } 272 273 }