001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    package org.apache.commons.nabla.algorithmic.forward.analysis;
018    
019    import java.util.ArrayList;
020    import java.util.HashMap;
021    import java.util.HashSet;
022    import java.util.IdentityHashMap;
023    import java.util.Iterator;
024    import java.util.List;
025    import java.util.Map;
026    import java.util.Set;
027    
028    import org.apache.commons.nabla.algorithmic.forward.arithmetic.DAddTransformer1;
029    import org.apache.commons.nabla.algorithmic.forward.arithmetic.DAddTransformer12;
030    import org.apache.commons.nabla.algorithmic.forward.arithmetic.DAddTransformer2;
031    import org.apache.commons.nabla.algorithmic.forward.arithmetic.DDivTransformer1;
032    import org.apache.commons.nabla.algorithmic.forward.arithmetic.DDivTransformer12;
033    import org.apache.commons.nabla.algorithmic.forward.arithmetic.DDivTransformer2;
034    import org.apache.commons.nabla.algorithmic.forward.arithmetic.DMulTransformer1;
035    import org.apache.commons.nabla.algorithmic.forward.arithmetic.DMulTransformer12;
036    import org.apache.commons.nabla.algorithmic.forward.arithmetic.DMulTransformer2;
037    import org.apache.commons.nabla.algorithmic.forward.arithmetic.DNegTransformer;
038    import org.apache.commons.nabla.algorithmic.forward.arithmetic.DRemTransformer1;
039    import org.apache.commons.nabla.algorithmic.forward.arithmetic.DRemTransformer12;
040    import org.apache.commons.nabla.algorithmic.forward.arithmetic.DRemTransformer2;
041    import org.apache.commons.nabla.algorithmic.forward.arithmetic.DSubTransformer1;
042    import org.apache.commons.nabla.algorithmic.forward.arithmetic.DSubTransformer12;
043    import org.apache.commons.nabla.algorithmic.forward.arithmetic.DSubTransformer2;
044    import org.apache.commons.nabla.algorithmic.forward.functions.AcosTransformer;
045    import org.apache.commons.nabla.algorithmic.forward.functions.AcoshTransformer;
046    import org.apache.commons.nabla.algorithmic.forward.functions.AsinTransformer;
047    import org.apache.commons.nabla.algorithmic.forward.functions.AsinhTransformer;
048    import org.apache.commons.nabla.algorithmic.forward.functions.Atan2Transformer1;
049    import org.apache.commons.nabla.algorithmic.forward.functions.Atan2Transformer12;
050    import org.apache.commons.nabla.algorithmic.forward.functions.Atan2Transformer2;
051    import org.apache.commons.nabla.algorithmic.forward.functions.AtanTransformer;
052    import org.apache.commons.nabla.algorithmic.forward.functions.AtanhTransformer;
053    import org.apache.commons.nabla.algorithmic.forward.functions.CbrtTransformer;
054    import org.apache.commons.nabla.algorithmic.forward.functions.CosTransformer;
055    import org.apache.commons.nabla.algorithmic.forward.functions.CoshTransformer;
056    import org.apache.commons.nabla.algorithmic.forward.functions.ExpTransformer;
057    import org.apache.commons.nabla.algorithmic.forward.functions.Expm1Transformer;
058    import org.apache.commons.nabla.algorithmic.forward.functions.HypotTransformer1;
059    import org.apache.commons.nabla.algorithmic.forward.functions.HypotTransformer12;
060    import org.apache.commons.nabla.algorithmic.forward.functions.HypotTransformer2;
061    import org.apache.commons.nabla.algorithmic.forward.functions.Log10Transformer;
062    import org.apache.commons.nabla.algorithmic.forward.functions.Log1pTransformer;
063    import org.apache.commons.nabla.algorithmic.forward.functions.LogTransformer;
064    import org.apache.commons.nabla.algorithmic.forward.functions.MathInvocationTransformer;
065    import org.apache.commons.nabla.algorithmic.forward.functions.PowTransformer1;
066    import org.apache.commons.nabla.algorithmic.forward.functions.PowTransformer12;
067    import org.apache.commons.nabla.algorithmic.forward.functions.PowTransformer2;
068    import org.apache.commons.nabla.algorithmic.forward.functions.SinTransformer;
069    import org.apache.commons.nabla.algorithmic.forward.functions.SinhTransformer;
070    import org.apache.commons.nabla.algorithmic.forward.functions.SqrtTransformer;
071    import org.apache.commons.nabla.algorithmic.forward.functions.TanTransformer;
072    import org.apache.commons.nabla.algorithmic.forward.functions.TanhTransformer;
073    import org.apache.commons.nabla.algorithmic.forward.instructions.DLoadTransformer;
074    import org.apache.commons.nabla.algorithmic.forward.instructions.DReturnTransformer;
075    import org.apache.commons.nabla.algorithmic.forward.instructions.DStoreTransformer;
076    import org.apache.commons.nabla.algorithmic.forward.instructions.DcmpTransformer1;
077    import org.apache.commons.nabla.algorithmic.forward.instructions.DcmpTransformer12;
078    import org.apache.commons.nabla.algorithmic.forward.instructions.DcmpTransformer2;
079    import org.apache.commons.nabla.algorithmic.forward.instructions.Dup2Transformer;
080    import org.apache.commons.nabla.algorithmic.forward.instructions.Dup2X1Transformer;
081    import org.apache.commons.nabla.algorithmic.forward.instructions.Dup2X2Transformer1;
082    import org.apache.commons.nabla.algorithmic.forward.instructions.Dup2X2Transformer12;
083    import org.apache.commons.nabla.algorithmic.forward.instructions.Dup2X2Transformer2;
084    import org.apache.commons.nabla.algorithmic.forward.instructions.NarrowingTransformer;
085    import org.apache.commons.nabla.algorithmic.forward.instructions.WideningTransformer;
086    import org.apache.commons.nabla.algorithmic.forward.trimming.DLoadPop2Trimmer;
087    import org.apache.commons.nabla.algorithmic.forward.trimming.SwappedDloadTrimmer;
088    import org.apache.commons.nabla.algorithmic.forward.trimming.SwappedDstoreTrimmer;
089    import org.apache.commons.nabla.core.DifferentialPair;
090    import org.apache.commons.nabla.core.DifferentiationException;
091    import org.objectweb.asm.MethodVisitor;
092    import org.objectweb.asm.Opcodes;
093    import org.objectweb.asm.tree.AbstractInsnNode;
094    import org.objectweb.asm.tree.FieldInsnNode;
095    import org.objectweb.asm.tree.IincInsnNode;
096    import org.objectweb.asm.tree.InsnList;
097    import org.objectweb.asm.tree.InsnNode;
098    import org.objectweb.asm.tree.LabelNode;
099    import org.objectweb.asm.tree.MethodInsnNode;
100    import org.objectweb.asm.tree.MethodNode;
101    import org.objectweb.asm.tree.VarInsnNode;
102    import org.objectweb.asm.tree.analysis.Analyzer;
103    import org.objectweb.asm.tree.analysis.AnalyzerException;
104    import org.objectweb.asm.tree.analysis.BasicValue;
105    import org.objectweb.asm.tree.analysis.Frame;
106    import org.objectweb.asm.tree.analysis.Interpreter;
107    
108    /** Class transforming a method computing a value to a method
109     * computing both a value and its differential.
110     */
111    public class MethodDifferentiator extends MethodNode {
112    
113        /** Name for the DifferentialPair class. */
114        public static final String DP_NAME = DifferentialPair.class.getName().replace('.', '/');
115    
116        /** Descriptor for the DifferentialPair class. */
117        public static final String DP_DESCRIPTOR = "L" + DP_NAME + ";";
118    
119        /** Descriptor for the derivative class f method. */
120        public static final String DP_RETURN_DP_DESCRIPTOR = "(" + DP_DESCRIPTOR + ")" + DP_DESCRIPTOR;
121    
122        /** Descriptor for <code>double f()</code> methods. */
123        private static final String VOID_RETURN_D_DESCRIPTOR = "()D";
124    
125        /** Math functions transformer. */
126        private static final Map<String, MathInvocationTransformer> MATH_TRANSFORMERS =
127            new HashMap<String, MathInvocationTransformer>();
128    
129        static {
130            MATH_TRANSFORMERS.put("acos",     new AcosTransformer());
131            MATH_TRANSFORMERS.put("acosh",    new AcoshTransformer());
132            MATH_TRANSFORMERS.put("asin",     new AsinTransformer());
133            MATH_TRANSFORMERS.put("asinh",    new AsinhTransformer());
134            MATH_TRANSFORMERS.put("atan2_12", new Atan2Transformer12());
135            MATH_TRANSFORMERS.put("atan2_1",  new Atan2Transformer1());
136            MATH_TRANSFORMERS.put("atan2_2",  new Atan2Transformer2());
137            MATH_TRANSFORMERS.put("atan",     new AtanTransformer());
138            MATH_TRANSFORMERS.put("atanh",    new AtanhTransformer());
139            MATH_TRANSFORMERS.put("cbrt",     new CbrtTransformer());
140            MATH_TRANSFORMERS.put("cos",      new CosTransformer());
141            MATH_TRANSFORMERS.put("cosh",     new CoshTransformer());
142            MATH_TRANSFORMERS.put("exp",      new ExpTransformer());
143            MATH_TRANSFORMERS.put("expm1",    new Expm1Transformer());
144            MATH_TRANSFORMERS.put("hypot_12", new HypotTransformer12());
145            MATH_TRANSFORMERS.put("hypot_1",  new HypotTransformer1());
146            MATH_TRANSFORMERS.put("hypot_2",  new HypotTransformer2());
147            MATH_TRANSFORMERS.put("log10",    new Log10Transformer());
148            MATH_TRANSFORMERS.put("log1p",    new Log1pTransformer());
149            MATH_TRANSFORMERS.put("log",      new LogTransformer());
150            MATH_TRANSFORMERS.put("pow_12",   new PowTransformer12());
151            MATH_TRANSFORMERS.put("pow_1",    new PowTransformer1());
152            MATH_TRANSFORMERS.put("pow_2",    new PowTransformer2());
153            MATH_TRANSFORMERS.put("sin",      new SinTransformer());
154            MATH_TRANSFORMERS.put("sinh",     new SinhTransformer());
155            MATH_TRANSFORMERS.put("sqrt",     new SqrtTransformer());
156            MATH_TRANSFORMERS.put("tan",      new TanTransformer());
157            MATH_TRANSFORMERS.put("tanh",     new TanhTransformer());
158        }
159    
160        /** Message format for unknown method. */
161        private static final String UNKNOWN_METHOD_FMT = "unknown method {0}.{1}";
162    
163        /** Maximal number of temporary size 2 variables. */
164        private static final int MAX_TEMP = 5;
165    
166        /** Math implementation classes. */
167        private final Set<String> mathClasses;
168    
169        /** Generator to use. */
170        private final MethodVisitor generator;
171    
172        /** Used locals variables array. */
173        private boolean[] usedLocals;
174    
175        /** Primitive class name. */
176        private final String primitiveName;
177    
178        /** Error reporter to use. */
179        private final ErrorReporter errorReporter;
180    
181        /** Set of converted values. */
182        private final Set<TrackingValue> converted;
183    
184        /** Frames for the original method. */
185        private final Map<AbstractInsnNode, Frame> frames;
186    
187        /** Instructions successors array. */
188        private final Map<AbstractInsnNode, Set<AbstractInsnNode>> successors;
189    
190        /** Cloned labels map. */
191        private final Map<LabelNode, LabelNode> clonedLabels;
192    
193        /** Build a differentiator for a method.
194         * @param access access flags of the method
195         * @param name name of the method
196         * @param desc descriptor of the method
197         * @param signature signature of the method
198         * @param exceptions exceptions thrown by the method
199         * @param generator bytecode generator to use for the transformed method
200         * @param primitiveName primitive class name
201         * @param mathClasses math implementation classes
202         * @param errorReporter reporter used for delaying exceptions
203         */
204        public MethodDifferentiator(final int access, final String name, final String desc,
205                                    final String signature, final String[] exceptions,
206                                    final MethodVisitor generator,final  String primitiveName,
207                                    final Set<String> mathClasses,
208                                    final ErrorReporter errorReporter) {
209    
210            super(access, name, desc, signature, exceptions);
211            this.generator     = generator;
212            this.usedLocals    = null;
213            this.primitiveName = primitiveName;
214            this.mathClasses   = mathClasses;
215            this.errorReporter = errorReporter;
216            this.converted     = new HashSet<TrackingValue>();
217            this.frames        = new IdentityHashMap<AbstractInsnNode, Frame>();
218            this.successors    = new IdentityHashMap<AbstractInsnNode, Set<AbstractInsnNode>>();
219            this.clonedLabels  = new HashMap<LabelNode, LabelNode>();
220    
221        }
222    
223        /** {@inheritDoc} */
224        @Override
225        public void visitEnd() {
226            try {
227    
228                // at start, "this" and one differential pair are used
229                maxLocals  = 2 * (maxLocals + MAX_TEMP) - 1;
230                usedLocals = new boolean[maxLocals];
231                useLocal(0, 1);
232                useLocal(1, 4);
233    
234                // add spare cells to hold new variables if needed
235                addSpareLocalVariables();
236    
237                // analyze the original code, tracing values production/consumption
238                final Frame[] array =
239                    new FlowAnalyzer(new TrackingInterpreter()).analyze(primitiveName, this);
240    
241                // convert the array into a map, since code changes will shift all indices
242                for (int i = 0; i < array.length; ++i) {
243                    frames.put(instructions.get(i), array[i]);
244                }
245    
246                // identify the needed changes
247                final Set<AbstractInsnNode> changes = identifyChanges();
248    
249                if (changes.isEmpty()) {
250    
251                    // the method does not depend on the parameter at all!
252                    // we replace all code by a simple "return DifferentialPair.ZERO;"
253                    instructions.clear();
254                    instructions.add(new FieldInsnNode(Opcodes.GETSTATIC, DP_NAME, "ZERO", DP_DESCRIPTOR));
255                    instructions.add(new InsnNode(Opcodes.ARETURN));
256    
257                } else {
258    
259                    // perform the code changes
260                    changeCode(changes);
261    
262                    // remove the local variables added at the beginning and not used
263                    removeUnusedSpareLocalVariables();
264    
265                    // trim generated instructions list
266                    SwappedDloadTrimmer.getInstance().trim(instructions);
267                    SwappedDstoreTrimmer.getInstance().trim(instructions);
268                    DLoadPop2Trimmer.getInstance().trim(instructions);
269    
270                }
271    
272                // change the descriptor to its true final value
273                desc = DP_RETURN_DP_DESCRIPTOR;
274    
275                // generate the method
276                accept(generator);
277    
278            } catch (AnalyzerException ae) {
279                if ((ae.getCause() != null) && ae.getCause() instanceof DifferentiationException) {
280                    errorReporter.register((DifferentiationException) ae.getCause());
281                } else {
282                    final DifferentiationException de =
283                        new DifferentiationException("unable to analyze the {0}.{1} method ({2})",
284                                                new Object[] {
285                                                    primitiveName, name, ae.getMessage()
286                                                });
287                    errorReporter.register(de);
288                }
289            } catch (DifferentiationException de) {
290                errorReporter.register(de);
291            }
292        }
293    
294        /** Add spare cells for new local variables.
295         * <p>In order to ease conversion from double values to differential pairs,
296         * we start by reserving one spare cell between each original local variables.
297         * So we have to modify the indices in all instructions referencing local
298         * variables in the original code, to take into account the renumbering
299         * introduced by these spare cells. The spare cells by themselves will
300         * be referenced by the converted instructions in the following passes.</p>
301         * <p>The spare cells that will not be used will be reclaimed after
302         * conversion, to avoid wasting memory.</p>
303         * @exception DifferentiationException if local variables array has not been
304         * expanded appropriately beforehand
305         * @see #removeUnusedSpareLocalVariables()
306         */
307        private void addSpareLocalVariables() throws DifferentiationException {
308            for (final Iterator<?> i = instructions.iterator(); i.hasNext();) {
309                final AbstractInsnNode insn = (AbstractInsnNode) i.next();
310                if (insn.getType() == AbstractInsnNode.VAR_INSN) {
311                    final VarInsnNode varInsn = (VarInsnNode) insn;
312                    if (varInsn.var > 2) {
313                        varInsn.var = 2 * varInsn.var - 1;
314                        final int opcode = varInsn.getOpcode();
315                        if ((opcode == Opcodes.ILOAD)  || (opcode == Opcodes.FLOAD)  ||
316                            (opcode == Opcodes.ALOAD)  || (opcode == Opcodes.ISTORE) ||
317                            (opcode == Opcodes.FSTORE) || (opcode == Opcodes.ASTORE)) {
318                            useLocal(varInsn.var, 1);
319                        } else {
320                            useLocal(varInsn.var, 2);
321                        }
322                    }
323                } else if (insn.getOpcode() == Opcodes.IINC) {
324                    final IincInsnNode iincInsn = (IincInsnNode) insn;
325                    if (iincInsn.var > 2) {
326                        iincInsn.var = 2 * iincInsn.var - 1;
327                        useLocal(iincInsn.var, 1);
328                    }
329                }
330            }
331        }
332    
333        /** Remove the unused spare cells introduced at conversion start.
334         * @see #addSpareLocalVariables()
335         */
336        private void removeUnusedSpareLocalVariables() {
337            for (final Iterator<?> i = instructions.iterator(); i.hasNext();) {
338                final AbstractInsnNode insn = (AbstractInsnNode) i.next();
339                if (insn.getType() == AbstractInsnNode.VAR_INSN) {
340                    shiftVariable((VarInsnNode) insn);
341                }
342            }
343        }
344    
345        /** Identify the instructions that must be changed.
346         * <p>Identification is based on data flow analysis. We start by changing
347         * the local variables in the initial frame to match the parameters of
348         * the derivative method, and propagate these variables following the
349         * instructions path, updating stack cells and local variables as needed.
350         * Instructions that must be changed are the ones that consume changed
351         * variables or stack cells.</p>
352         * @return set containing all the instructions that must be changed
353         */
354        private Set<AbstractInsnNode> identifyChanges() {
355    
356            // the pending set contains the values (local variables or stack cells)
357            // that have been changed, they will trigger changes on the instructions
358            // that consume them
359            final Set<TrackingValue> pending = new HashSet<TrackingValue>();
360    
361            // the changes set contains the instructions that must be changed
362            final Set<AbstractInsnNode> changes = new HashSet<AbstractInsnNode>();
363    
364            // start by converting the parameter of the method,
365            // which is kept in local variable 1 of the initial frame
366            final TrackingValue dpParameter = (TrackingValue) frames.get(instructions.get(0)).getLocal(1);
367            pending.add(dpParameter);
368    
369            // propagate the values conversions throughout the method
370            while (!pending.isEmpty()) {
371    
372                // pop one element from the set of changed values
373                final Iterator<TrackingValue> iterator = pending.iterator();
374                final TrackingValue value = iterator.next();
375                iterator.remove();
376    
377                // this value is converted
378                converted.add(value);
379    
380                // check the consumers instructions for this value
381                for (final AbstractInsnNode consumer : value.getConsumers()) {
382    
383                    // an instruction consuming a converted value and producing
384                    // a double must be changed to produce a differential pair,
385                    // get the double values produced and add them to the changed set
386                    for (TrackingValue produced : getProducedDoubleValues(consumer)) {
387    
388                        // add it to the pending set if it has not already been processed
389                        if (!converted.contains(produced)) {
390                            pending.add(produced);
391                        }
392    
393                    }
394    
395                    // as a consumer of a converted value, the instruction must be changed
396                    changes.add(consumer);
397    
398                }
399    
400                // check the producers instructions for this value
401                for (final AbstractInsnNode producer : value.getProducers()) {
402    
403                    // an instruction producing a converted value must be changed
404                    changes.add(producer);
405    
406                }
407            }
408    
409            return changes;
410    
411        }
412    
413        /** Get the list of double values produced by an instruction.
414         * @param instruction instruction producing the values
415         * @return list of double values produced
416         */
417        private List<TrackingValue> getProducedDoubleValues(final AbstractInsnNode instruction) {
418    
419            final List<TrackingValue> values = new ArrayList<TrackingValue>();
420    
421            // get the frame before instruction execution
422            final Frame before = frames.get(instruction);
423            final int beforeStackSize = before.getStackSize();
424            final int locals = before.getLocals();
425    
426            // check the frames produced by this instruction
427            // (they correspond to the input frames of its successors)
428            final Set<AbstractInsnNode> set = successors.get(instruction);
429            if (set != null) {
430    
431                // loop over the successors of this instruction
432                for (final AbstractInsnNode successor : set) {
433                    final Frame produced = frames.get(successor);
434    
435                    // check the stack cells
436                    for (int i = 0; i < produced.getStackSize(); ++i) {
437                        final TrackingValue value = (TrackingValue) produced.getStack(i);
438                        if (((i >= beforeStackSize) || (value != before.getStack(i))) &&
439                            value.getValue().equals(BasicValue.DOUBLE_VALUE)) {
440                            values.add(value);
441                        }
442                    }
443    
444                    // check the local variables
445                    for (int i = 0; i < locals; ++i) {
446                        final TrackingValue value = (TrackingValue) produced.getLocal(i);
447                        if ((value != before.getLocal(i)) &&
448                            value.getValue().equals(BasicValue.DOUBLE_VALUE)) {
449                            values.add(value);
450                        }
451                    }
452                }
453            }
454    
455            return values;
456    
457        }
458    
459        /** Perform the code changes.
460         * @param changes instructions that must be changed
461         * @exception DifferentiationException if some instruction cannot be handled
462         */
463        private void changeCode(final Set<AbstractInsnNode> changes)
464            throws DifferentiationException {
465    
466            // insert the parameter conversion code at method start
467            final InsnList list = new InsnList();
468            list.add(new VarInsnNode(Opcodes.ALOAD, 1));
469            list.add(new InsnNode(Opcodes.DUP));
470            list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, DP_NAME,
471                                        "getValue", VOID_RETURN_D_DESCRIPTOR));
472            list.add(new VarInsnNode(Opcodes.DSTORE, 1));
473            list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, DP_NAME,
474                                        "getFirstDerivative", VOID_RETURN_D_DESCRIPTOR));
475            list.add(new VarInsnNode(Opcodes.DSTORE, 3));
476    
477            instructions.insertBefore(instructions.get(0), list);
478    
479            // transform the existing instructions
480            for (final AbstractInsnNode insn : changes) {
481                instructions.insert(insn, getReplacement(insn));
482                instructions.remove(insn);
483            }
484    
485        }
486    
487        /** Get the replacement list for an instruction.
488         * @param insn instruction to replace
489         * @return replacement instructions list
490         * @exception DifferentiationException if some instruction cannot be handled
491         * or if no temporary variable can be reserved
492         */
493        private InsnList getReplacement(final AbstractInsnNode insn)
494            throws DifferentiationException {
495    
496            // get the frame at the start of the instruction
497            final Frame frame = frames.get(insn);
498            final int size = frame.getStackSize();
499            final boolean stack1Converted = (size > 0) && converted.contains(frame.getStack(size - 2));
500            final boolean stack0Converted = (size > 1) && converted.contains(frame.getStack(size - 1));
501    
502            switch(insn.getOpcode()) {
503            case Opcodes.DLOAD :
504                useLocal(((VarInsnNode) insn).var, 4);
505                return  DLoadTransformer.getInstance().getReplacement(insn, this);
506            case Opcodes.DALOAD :
507                // TODO add support for DALOAD differentiation
508                throw new RuntimeException("DALOAD not handled yet");
509            case Opcodes.DSTORE :
510                useLocal(((VarInsnNode) insn).var, 4);
511                return  DStoreTransformer.getInstance().getReplacement(insn, this);
512            case Opcodes.DASTORE :
513                // TODO add support for DASTORE differentiation
514                throw new RuntimeException("DASTORE not handled yet");
515            case Opcodes.DUP2 :
516                return Dup2Transformer.getInstance().getReplacement(insn, this);
517            case Opcodes.DUP2_X1 :
518                return Dup2X1Transformer.getInstance().getReplacement(insn, this);
519            case Opcodes.DUP2_X2 :
520                if (stack1Converted) {
521                    if (stack0Converted) {
522                        return Dup2X2Transformer12.getInstance().getReplacement(insn, this);
523                    }
524                    return Dup2X2Transformer1.getInstance().getReplacement(insn, this);
525                }
526                return Dup2X2Transformer2.getInstance().getReplacement(insn, this);
527            case Opcodes.DADD :
528                if (stack1Converted) {
529                    if (stack0Converted) {
530                        return DAddTransformer12.getInstance().getReplacement(insn, this);
531                    }
532                    return DAddTransformer1.getInstance().getReplacement(insn, this);
533                }
534                return  DAddTransformer2.getInstance().getReplacement(insn, this);
535            case Opcodes.DSUB :
536                if (stack1Converted) {
537                    if (stack0Converted) {
538                        return DSubTransformer12.getInstance().getReplacement(insn, this);
539                    }
540                    return DSubTransformer1.getInstance().getReplacement(insn, this);
541                }
542                return  DSubTransformer2.getInstance().getReplacement(insn, this);
543            case Opcodes.DMUL :
544                if (stack1Converted) {
545                    if (stack0Converted) {
546                        return  DMulTransformer12.getInstance().getReplacement(insn, this);
547                    }
548                    return  DMulTransformer1.getInstance().getReplacement(insn, this);
549                }
550                return  DMulTransformer2.getInstance().getReplacement(insn, this);
551            case Opcodes.DDIV :
552                if (stack1Converted) {
553                    if (stack0Converted) {
554                        return  DDivTransformer12.getInstance().getReplacement(insn, this);
555                    }
556                    return  DDivTransformer1.getInstance().getReplacement(insn, this);
557                }
558                return  DDivTransformer2.getInstance().getReplacement(insn, this);
559            case Opcodes.DREM :
560                if (stack1Converted) {
561                    if (stack0Converted) {
562                        return  DRemTransformer12.getInstance().getReplacement(insn, this);
563                    }
564                    return  DRemTransformer1.getInstance().getReplacement(insn, this);
565                }
566                return  DRemTransformer2.getInstance().getReplacement(insn, this);
567            case Opcodes.DNEG :
568                return  DNegTransformer.getInstance().getReplacement(insn, this);
569            case Opcodes.DCONST_0 :
570            case Opcodes.DCONST_1 :
571            case Opcodes.LDC :
572            case Opcodes.I2D :
573            case Opcodes.L2D :
574            case Opcodes.F2D :
575                return WideningTransformer.getInstance().getReplacement(insn, this);
576            case Opcodes.POP2 :
577            case Opcodes.D2I :
578            case Opcodes.D2L :
579            case Opcodes.D2F :
580                return NarrowingTransformer.getInstance().getReplacement(insn, this);
581            case Opcodes.DCMPL :
582            case Opcodes.DCMPG :
583                if (stack1Converted) {
584                    if (stack0Converted) {
585                        return  DcmpTransformer12.getInstance().getReplacement(insn, this);
586                    }
587                    return  DcmpTransformer1.getInstance().getReplacement(insn, this);
588                }
589                return  DcmpTransformer2.getInstance().getReplacement(insn, this);
590            case Opcodes.DRETURN :
591                return  DReturnTransformer.getInstance().getReplacement(insn, this);
592            case Opcodes.GETSTATIC :
593                // TODO add support for GETSTATIC differentiation
594                throw new RuntimeException("GETSTATIC not handled yet");
595            case Opcodes.PUTSTATIC :
596                // TODO add support for PUTSTATIC differentiation
597                throw new RuntimeException("PUTSTATIC not handled yet");
598            case Opcodes.GETFIELD :
599                // TODO add support for GETFIELD differentiation
600                throw new RuntimeException("GETFIELD not handled yet");
601            case Opcodes.PUTFIELD :
602                // TODO add support for PUTFIELD differentiation
603                throw new RuntimeException("PUTFIELD not handled yet");
604            case Opcodes.INVOKEVIRTUAL :
605                // TODO add support for INVOKEVIRTUAL differentiation
606                throw new RuntimeException("INVOKEVIRTUAL not handled yet");
607            case Opcodes.INVOKESPECIAL :
608                // TODO add support for INVOKESPECIAL differentiation
609                throw new RuntimeException("INVOKESPECIAL not handled yet");
610            case Opcodes.INVOKESTATIC :
611                return replaceInvocation((MethodInsnNode) insn,
612                                         stack1Converted, stack0Converted);
613            case Opcodes.INVOKEINTERFACE :
614                // TODO add support for INVOKEINTERFACE differentiation
615                throw new RuntimeException("INVOKEINTERFACE not handled yet");
616            case Opcodes.NEWARRAY :
617                // TODO add support for NEWARRAY differentiation
618                throw new RuntimeException("NEWARRAY not handled yet");
619            case Opcodes.ANEWARRAY :
620                // TODO add support for ANEWARRAY differentiation
621                throw new RuntimeException("ANEWARRAY not handled yet");
622            case Opcodes.MULTIANEWARRAY :
623                // TODO add support for MULTIANEWARRAY differentiation
624                throw new RuntimeException("MULTIANEWARRAY not handled yet");
625            default:
626                throw new DifferentiationException("unable to handle instruction with opcode {0}",
627                                              new Object[] {
628                                                  Integer.valueOf(insn.getOpcode())
629                                              });
630            }
631    
632        }
633    
634        /** Replace an INVOKESTATIC instruction.
635         * @param methodInsn invocation instruction
636         * @param stack1Converted if true, the stack sub-head has been
637         * converted to differential pair
638         * @param stack0Converted if true, the stack head has been
639         * converted to differential pair
640         * @return replacement instructions list
641         * @exception DifferentiationException if the instruction cannot be replaced
642         */
643        private InsnList replaceInvocation(final MethodInsnNode methodInsn,
644                                           final boolean stack1Converted,
645                                           final boolean stack0Converted)
646            throws DifferentiationException {
647            if (isMathImplementationClass(methodInsn.owner)) {
648                if ("(D)D".equals(methodInsn.desc)) {
649                    // this is a univariate method like sin, cos, exp ...
650                    final MathInvocationTransformer transformer = MATH_TRANSFORMERS.get(methodInsn.name);
651                    if (transformer == null) {
652                        throw new DifferentiationException(UNKNOWN_METHOD_FMT,
653                                                      methodInsn.owner, methodInsn.name);
654                    }
655                    return transformer.getReplacementList(methodInsn.owner, this);
656                } else if ("(DD)D".equals(methodInsn.desc)) {
657                    // this is a bivariate method like atan2, pow ...
658    
659                    // we may want to differentiate against first, second or both parameters
660                    String name = null;
661                    if (stack1Converted) {
662                        if (stack0Converted) {
663                            name = methodInsn.name + "_12";
664                        } else {
665                            name = methodInsn.name + "_1";
666                        }
667                    } else if (stack0Converted) {
668                        name = methodInsn.name + "_2";
669                    }
670    
671                    if (name != null) {
672                        final MathInvocationTransformer transformer = MATH_TRANSFORMERS.get(name);
673                        if (transformer == null) {
674                            throw new DifferentiationException(UNKNOWN_METHOD_FMT,
675                                                          methodInsn.owner, methodInsn.name);
676                        }
677                        return transformer.getReplacementList(methodInsn.owner, this);
678                    }
679                }
680            }
681            throw new DifferentiationException("unexpected instruction {0}",
682                                               Integer.valueOf(methodInsn.getOpcode()));
683        }
684    
685        /** Test if a class is a math implementation class.
686         * @param name name of the class to test
687         * @return true if the named class is a math implementation class
688         */
689        public boolean isMathImplementationClass(final String name) {
690            return mathClasses.contains(name);
691        }
692    
693        /** Set a local variable as used by the modified code.
694         * @param index index of the variable
695         * @param size size of the variable (1 or 2 for standard variables,
696         * 4 for special expanded differential pairs)
697         * @exception DifferentiationException if the number of the
698         * temporary variable lies outside of the allowed range
699         */
700        public void useLocal(final int index, final int size)
701            throws DifferentiationException {
702            if ((index < 0) || ((index + size - 1) >= usedLocals.length)) {
703                throw new DifferentiationException("index of size {0} local variable ({1}) " +
704                                              "outside of [{2}, {3}] range",
705                                              Integer.valueOf(size), Integer.valueOf(index),
706                                              Integer.valueOf(1), Integer.valueOf(MAX_TEMP));
707            }
708            for (int i = index; i < index + size; ++i) {
709                usedLocals[i] = true;
710            }
711        }
712    
713        /** Get index of a size 2 temporary variable.
714         * <p>Temporary variables can be used for very short duration
715         * storage of size 2 values (i.e one double, or one long or two
716         * integers). These variables are reused in many replacement
717         * instructions sequences, so their content may be overridden
718         * at any time: they should be considered to have a scope
719         * limited to one replacement sequence only. This means that
720         * one should <em>not</em> store a value in a variable in one
721         * replacement sequence and retrieve it later in another
722         * replacement sequence as it may have been overridden in
723         * between.</p>
724         * <p>At most 5 temporary variables may be used.</p>
725         * @param number number of the temporary variable (must be
726         * between 1 and the maximal number of available temporary
727         * variables)
728         * @return index of the variable within the local variables
729         * array
730         * @exception DifferentiationException if the number of the
731         * temporary variable lies outside of the allowed range
732         */
733        public int getTmp(final int number) throws DifferentiationException {
734            if ((number < 0) || (number > MAX_TEMP)) {
735                throw new DifferentiationException("number of temporary variable ({0}) outside of [{1}, {2}] range",
736                                                   Integer.valueOf(number),
737                                                   Integer.valueOf(1),
738                                                   Integer.valueOf(MAX_TEMP));
739            }
740            final int index = usedLocals.length - 2 * number;
741            useLocal(index, 2);
742            return index;
743        }
744    
745        /** Shifted the index of a variable instruction.
746         * @param insn variable instruction
747         */
748        public void shiftVariable(final VarInsnNode insn) {
749            int shifted = 0;
750            for (int i = 0; i < insn.var; ++i) {
751                if (usedLocals[i]) {
752                    ++shifted;
753                }
754            }
755            insn.var = shifted;
756        }
757    
758        /** Clone an instruction.
759         * @param insn instruction to clone
760         * @return cloned instruction
761         */
762        public AbstractInsnNode clone(final AbstractInsnNode insn) {
763            return insn.clone(clonedLabels);
764        }
765    
766        /** Analyzer preserving instructions successors information. */
767        private class FlowAnalyzer extends Analyzer {
768    
769            /** Simple constructor.
770             * @param interpreter associated interpreter
771             */
772            public FlowAnalyzer(final Interpreter interpreter) {
773                super(interpreter);
774            }
775    
776            /** Store a new edge.
777             * @param insn current instruction
778             * @param successor successor instruction
779             */
780            protected void newControlFlowEdge(final int insn, final int successor) {
781                // store the successor information
782                final AbstractInsnNode node = instructions.get(insn);
783                Set<AbstractInsnNode> set = successors.get(node);
784                if (set == null) {
785                    set = new HashSet<AbstractInsnNode>();
786                    successors.put(node, set);
787                }
788                set.add(instructions.get(successor));
789            }
790    
791        }
792    
793    }