1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.nabla.automatic.analysis;
18
19 import java.util.ArrayList;
20 import java.util.HashMap;
21 import java.util.HashSet;
22 import java.util.IdentityHashMap;
23 import java.util.Iterator;
24 import java.util.List;
25 import java.util.Map;
26 import java.util.Set;
27
28 import org.apache.commons.nabla.automatic.arithmetic.DAddTransformer1;
29 import org.apache.commons.nabla.automatic.arithmetic.DAddTransformer12;
30 import org.apache.commons.nabla.automatic.arithmetic.DAddTransformer2;
31 import org.apache.commons.nabla.automatic.arithmetic.DDivTransformer1;
32 import org.apache.commons.nabla.automatic.arithmetic.DDivTransformer12;
33 import org.apache.commons.nabla.automatic.arithmetic.DDivTransformer2;
34 import org.apache.commons.nabla.automatic.arithmetic.DMulTransformer1;
35 import org.apache.commons.nabla.automatic.arithmetic.DMulTransformer12;
36 import org.apache.commons.nabla.automatic.arithmetic.DMulTransformer2;
37 import org.apache.commons.nabla.automatic.arithmetic.DNegTransformer;
38 import org.apache.commons.nabla.automatic.arithmetic.DRemTransformer1;
39 import org.apache.commons.nabla.automatic.arithmetic.DRemTransformer12;
40 import org.apache.commons.nabla.automatic.arithmetic.DRemTransformer2;
41 import org.apache.commons.nabla.automatic.arithmetic.DSubTransformer1;
42 import org.apache.commons.nabla.automatic.arithmetic.DSubTransformer12;
43 import org.apache.commons.nabla.automatic.arithmetic.DSubTransformer2;
44 import org.apache.commons.nabla.automatic.functions.AcosTransformer;
45 import org.apache.commons.nabla.automatic.functions.AcoshTransformer;
46 import org.apache.commons.nabla.automatic.functions.AsinTransformer;
47 import org.apache.commons.nabla.automatic.functions.AsinhTransformer;
48 import org.apache.commons.nabla.automatic.functions.Atan2Transformer1;
49 import org.apache.commons.nabla.automatic.functions.Atan2Transformer12;
50 import org.apache.commons.nabla.automatic.functions.Atan2Transformer2;
51 import org.apache.commons.nabla.automatic.functions.AtanTransformer;
52 import org.apache.commons.nabla.automatic.functions.AtanhTransformer;
53 import org.apache.commons.nabla.automatic.functions.CbrtTransformer;
54 import org.apache.commons.nabla.automatic.functions.CosTransformer;
55 import org.apache.commons.nabla.automatic.functions.CoshTransformer;
56 import org.apache.commons.nabla.automatic.functions.ExpTransformer;
57 import org.apache.commons.nabla.automatic.functions.Expm1Transformer;
58 import org.apache.commons.nabla.automatic.functions.HypotTransformer1;
59 import org.apache.commons.nabla.automatic.functions.HypotTransformer12;
60 import org.apache.commons.nabla.automatic.functions.HypotTransformer2;
61 import org.apache.commons.nabla.automatic.functions.Log10Transformer;
62 import org.apache.commons.nabla.automatic.functions.Log1pTransformer;
63 import org.apache.commons.nabla.automatic.functions.LogTransformer;
64 import org.apache.commons.nabla.automatic.functions.MathInvocationTransformer;
65 import org.apache.commons.nabla.automatic.functions.PowTransformer1;
66 import org.apache.commons.nabla.automatic.functions.PowTransformer12;
67 import org.apache.commons.nabla.automatic.functions.PowTransformer2;
68 import org.apache.commons.nabla.automatic.functions.SinTransformer;
69 import org.apache.commons.nabla.automatic.functions.SinhTransformer;
70 import org.apache.commons.nabla.automatic.functions.SqrtTransformer;
71 import org.apache.commons.nabla.automatic.functions.TanTransformer;
72 import org.apache.commons.nabla.automatic.functions.TanhTransformer;
73 import org.apache.commons.nabla.automatic.instructions.DLoadTransformer;
74 import org.apache.commons.nabla.automatic.instructions.DReturnTransformer;
75 import org.apache.commons.nabla.automatic.instructions.DStoreTransformer;
76 import org.apache.commons.nabla.automatic.instructions.DcmpTransformer1;
77 import org.apache.commons.nabla.automatic.instructions.DcmpTransformer12;
78 import org.apache.commons.nabla.automatic.instructions.DcmpTransformer2;
79 import org.apache.commons.nabla.automatic.instructions.Dup2Transformer;
80 import org.apache.commons.nabla.automatic.instructions.Dup2X1Transformer;
81 import org.apache.commons.nabla.automatic.instructions.Dup2X2Transformer1;
82 import org.apache.commons.nabla.automatic.instructions.Dup2X2Transformer12;
83 import org.apache.commons.nabla.automatic.instructions.Dup2X2Transformer2;
84 import org.apache.commons.nabla.automatic.instructions.NarrowingTransformer;
85 import org.apache.commons.nabla.automatic.instructions.WideningTransformer;
86 import org.apache.commons.nabla.automatic.trimming.DLoadPop2Trimmer;
87 import org.apache.commons.nabla.automatic.trimming.SwappedDloadTrimmer;
88 import org.apache.commons.nabla.automatic.trimming.SwappedDstoreTrimmer;
89 import org.apache.commons.nabla.core.DifferentialPair;
90 import org.apache.commons.nabla.core.DifferentiationException;
91 import org.objectweb.asm.MethodVisitor;
92 import org.objectweb.asm.Opcodes;
93 import org.objectweb.asm.tree.AbstractInsnNode;
94 import org.objectweb.asm.tree.FieldInsnNode;
95 import org.objectweb.asm.tree.IincInsnNode;
96 import org.objectweb.asm.tree.InsnList;
97 import org.objectweb.asm.tree.InsnNode;
98 import org.objectweb.asm.tree.LabelNode;
99 import org.objectweb.asm.tree.MethodInsnNode;
100 import org.objectweb.asm.tree.MethodNode;
101 import org.objectweb.asm.tree.VarInsnNode;
102 import org.objectweb.asm.tree.analysis.Analyzer;
103 import org.objectweb.asm.tree.analysis.AnalyzerException;
104 import org.objectweb.asm.tree.analysis.BasicValue;
105 import org.objectweb.asm.tree.analysis.Frame;
106 import org.objectweb.asm.tree.analysis.Interpreter;
107
108
109
110
111 public class MethodDifferentiator extends MethodNode {
112
113
114 public static final String DP_NAME = DifferentialPair.class.getName().replace('.', '/');
115
116
117 public static final String DP_DESCRIPTOR = "L" + DP_NAME + ";";
118
119
120 public static final String DP_RETURN_DP_DESCRIPTOR = "(" + DP_DESCRIPTOR + ")" + DP_DESCRIPTOR;
121
122
123 private static final String VOID_RETURN_D_DESCRIPTOR = "()D";
124
125
126 private static final Map<String, MathInvocationTransformer> MATH_TRANSFORMERS =
127 new HashMap<String, MathInvocationTransformer>();
128
129 static {
130 MATH_TRANSFORMERS.put("acos", new AcosTransformer());
131 MATH_TRANSFORMERS.put("acosh", new AcoshTransformer());
132 MATH_TRANSFORMERS.put("asin", new AsinTransformer());
133 MATH_TRANSFORMERS.put("asinh", new AsinhTransformer());
134 MATH_TRANSFORMERS.put("atan2_12", new Atan2Transformer12());
135 MATH_TRANSFORMERS.put("atan2_1", new Atan2Transformer1());
136 MATH_TRANSFORMERS.put("atan2_2", new Atan2Transformer2());
137 MATH_TRANSFORMERS.put("atan", new AtanTransformer());
138 MATH_TRANSFORMERS.put("atanh", new AtanhTransformer());
139 MATH_TRANSFORMERS.put("cbrt", new CbrtTransformer());
140 MATH_TRANSFORMERS.put("cos", new CosTransformer());
141 MATH_TRANSFORMERS.put("cosh", new CoshTransformer());
142 MATH_TRANSFORMERS.put("exp", new ExpTransformer());
143 MATH_TRANSFORMERS.put("expm1", new Expm1Transformer());
144 MATH_TRANSFORMERS.put("hypot_12", new HypotTransformer12());
145 MATH_TRANSFORMERS.put("hypot_1", new HypotTransformer1());
146 MATH_TRANSFORMERS.put("hypot_2", new HypotTransformer2());
147 MATH_TRANSFORMERS.put("log10", new Log10Transformer());
148 MATH_TRANSFORMERS.put("log1p", new Log1pTransformer());
149 MATH_TRANSFORMERS.put("log", new LogTransformer());
150 MATH_TRANSFORMERS.put("pow_12", new PowTransformer12());
151 MATH_TRANSFORMERS.put("pow_1", new PowTransformer1());
152 MATH_TRANSFORMERS.put("pow_2", new PowTransformer2());
153 MATH_TRANSFORMERS.put("sin", new SinTransformer());
154 MATH_TRANSFORMERS.put("sinh", new SinhTransformer());
155 MATH_TRANSFORMERS.put("sqrt", new SqrtTransformer());
156 MATH_TRANSFORMERS.put("tan", new TanTransformer());
157 MATH_TRANSFORMERS.put("tanh", new TanhTransformer());
158 }
159
160
161 private static final String UNKNOWN_METHOD_FMT = "unknown method {0}.{1}";
162
163
164 private static final int MAX_TEMP = 5;
165
166
167 private final Set<String> mathClasses;
168
169
170 private final MethodVisitor generator;
171
172
173 private boolean[] usedLocals;
174
175
176 private final String primitiveName;
177
178
179 private final ErrorReporter errorReporter;
180
181
182 private final Set<TrackingValue> converted;
183
184
185 private final Map<AbstractInsnNode, Frame> frames;
186
187
188 private final Map<AbstractInsnNode, Set<AbstractInsnNode>> successors;
189
190
191 private final Map<LabelNode, LabelNode> clonedLabels;
192
193
194
195
196
197
198
199
200
201
202
203
204 public MethodDifferentiator(final int access, final String name, final String desc,
205 final String signature, final String[] exceptions,
206 final MethodVisitor generator,final String primitiveName,
207 final Set<String> mathClasses,
208 final ErrorReporter errorReporter) {
209
210 super(access, name, desc, signature, exceptions);
211 this.generator = generator;
212 this.usedLocals = null;
213 this.primitiveName = primitiveName;
214 this.mathClasses = mathClasses;
215 this.errorReporter = errorReporter;
216 this.converted = new HashSet<TrackingValue>();
217 this.frames = new IdentityHashMap<AbstractInsnNode, Frame>();
218 this.successors = new IdentityHashMap<AbstractInsnNode, Set<AbstractInsnNode>>();
219 this.clonedLabels = new HashMap<LabelNode, LabelNode>();
220
221 }
222
223
224 @Override
225 public void visitEnd() {
226 try {
227
228
229 maxLocals = 2 * (maxLocals + MAX_TEMP) - 1;
230 usedLocals = new boolean[maxLocals];
231 useLocal(0, 1);
232 useLocal(1, 4);
233
234
235 addSpareLocalVariables();
236
237
238 final Frame[] array =
239 new FlowAnalyzer(new TrackingInterpreter()).analyze(primitiveName, this);
240
241
242 for (int i = 0; i < array.length; ++i) {
243 frames.put(instructions.get(i), array[i]);
244 }
245
246
247 final Set<AbstractInsnNode> changes = identifyChanges();
248
249 if (changes.isEmpty()) {
250
251
252
253 instructions.clear();
254 instructions.add(new FieldInsnNode(Opcodes.GETSTATIC, DP_NAME, "ZERO", DP_DESCRIPTOR));
255 instructions.add(new InsnNode(Opcodes.ARETURN));
256
257 } else {
258
259
260 changeCode(changes);
261
262
263 removeUnusedSpareLocalVariables();
264
265
266 SwappedDloadTrimmer.getInstance().trim(instructions);
267 SwappedDstoreTrimmer.getInstance().trim(instructions);
268 DLoadPop2Trimmer.getInstance().trim(instructions);
269
270 }
271
272
273 desc = DP_RETURN_DP_DESCRIPTOR;
274
275
276 accept(generator);
277
278 } catch (AnalyzerException ae) {
279 if ((ae.getCause() != null) && ae.getCause() instanceof DifferentiationException) {
280 errorReporter.register((DifferentiationException) ae.getCause());
281 } else {
282 final DifferentiationException de =
283 new DifferentiationException("unable to analyze the {0}.{1} method ({2})",
284 new Object[] {
285 primitiveName, name, ae.getMessage()
286 });
287 errorReporter.register(de);
288 }
289 } catch (DifferentiationException de) {
290 errorReporter.register(de);
291 }
292 }
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307 private void addSpareLocalVariables() throws DifferentiationException {
308 for (final Iterator<?> i = instructions.iterator(); i.hasNext();) {
309 final AbstractInsnNode insn = (AbstractInsnNode) i.next();
310 if (insn.getType() == AbstractInsnNode.VAR_INSN) {
311 final VarInsnNode varInsn = (VarInsnNode) insn;
312 if (varInsn.var > 2) {
313 varInsn.var = 2 * varInsn.var - 1;
314 final int opcode = varInsn.getOpcode();
315 if ((opcode == Opcodes.ILOAD) || (opcode == Opcodes.FLOAD) ||
316 (opcode == Opcodes.ALOAD) || (opcode == Opcodes.ISTORE) ||
317 (opcode == Opcodes.FSTORE) || (opcode == Opcodes.ASTORE)) {
318 useLocal(varInsn.var, 1);
319 } else {
320 useLocal(varInsn.var, 2);
321 }
322 }
323 } else if (insn.getOpcode() == Opcodes.IINC) {
324 final IincInsnNode iincInsn = (IincInsnNode) insn;
325 if (iincInsn.var > 2) {
326 iincInsn.var = 2 * iincInsn.var - 1;
327 useLocal(iincInsn.var, 1);
328 }
329 }
330 }
331 }
332
333
334
335
336 private void removeUnusedSpareLocalVariables() {
337 for (final Iterator<?> i = instructions.iterator(); i.hasNext();) {
338 final AbstractInsnNode insn = (AbstractInsnNode) i.next();
339 if (insn.getType() == AbstractInsnNode.VAR_INSN) {
340 shiftVariable((VarInsnNode) insn);
341 }
342 }
343 }
344
345
346
347
348
349
350
351
352
353
354 private Set<AbstractInsnNode> identifyChanges() {
355
356
357
358
359 final Set<TrackingValue> pending = new HashSet<TrackingValue>();
360
361
362 final Set<AbstractInsnNode> changes = new HashSet<AbstractInsnNode>();
363
364
365
366 final TrackingValue dpParameter = (TrackingValue) frames.get(instructions.get(0)).getLocal(1);
367 pending.add(dpParameter);
368
369
370 while (!pending.isEmpty()) {
371
372
373 final Iterator<TrackingValue> iterator = pending.iterator();
374 final TrackingValue value = iterator.next();
375 iterator.remove();
376
377
378 converted.add(value);
379
380
381 for (final AbstractInsnNode consumer : value.getConsumers()) {
382
383
384
385
386 for (TrackingValue produced : getProducedDoubleValues(consumer)) {
387
388
389 if (!converted.contains(produced)) {
390 pending.add(produced);
391 }
392
393 }
394
395
396 changes.add(consumer);
397
398 }
399
400
401 for (final AbstractInsnNode producer : value.getProducers()) {
402
403
404 changes.add(producer);
405
406 }
407 }
408
409 return changes;
410
411 }
412
413
414
415
416
417 private List<TrackingValue> getProducedDoubleValues(final AbstractInsnNode instruction) {
418
419 final List<TrackingValue> values = new ArrayList<TrackingValue>();
420
421
422 final Frame before = frames.get(instruction);
423 final int beforeStackSize = before.getStackSize();
424 final int locals = before.getLocals();
425
426
427
428 final Set<AbstractInsnNode> set = successors.get(instruction);
429 if (set != null) {
430
431
432 for (final AbstractInsnNode successor : set) {
433 final Frame produced = frames.get(successor);
434
435
436 for (int i = 0; i < produced.getStackSize(); ++i) {
437 final TrackingValue value = (TrackingValue) produced.getStack(i);
438 if (((i >= beforeStackSize) || (value != before.getStack(i))) &&
439 value.getValue().equals(BasicValue.DOUBLE_VALUE)) {
440 values.add(value);
441 }
442 }
443
444
445 for (int i = 0; i < locals; ++i) {
446 final TrackingValue value = (TrackingValue) produced.getLocal(i);
447 if ((value != before.getLocal(i)) &&
448 value.getValue().equals(BasicValue.DOUBLE_VALUE)) {
449 values.add(value);
450 }
451 }
452 }
453 }
454
455 return values;
456
457 }
458
459
460
461
462
463 private void changeCode(final Set<AbstractInsnNode> changes)
464 throws DifferentiationException {
465
466
467 final InsnList list = new InsnList();
468 list.add(new VarInsnNode(Opcodes.ALOAD, 1));
469 list.add(new InsnNode(Opcodes.DUP));
470 list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, DP_NAME,
471 "getValue", VOID_RETURN_D_DESCRIPTOR));
472 list.add(new VarInsnNode(Opcodes.DSTORE, 1));
473 list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, DP_NAME,
474 "getFirstDerivative", VOID_RETURN_D_DESCRIPTOR));
475 list.add(new VarInsnNode(Opcodes.DSTORE, 3));
476
477 instructions.insertBefore(instructions.get(0), list);
478
479
480 for (final AbstractInsnNode insn : changes) {
481 instructions.insert(insn, getReplacement(insn));
482 instructions.remove(insn);
483 }
484
485 }
486
487
488
489
490
491
492
493 private InsnList getReplacement(final AbstractInsnNode insn)
494 throws DifferentiationException {
495
496
497 final Frame frame = frames.get(insn);
498 final int size = frame.getStackSize();
499 final boolean stack1Converted = (size > 0) && converted.contains(frame.getStack(size - 2));
500 final boolean stack0Converted = (size > 1) && converted.contains(frame.getStack(size - 1));
501
502 switch(insn.getOpcode()) {
503 case Opcodes.DLOAD :
504 useLocal(((VarInsnNode) insn).var, 4);
505 return DLoadTransformer.getInstance().getReplacement(insn, this);
506 case Opcodes.DALOAD :
507
508 throw new RuntimeException("DALOAD not handled yet");
509 case Opcodes.DSTORE :
510 useLocal(((VarInsnNode) insn).var, 4);
511 return DStoreTransformer.getInstance().getReplacement(insn, this);
512 case Opcodes.DASTORE :
513
514 throw new RuntimeException("DASTORE not handled yet");
515 case Opcodes.DUP2 :
516 return Dup2Transformer.getInstance().getReplacement(insn, this);
517 case Opcodes.DUP2_X1 :
518 return Dup2X1Transformer.getInstance().getReplacement(insn, this);
519 case Opcodes.DUP2_X2 :
520 if (stack1Converted) {
521 if (stack0Converted) {
522 return Dup2X2Transformer12.getInstance().getReplacement(insn, this);
523 }
524 return Dup2X2Transformer1.getInstance().getReplacement(insn, this);
525 }
526 return Dup2X2Transformer2.getInstance().getReplacement(insn, this);
527 case Opcodes.DADD :
528 if (stack1Converted) {
529 if (stack0Converted) {
530 return DAddTransformer12.getInstance().getReplacement(insn, this);
531 }
532 return DAddTransformer1.getInstance().getReplacement(insn, this);
533 }
534 return DAddTransformer2.getInstance().getReplacement(insn, this);
535 case Opcodes.DSUB :
536 if (stack1Converted) {
537 if (stack0Converted) {
538 return DSubTransformer12.getInstance().getReplacement(insn, this);
539 }
540 return DSubTransformer1.getInstance().getReplacement(insn, this);
541 }
542 return DSubTransformer2.getInstance().getReplacement(insn, this);
543 case Opcodes.DMUL :
544 if (stack1Converted) {
545 if (stack0Converted) {
546 return DMulTransformer12.getInstance().getReplacement(insn, this);
547 }
548 return DMulTransformer1.getInstance().getReplacement(insn, this);
549 }
550 return DMulTransformer2.getInstance().getReplacement(insn, this);
551 case Opcodes.DDIV :
552 if (stack1Converted) {
553 if (stack0Converted) {
554 return DDivTransformer12.getInstance().getReplacement(insn, this);
555 }
556 return DDivTransformer1.getInstance().getReplacement(insn, this);
557 }
558 return DDivTransformer2.getInstance().getReplacement(insn, this);
559 case Opcodes.DREM :
560 if (stack1Converted) {
561 if (stack0Converted) {
562 return DRemTransformer12.getInstance().getReplacement(insn, this);
563 }
564 return DRemTransformer1.getInstance().getReplacement(insn, this);
565 }
566 return DRemTransformer2.getInstance().getReplacement(insn, this);
567 case Opcodes.DNEG :
568 return DNegTransformer.getInstance().getReplacement(insn, this);
569 case Opcodes.DCONST_0 :
570 case Opcodes.DCONST_1 :
571 case Opcodes.LDC :
572 case Opcodes.I2D :
573 case Opcodes.L2D :
574 case Opcodes.F2D :
575 return WideningTransformer.getInstance().getReplacement(insn, this);
576 case Opcodes.POP2 :
577 case Opcodes.D2I :
578 case Opcodes.D2L :
579 case Opcodes.D2F :
580 return NarrowingTransformer.getInstance().getReplacement(insn, this);
581 case Opcodes.DCMPL :
582 case Opcodes.DCMPG :
583 if (stack1Converted) {
584 if (stack0Converted) {
585 return DcmpTransformer12.getInstance().getReplacement(insn, this);
586 }
587 return DcmpTransformer1.getInstance().getReplacement(insn, this);
588 }
589 return DcmpTransformer2.getInstance().getReplacement(insn, this);
590 case Opcodes.DRETURN :
591 return DReturnTransformer.getInstance().getReplacement(insn, this);
592 case Opcodes.GETSTATIC :
593
594 throw new RuntimeException("GETSTATIC not handled yet");
595 case Opcodes.PUTSTATIC :
596
597 throw new RuntimeException("PUTSTATIC not handled yet");
598 case Opcodes.GETFIELD :
599
600 throw new RuntimeException("GETFIELD not handled yet");
601 case Opcodes.PUTFIELD :
602
603 throw new RuntimeException("PUTFIELD not handled yet");
604 case Opcodes.INVOKEVIRTUAL :
605
606 throw new RuntimeException("INVOKEVIRTUAL not handled yet");
607 case Opcodes.INVOKESPECIAL :
608
609 throw new RuntimeException("INVOKESPECIAL not handled yet");
610 case Opcodes.INVOKESTATIC :
611 return replaceInvocation((MethodInsnNode) insn,
612 stack1Converted, stack0Converted);
613 case Opcodes.INVOKEINTERFACE :
614
615 throw new RuntimeException("INVOKEINTERFACE not handled yet");
616 case Opcodes.NEWARRAY :
617
618 throw new RuntimeException("NEWARRAY not handled yet");
619 case Opcodes.ANEWARRAY :
620
621 throw new RuntimeException("ANEWARRAY not handled yet");
622 case Opcodes.MULTIANEWARRAY :
623
624 throw new RuntimeException("MULTIANEWARRAY not handled yet");
625 default:
626 throw new DifferentiationException("unable to handle instruction with opcode {0}",
627 new Object[] {
628 Integer.valueOf(insn.getOpcode())
629 });
630 }
631
632 }
633
634
635
636
637
638
639
640
641
642
643 private InsnList replaceInvocation(final MethodInsnNode methodInsn,
644 final boolean stack1Converted,
645 final boolean stack0Converted)
646 throws DifferentiationException {
647 if (isMathImplementationClass(methodInsn.owner)) {
648 if ("(D)D".equals(methodInsn.desc)) {
649
650 final MathInvocationTransformer transformer = MATH_TRANSFORMERS.get(methodInsn.name);
651 if (transformer == null) {
652 throw new DifferentiationException(UNKNOWN_METHOD_FMT,
653 methodInsn.owner, methodInsn.name);
654 }
655 return transformer.getReplacementList(methodInsn.owner, this);
656 } else if ("(DD)D".equals(methodInsn.desc)) {
657
658
659
660 String name = null;
661 if (stack1Converted) {
662 if (stack0Converted) {
663 name = methodInsn.name + "_12";
664 } else {
665 name = methodInsn.name + "_1";
666 }
667 } else if (stack0Converted) {
668 name = methodInsn.name + "_2";
669 }
670
671 if (name != null) {
672 final MathInvocationTransformer transformer = MATH_TRANSFORMERS.get(name);
673 if (transformer == null) {
674 throw new DifferentiationException(UNKNOWN_METHOD_FMT,
675 methodInsn.owner, methodInsn.name);
676 }
677 return transformer.getReplacementList(methodInsn.owner, this);
678 }
679 }
680 }
681 throw new DifferentiationException("unexpected instruction {0}",
682 Integer.valueOf(methodInsn.getOpcode()));
683 }
684
685
686
687
688
689 public boolean isMathImplementationClass(final String name) {
690 return mathClasses.contains(name);
691 }
692
693
694
695
696
697
698
699
700 public void useLocal(final int index, final int size)
701 throws DifferentiationException {
702 if ((index < 0) || ((index + size - 1) >= usedLocals.length)) {
703 throw new DifferentiationException("index of size {0} local variable ({1}) " +
704 "outside of [{2}, {3}] range",
705 Integer.valueOf(size), Integer.valueOf(index),
706 Integer.valueOf(1), Integer.valueOf(MAX_TEMP));
707 }
708 for (int i = index; i < index + size; ++i) {
709 usedLocals[i] = true;
710 }
711 }
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733 public int getTmp(final int number) throws DifferentiationException {
734 if ((number < 0) || (number > MAX_TEMP)) {
735 throw new DifferentiationException("number of temporary variable ({0}) outside of [{1}, {2}] range",
736 Integer.valueOf(number),
737 Integer.valueOf(1),
738 Integer.valueOf(MAX_TEMP));
739 }
740 final int index = usedLocals.length - 2 * number;
741 useLocal(index, 2);
742 return index;
743 }
744
745
746
747
748 public void shiftVariable(final VarInsnNode insn) {
749 int shifted = 0;
750 for (int i = 0; i < insn.var; ++i) {
751 if (usedLocals[i]) {
752 ++shifted;
753 }
754 }
755 insn.var = shifted;
756 }
757
758
759
760
761
762 public AbstractInsnNode clone(final AbstractInsnNode insn) {
763 return insn.clone(clonedLabels);
764 }
765
766
767 private class FlowAnalyzer extends Analyzer {
768
769
770
771
772 public FlowAnalyzer(final Interpreter interpreter) {
773 super(interpreter);
774 }
775
776
777
778
779
780 protected void newControlFlowEdge(final int insn, final int successor) {
781
782 final AbstractInsnNode node = instructions.get(insn);
783 Set<AbstractInsnNode> set = successors.get(node);
784 if (set == null) {
785 set = new HashSet<AbstractInsnNode>();
786 successors.put(node, set);
787 }
788 set.add(instructions.get(successor));
789 }
790
791 }
792
793 }