1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.nabla.forward.analysis;
18
19 import java.util.ArrayList;
20 import java.util.HashSet;
21 import java.util.IdentityHashMap;
22 import java.util.Iterator;
23 import java.util.List;
24 import java.util.ListIterator;
25 import java.util.Map;
26 import java.util.Set;
27
28 import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
29 import org.apache.commons.nabla.DifferentiationException;
30 import org.apache.commons.nabla.NablaMessages;
31 import org.apache.commons.nabla.forward.arithmetic.DAddTransformer;
32 import org.apache.commons.nabla.forward.arithmetic.DDivTransformer;
33 import org.apache.commons.nabla.forward.arithmetic.DMulTransformer;
34 import org.apache.commons.nabla.forward.arithmetic.DNegTransformer;
35 import org.apache.commons.nabla.forward.arithmetic.DRemTransformer;
36 import org.apache.commons.nabla.forward.arithmetic.DSubTransformer;
37 import org.apache.commons.nabla.forward.instructions.DLoadTransformer;
38 import org.apache.commons.nabla.forward.instructions.DReturnTransformer;
39 import org.apache.commons.nabla.forward.instructions.DStoreTransformer;
40 import org.apache.commons.nabla.forward.instructions.DcmpTransformer;
41 import org.apache.commons.nabla.forward.instructions.Dup2Transformer;
42 import org.apache.commons.nabla.forward.instructions.Dup2X1Transformer;
43 import org.apache.commons.nabla.forward.instructions.Dup2X2Transformer;
44 import org.apache.commons.nabla.forward.instructions.InvokeStaticTransformer;
45 import org.apache.commons.nabla.forward.instructions.NarrowingTransformer;
46 import org.apache.commons.nabla.forward.instructions.Pop2Transformer;
47 import org.apache.commons.nabla.forward.instructions.WideningTransformer;
48 import org.apache.commons.nabla.forward.trimming.DLoadPop2Trimmer;
49 import org.apache.commons.nabla.forward.trimming.SwappedDloadTrimmer;
50 import org.apache.commons.nabla.forward.trimming.SwappedDstoreTrimmer;
51 import org.objectweb.asm.Opcodes;
52 import org.objectweb.asm.Type;
53 import org.objectweb.asm.tree.AbstractInsnNode;
54 import org.objectweb.asm.tree.FieldInsnNode;
55 import org.objectweb.asm.tree.IincInsnNode;
56 import org.objectweb.asm.tree.InsnList;
57 import org.objectweb.asm.tree.InsnNode;
58 import org.objectweb.asm.tree.LdcInsnNode;
59 import org.objectweb.asm.tree.MethodInsnNode;
60 import org.objectweb.asm.tree.MethodNode;
61 import org.objectweb.asm.tree.TypeInsnNode;
62 import org.objectweb.asm.tree.VarInsnNode;
63 import org.objectweb.asm.tree.analysis.Analyzer;
64 import org.objectweb.asm.tree.analysis.AnalyzerException;
65 import org.objectweb.asm.tree.analysis.Frame;
66 import org.objectweb.asm.tree.analysis.Interpreter;
67
68
69
70
71
72 public class MethodDifferentiator {
73
74
75 private static final int MAX_TEMP = 5;
76
77
78 private final Set<String> mathClasses;
79
80
81 private final String derivedName;
82
83
84 private boolean[] usedLocals;
85
86
87 private final Set<TrackingValue> converted;
88
89
90 private final Map<AbstractInsnNode, Frame<TrackingValue>> frames;
91
92
93 private final Map<AbstractInsnNode, Set<AbstractInsnNode>> successors;
94
95
96
97
98
99 public MethodDifferentiator(final Set<String> mathClasses, final String derivedName) {
100 this.usedLocals = null;
101 this.mathClasses = mathClasses;
102 this.derivedName = derivedName;
103 this.converted = new HashSet<TrackingValue>();
104 this.frames = new IdentityHashMap<AbstractInsnNode, Frame<TrackingValue>>();
105 this.successors = new IdentityHashMap<AbstractInsnNode, Set<AbstractInsnNode>>();
106 }
107
108
109
110
111 public int getInputDSIndex() {
112
113
114 return 1;
115 }
116
117
118
119
120
121
122
123 public void differentiate(final String primitiveName, final MethodNode method)
124 throws DifferentiationException {
125 try {
126
127
128 method.maxLocals = 2 * (method.maxLocals + MAX_TEMP) - 1;
129 usedLocals = new boolean[method.maxLocals];
130 useLocal(0, 1);
131 useLocal(1, 4);
132
133 final Type dsType = Type.getType(DerivativeStructure.class);
134
135
136 addSpareLocalVariables(method.instructions);
137
138
139 final FlowAnalyzer analyzer =
140 new FlowAnalyzer(new TrackingInterpreter(), method.instructions);
141 final Frame<TrackingValue>[] array = analyzer.analyze(primitiveName, method);
142
143
144 for (int i = 0; i < array.length; ++i) {
145 frames.put(method.instructions.get(i), array[i]);
146 }
147
148
149 final Set<AbstractInsnNode> changes = identifyChanges(method.instructions);
150
151 if (changes.isEmpty()) {
152
153
154
155
156
157
158
159
160
161 for (final Iterator<AbstractInsnNode> i = method.instructions.iterator(); i.hasNext();) {
162 final AbstractInsnNode insn = i.next();
163 if (insn.getOpcode() == Opcodes.DRETURN) {
164 final InsnList list = new DReturnTransformer(false).getReplacement(insn, this);
165 method.instructions.insert(insn, list);
166 method.instructions.remove(insn);
167 }
168 }
169
170 } else {
171
172
173 for (final AbstractInsnNode insn : changes) {
174 method.instructions.insert(insn, getReplacement(insn));
175 method.instructions.remove(insn);
176 }
177
178
179 new SwappedDloadTrimmer().trim(method.instructions);
180 new SwappedDstoreTrimmer().trim(method.instructions);
181 new DLoadPop2Trimmer().trim(method.instructions);
182
183 }
184
185
186 removeUnusedSpareLocalVariables(method.instructions);
187
188
189 method.desc = Type.getMethodDescriptor(dsType, dsType);
190 method.access |= Opcodes.ACC_SYNTHETIC;
191 method.maxLocals = maxVariables();
192
193 } catch (AnalyzerException ae) {
194 ae.printStackTrace(System.err);
195 if ((ae.getCause() != null) && ae.getCause() instanceof DifferentiationException) {
196 throw (DifferentiationException) ae.getCause();
197 } else {
198 throw new DifferentiationException(NablaMessages.UNABLE_TO_ANALYZE_METHOD,
199 primitiveName, method.name, ae.getMessage());
200 }
201 }
202 }
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218 private void addSpareLocalVariables(final InsnList instructions)
219 throws DifferentiationException {
220 for (final Iterator<AbstractInsnNode> i = instructions.iterator(); i.hasNext();) {
221 final AbstractInsnNode insn = i.next();
222 if (insn.getType() == AbstractInsnNode.VAR_INSN) {
223 final VarInsnNode varInsn = (VarInsnNode) insn;
224 if (varInsn.var > 2) {
225 varInsn.var = 2 * varInsn.var - 1;
226 final int opcode = varInsn.getOpcode();
227 if ((opcode == Opcodes.ILOAD) || (opcode == Opcodes.FLOAD) ||
228 (opcode == Opcodes.ALOAD) || (opcode == Opcodes.ISTORE) ||
229 (opcode == Opcodes.FSTORE) || (opcode == Opcodes.ASTORE)) {
230 useLocal(varInsn.var, 1);
231 } else {
232 useLocal(varInsn.var, 2);
233 }
234 }
235 } else if (insn.getOpcode() == Opcodes.IINC) {
236 final IincInsnNode iincInsn = (IincInsnNode) insn;
237 if (iincInsn.var > 2) {
238 iincInsn.var = 2 * iincInsn.var - 1;
239 useLocal(iincInsn.var, 1);
240 }
241 }
242 }
243 }
244
245
246
247
248
249 private void removeUnusedSpareLocalVariables(final InsnList instructions) {
250 for (final Iterator<AbstractInsnNode> i = instructions.iterator(); i.hasNext();) {
251 final AbstractInsnNode insn = i.next();
252 if (insn.getType() == AbstractInsnNode.VAR_INSN) {
253 shiftVariable((VarInsnNode) insn);
254 }
255 }
256 }
257
258
259
260
261
262
263
264
265
266
267
268 private Set<AbstractInsnNode> identifyChanges(final InsnList instructions) {
269
270
271
272
273 final Set<TrackingValue> pending = new HashSet<TrackingValue>();
274
275
276 final Set<AbstractInsnNode> changes = new HashSet<AbstractInsnNode>();
277
278
279
280 final TrackingValue dpParameter = frames.get(instructions.get(0)).getLocal(1);
281 pending.add(dpParameter);
282
283
284 while (!pending.isEmpty()) {
285
286
287 final Iterator<TrackingValue> iterator = pending.iterator();
288 final TrackingValue value = iterator.next();
289 iterator.remove();
290
291
292 converted.add(value);
293
294
295 for (final AbstractInsnNode consumer : value.getConsumers()) {
296
297
298
299
300 pending.addAll(getProducedAndNotConvertedDoubleValues(consumer));
301
302
303 changes.add(consumer);
304
305 }
306
307
308 for (final AbstractInsnNode producer : value.getProducers()) {
309
310
311 changes.add(producer);
312
313 }
314 }
315
316
317
318 final ListIterator<AbstractInsnNode> iterator = instructions.iterator();
319 while (iterator.hasNext()) {
320 final AbstractInsnNode ins = iterator.next();
321 if ((ins.getOpcode() == Opcodes.GETFIELD) || (ins.getOpcode() == Opcodes.PUTFIELD)) {
322 changes.add(ins);
323 }
324 }
325
326 return changes;
327
328 }
329
330
331
332
333
334 private List<TrackingValue> getProducedAndNotConvertedDoubleValues(final AbstractInsnNode instruction) {
335
336 final List<TrackingValue> values = new ArrayList<TrackingValue>();
337
338
339 final Frame<TrackingValue> before = frames.get(instruction);
340 final int beforeStackSize = before.getStackSize();
341 final int locals = before.getLocals();
342
343
344
345 final Set<AbstractInsnNode> set = successors.get(instruction);
346 if (set != null) {
347
348
349 for (final AbstractInsnNode successor : set) {
350 final Frame<TrackingValue> produced = frames.get(successor);
351
352
353 for (int i = 0; i < produced.getStackSize(); ++i) {
354 final TrackingValue value = produced.getStack(i);
355 if (((i >= beforeStackSize) || (value != before.getStack(i))) &&
356 value.getType().equals(Type.DOUBLE_TYPE) &&
357 !converted.contains(value)) {
358 values.add(value);
359 }
360 }
361
362
363 for (int i = 0; i < locals; ++i) {
364 final TrackingValue value = (TrackingValue) produced.getLocal(i);
365 if ((value != before.getLocal(i)) &&
366 value.getType().equals(Type.DOUBLE_TYPE) &&
367 !converted.contains(value)) {
368 values.add(value);
369 }
370 }
371 }
372 }
373
374 return values;
375
376 }
377
378
379
380
381
382
383
384 private InsnList getReplacement(final AbstractInsnNode insn)
385 throws DifferentiationException {
386
387
388 final Frame<TrackingValue> frame = frames.get(insn);
389 final int size = frame.getStackSize();
390 final boolean stack1Converted = (size > 0) && converted.contains(frame.getStack(size - 2));
391 final boolean stack0Converted = (size > 1) && converted.contains(frame.getStack(size - 1));
392
393 switch(insn.getOpcode()) {
394 case Opcodes.DLOAD :
395 useLocal(((VarInsnNode) insn).var, 4);
396 return new DLoadTransformer().getReplacement(insn, this);
397 case Opcodes.DALOAD :
398
399 throw new RuntimeException("DALOAD not handled yet");
400 case Opcodes.DSTORE :
401 useLocal(((VarInsnNode) insn).var, 4);
402 return new DStoreTransformer().getReplacement(insn, this);
403 case Opcodes.DASTORE :
404
405 throw new RuntimeException("DASTORE not handled yet");
406 case Opcodes.DUP2 :
407 return new Dup2Transformer().getReplacement(insn, this);
408 case Opcodes.POP2 :
409 return new Pop2Transformer().getReplacement(insn, this);
410 case Opcodes.DUP2_X1 :
411 return new Dup2X1Transformer().getReplacement(insn, this);
412 case Opcodes.DUP2_X2 :
413 return new Dup2X2Transformer(stack0Converted, stack1Converted).getReplacement(insn, this);
414 case Opcodes.DADD :
415 return new DAddTransformer(stack0Converted, stack1Converted).getReplacement(insn, this);
416 case Opcodes.DSUB :
417 return new DSubTransformer(stack0Converted, stack1Converted).getReplacement(insn, this);
418 case Opcodes.DMUL :
419 return new DMulTransformer(stack0Converted, stack1Converted).getReplacement(insn, this);
420 case Opcodes.DDIV :
421 return new DDivTransformer(stack0Converted, stack1Converted).getReplacement(insn, this);
422 case Opcodes.DREM :
423 return new DRemTransformer(stack0Converted, stack1Converted).getReplacement(insn, this);
424 case Opcodes.DNEG :
425 return new DNegTransformer().getReplacement(insn, this);
426 case Opcodes.DCONST_0 :
427 case Opcodes.DCONST_1 :
428 case Opcodes.LDC :
429 case Opcodes.I2D :
430 case Opcodes.L2D :
431 case Opcodes.F2D :
432 return new WideningTransformer().getReplacement(insn, this);
433 case Opcodes.D2I :
434 case Opcodes.D2L :
435 case Opcodes.D2F :
436 return new NarrowingTransformer().getReplacement(insn, this);
437 case Opcodes.DCMPL :
438 case Opcodes.DCMPG :
439 return new DcmpTransformer(stack0Converted, stack1Converted).getReplacement(insn, this);
440 case Opcodes.DRETURN :
441
442 return new DReturnTransformer(true).getReplacement(insn, this);
443 case Opcodes.GETSTATIC :
444
445 throw new RuntimeException("GETSTATIC not handled yet");
446 case Opcodes.PUTSTATIC :
447
448 throw new RuntimeException("PUTSTATIC not handled yet");
449 case Opcodes.GETFIELD :
450 return replaceGetField((FieldInsnNode) insn);
451 case Opcodes.PUTFIELD :
452
453 throw new RuntimeException("PUTFIELD not handled yet");
454 case Opcodes.INVOKEVIRTUAL :
455
456 throw new RuntimeException("INVOKEVIRTUAL not handled yet");
457 case Opcodes.INVOKESPECIAL :
458
459 throw new RuntimeException("INVOKESPECIAL not handled yet");
460 case Opcodes.INVOKESTATIC :
461 return new InvokeStaticTransformer(stack0Converted, stack1Converted).getReplacement(insn, this);
462 case Opcodes.INVOKEINTERFACE :
463
464 throw new RuntimeException("INVOKEINTERFACE not handled yet");
465 case Opcodes.INVOKEDYNAMIC :
466
467 throw new RuntimeException("INVOKEDYNAMIC not handled yet");
468 case Opcodes.NEWARRAY :
469
470 throw new RuntimeException("NEWARRAY not handled yet");
471 case Opcodes.ANEWARRAY :
472
473 throw new RuntimeException("ANEWARRAY not handled yet");
474 case Opcodes.MULTIANEWARRAY :
475
476 throw new RuntimeException("MULTIANEWARRAY not handled yet");
477 default:
478 throw new DifferentiationException(NablaMessages.UNABLE_TO_HANDLE_INSTRUCTION, insn.getOpcode());
479 }
480
481 }
482
483
484
485
486
487
488 private InsnList replaceGetField(final FieldInsnNode fieldInsn) throws DifferentiationException {
489
490 final InsnList list = new InsnList();
491
492
493 list.add(new LdcInsnNode(fieldInsn.name));
494 list.add(new MethodInsnNode(Opcodes.INVOKESPECIAL, derivedName, "getPrimitiveField",
495 Type.getMethodDescriptor(Type.getType(Object.class),
496 Type.getType(String.class))));
497
498
499 final Type type = Type.getType(fieldInsn.desc);
500 final Type boxedType;
501 final String valueMethodName;
502 switch (type.getSort()) {
503 case Type.VOID:
504 throw new DifferentiationException(NablaMessages.CANNOT_GET_VOID_FIELD, fieldInsn.name);
505 case Type.BOOLEAN:
506 valueMethodName = "booleanValue";
507 boxedType = Type.getType(Boolean.class);
508 break;
509 case Type.CHAR:
510 valueMethodName = "charValue";
511 boxedType = Type.getType(Character.class);
512 break;
513 case Type.BYTE:
514 valueMethodName = "byteValue";
515 boxedType = Type.getType(Byte.class);
516 break;
517 case Type.SHORT:
518 valueMethodName = "shortValue";
519 boxedType = Type.getType(Short.class);
520 break;
521 case Type.INT:
522 valueMethodName = "intValue";
523 boxedType = Type.getType(Integer.class);
524 break;
525 case Type.FLOAT:
526 valueMethodName = "floatValue";
527 boxedType = Type.getType(Float.class);
528 break;
529 case Type.LONG:
530 valueMethodName = "longValue";
531 boxedType = Type.getType(Long.class);
532 break;
533 case Type.DOUBLE:
534 valueMethodName = "doubleValue";
535 boxedType = Type.getType(Double.class);
536 break;
537 default :
538
539 valueMethodName = null;
540 boxedType = null;
541
542 }
543 if (boxedType != null) {
544 list.add(new TypeInsnNode(Opcodes.CHECKCAST, boxedType.getInternalName()));
545 }
546 if (valueMethodName != null) {
547 list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, boxedType.getInternalName(),
548 valueMethodName,
549 Type.getMethodDescriptor(type, new Type[0])));
550 }
551
552 return list;
553
554 }
555
556
557
558
559
560 public boolean isMathImplementationClass(final String name) {
561 return mathClasses.contains(name);
562 }
563
564
565
566
567 public InsnList doubleToDerivativeStructureConversion() {
568
569 final InsnList list = new InsnList();
570
571
572 list.add(new TypeInsnNode(Opcodes.NEW,
573 Type.getInternalName(DerivativeStructure.class)));
574 list.add(new InsnNode(Opcodes.DUP_X2));
575 list.add(new InsnNode(Opcodes.DUP_X2));
576 list.add(new InsnNode(Opcodes.POP));
577 list.add(new VarInsnNode(Opcodes.ALOAD, getInputDSIndex()));
578 list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL,
579 Type.getInternalName(DerivativeStructure.class),
580 "getFreeParameters",
581 Type.getMethodDescriptor(Type.INT_TYPE)));
582 list.add(new VarInsnNode(Opcodes.ALOAD, 1));
583 list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL,
584 Type.getInternalName(DerivativeStructure.class),
585 "getOrder",
586 Type.getMethodDescriptor(Type.INT_TYPE)));
587 list.add(new InsnNode(Opcodes.DUP2_X2));
588 list.add(new InsnNode(Opcodes.POP2));
589 list.add(new MethodInsnNode(Opcodes.INVOKESPECIAL,
590 Type.getInternalName(DerivativeStructure.class),
591 "<init>",
592 Type.getMethodDescriptor(Type.VOID_TYPE,
593 Type.INT_TYPE,
594 Type.INT_TYPE,
595 Type.DOUBLE_TYPE)));
596
597 return list;
598
599 }
600
601
602
603
604
605
606
607
608 public void useLocal(final int index, final int size)
609 throws DifferentiationException {
610 if ((index < 0) || ((index + size) > usedLocals.length)) {
611 throw new DifferentiationException(NablaMessages.INDEX_OF_LOCAL_VARIABLE_OUT_OF_RANGE,
612 size, index, 1, MAX_TEMP);
613 }
614 for (int i = index; i < index + size; ++i) {
615 usedLocals[i] = true;
616 }
617 }
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639 public int getTmp(final int number) throws DifferentiationException {
640 if ((number < 0) || (number > MAX_TEMP)) {
641 throw new DifferentiationException(NablaMessages.NUMBER_OF_TEMPORARY_VARIABLES_OUT_OF_RANGE,
642 number, 1, MAX_TEMP);
643 }
644 final int index = usedLocals.length - 2 * number;
645 useLocal(index, 2);
646 return index;
647 }
648
649
650
651
652 private void shiftVariable(final VarInsnNode insn) {
653 int shifted = 0;
654 for (int i = 0; i < insn.var; ++i) {
655 if (usedLocals[i]) {
656 ++shifted;
657 }
658 }
659 insn.var = shifted;
660 }
661
662
663
664
665 private int maxVariables() {
666 int max = 0;
667 for (final boolean isUsed : usedLocals) {
668 if (isUsed) {
669 ++max;
670 }
671 }
672 return max;
673 }
674
675
676 private class FlowAnalyzer extends Analyzer<TrackingValue> {
677
678
679 private final InsnList instructions;
680
681
682
683
684
685 public FlowAnalyzer(final Interpreter<TrackingValue> interpreter,
686 final InsnList instructions) {
687 super(interpreter);
688 this.instructions = instructions;
689 }
690
691
692
693
694
695 protected void newControlFlowEdge(final int insn, final int successor) {
696
697 final AbstractInsnNode node = instructions.get(insn);
698 Set<AbstractInsnNode> set = successors.get(node);
699 if (set == null) {
700 set = new HashSet<AbstractInsnNode>();
701 successors.put(node, set);
702 }
703 set.add(instructions.get(successor));
704 }
705
706 }
707
708 }