001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    package org.apache.commons.nabla.algorithmic.forward.analysis;
018    
019    import java.util.Set;
020    
021    import org.apache.commons.nabla.core.DifferentiationException;
022    import org.apache.commons.nabla.core.UnivariateDerivative;
023    import org.apache.commons.nabla.core.UnivariateDifferentiable;
024    import org.objectweb.asm.AnnotationVisitor;
025    import org.objectweb.asm.Attribute;
026    import org.objectweb.asm.ClassVisitor;
027    import org.objectweb.asm.FieldVisitor;
028    import org.objectweb.asm.MethodVisitor;
029    import org.objectweb.asm.Opcodes;
030    
031    /**
032     * Visitor (in asm sense) for differentiating classes.
033     * <p>
034     * This visitor visits classes implementing the
035     * {@link UnivariateDifferentiable UnivariateDifferentiable} interface and convert
036     * them to classes implementing the {@link UnivariateDerivative
037     * UnivariateDerivative} interface.
038     * </p>
039     * <p>
040     * The visitor creates a new class as an inner class of the visited class.
041     * Instances of the generated class are therefore automatically bound to their
042     * primitive instance which is their directly enclosing instance. As such they
043     * have access to the current value of all fields.
044     * </p>
045     * <p>
046     * The visited class bytecode is not changed at all.
047     * </p>
048     */
049    public class ClassDifferentiator implements ClassVisitor {
050    
051        /** Name for the primitive instance field. */
052        private static final String PRIMITIVE_FIELD = "primitive";
053    
054        /** Math implementation classes. */
055        private final Set<String> mathClasses;
056    
057        /** Class generating visitor. */
058        private final ClassVisitor generator;
059    
060        /** Error reporter. */
061        private final ErrorReporter errorReporter;
062    
063        /** Primitive class name. */
064        private String primitiveName;
065    
066        /** Descriptor for the primitive class. */
067        private String primitiveDesc;
068    
069        /** Derivative class name. */
070        private String derivativeName;
071    
072        /** Indicator for specific fields and method addition. */
073        private boolean specificMembersAdded;
074    
075        /**
076         * Simple constructor.
077         * @param mathClasses math implementation classes
078         * @param generator visitor to which class generation calls will be delegated
079         */
080        public ClassDifferentiator(final Set<String> mathClasses,
081                              final ClassVisitor generator) {
082            this.mathClasses = mathClasses;
083            this.generator   = generator;
084            errorReporter    = new ErrorReporter();
085        }
086    
087        /**
088         * Get the name of the derivative class.
089         * @return name of the (generated) derivative class
090         */
091        public String getDerivativeClassName() {
092            return derivativeName;
093        }
094    
095        /** {@inheritDoc} */
096        public void visit(final int version, final int access,
097                          final String name, final String signature,
098                          final String superName, final String[] interfaces) {
099            // set up the various names
100            primitiveName = name;
101            derivativeName   = primitiveName + "$NablaUnivariateDerivative";
102            primitiveDesc = "L" + primitiveName + ";";
103    
104            // check the UnivariateDifferentiable interface is implemented
105            final Class<UnivariateDifferentiable> uDerClass = UnivariateDifferentiable.class;
106            boolean isDifferentiable = false;
107            for (String interf : interfaces) {
108                final String interfName = interf.replace('/', '.');
109                Class<?> interfClass = null;
110                try {
111                    interfClass = Class.forName(interfName);
112                } catch (ClassNotFoundException cnfe) {
113                    // this should never occur since class has already been loaded
114                    // and an instance already exists ...
115                    errorReporter.register(new DifferentiationException("interface {0} not found " +
116                                                                        "while differentiating class {1}",
117                                                                        interfName, name));
118                }
119                if (interfClass != null) {
120                    isDifferentiable = isDifferentiable || uDerClass.isAssignableFrom(interfClass);
121                }
122            }
123    
124            if (isDifferentiable) {
125                // generate the new class implementing the UnivariateDerivative interface
126                generator.visit(version, Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC,
127                                derivativeName, signature, superName,
128                                new String[] {
129                                    UnivariateDerivative.class.getName().replace('.', '/')
130                                });
131            } else {
132                errorReporter.register(new DifferentiationException("the {0} class does not implement " +
133                                                                    "the {1} interface",
134                                                                    name, uDerClass.getName()));
135            }
136    
137            specificMembersAdded = false;
138    
139        }
140    
141        /** {@inheritDoc} */
142        public MethodVisitor visitMethod(final int access, final String name,
143                                         final String desc, final String signature,
144                                         final String[] exceptions) {
145    
146            // don't do anything if an error has already been encountered
147            if (errorReporter.hasError()) {
148                return null;
149            }
150    
151            if (!specificMembersAdded) {
152                // add the specific members we need
153                addPrimitiveField();
154                addConstructor();
155                addGetPrimitive();
156                specificMembersAdded = true;
157            }
158    
159            // is it the "public double f(double)" method we want to differentiate ?
160            if (((access & Opcodes.ACC_PUBLIC) == Opcodes.ACC_PUBLIC) &&
161                    "f".equals(name) && "(D)D".equals(desc) &&
162                    ((exceptions == null) || (exceptions.length == 0))) {
163    
164                // get a generator for the method we are going to create
165                final MethodVisitor visitor =
166                    generator.visitMethod(access | Opcodes.ACC_SYNTHETIC, name,
167                                          MethodDifferentiator.DP_RETURN_DP_DESCRIPTOR, null, null);
168    
169                // make sure our own differentiator will be used to transform the code
170                return new MethodDifferentiator(access, name, desc, signature, exceptions,
171                                           visitor, primitiveName, mathClasses, errorReporter);
172    
173            }
174    
175            // we are not interested in this method
176            return null;
177    
178        }
179    
180        /** {@inheritDoc} */
181        public FieldVisitor visitField(final int access, final String name,
182                                       final String desc, final  String signature,
183                                       final Object value) {
184            // we are not interested in any fields
185            return null;
186        }
187    
188        /** {@inheritDoc} */
189        public void visitSource(final String source, final String debug) {
190        }
191    
192        /** {@inheritDoc} */
193        public void visitOuterClass(final String owner, final String name,
194                                    final String desc) {
195        }
196    
197        /** {@inheritDoc} */
198        public AnnotationVisitor visitAnnotation(final String desc,
199                                                 final boolean visible) {
200            return null;
201        }
202    
203        /** {@inheritDoc} */
204        public void visitAttribute(final Attribute attr) {
205        }
206    
207        /** {@inheritDoc} */
208        public void visitInnerClass(final String name, final String outerName,
209                                    final String innerName, final int access) {
210        }
211    
212        /** {@inheritDoc} */
213        public void visitEnd() {
214    
215            // don't do anything if an error has already been encountered
216            if (errorReporter.hasError()) {
217                return;
218            }
219    
220            generator.visitEnd();
221    
222        }
223    
224        /** Add the primitive field.
225         */
226        private void addPrimitiveField() {
227            final FieldVisitor visitor =
228                generator.visitField(Opcodes.ACC_PRIVATE | Opcodes.ACC_FINAL | Opcodes.ACC_SYNTHETIC,
229                                     PRIMITIVE_FIELD, primitiveDesc, null, null);
230            visitor.visitEnd();
231        }
232    
233        /** Add the class constructor.
234         */
235        private void addConstructor() {
236            final String init = "<init>";
237            final MethodVisitor visitor =
238                generator.visitMethod(Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC, init,
239                                      "(" + primitiveDesc + ")V", null, null);
240            visitor.visitCode();
241            visitor.visitVarInsn(Opcodes.ALOAD, 0);
242            visitor.visitMethodInsn(Opcodes.INVOKESPECIAL, "java/lang/Object", init, "()V");
243            visitor.visitVarInsn(Opcodes.ALOAD, 0);
244            visitor.visitVarInsn(Opcodes.ALOAD, 1);
245            visitor.visitFieldInsn(Opcodes.PUTFIELD, derivativeName, PRIMITIVE_FIELD, primitiveDesc);
246            visitor.visitInsn(Opcodes.RETURN);
247            visitor.visitMaxs(0, 0);
248            visitor.visitEnd();
249        }
250    
251        /** Add the {@link UnivariateDerivative#getPrimitive() getPrimitive()} method.
252         */
253        private void addGetPrimitive() {
254            final MethodVisitor visitor =
255                generator.visitMethod(Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC, "getPrimitive",
256                                      "()" + primitiveDesc, null, null);
257            visitor.visitCode();
258            visitor.visitVarInsn(Opcodes.ALOAD, 0);
259            visitor.visitFieldInsn(Opcodes.GETFIELD, derivativeName, PRIMITIVE_FIELD, primitiveDesc);
260            visitor.visitInsn(Opcodes.ARETURN);
261            visitor.visitMaxs(0, 0);
262            visitor.visitEnd();
263        }
264    
265        /** Report the errors that may have occurred during analysis.
266         * @exception DifferentiationException if the derivative class
267         * could not be generated
268         */
269        public void reportErrors() throws DifferentiationException {
270            errorReporter.reportErrors();
271        }
272    
273    }