1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.nabla.forward.instructions;
18
19 import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
20 import org.apache.commons.nabla.DifferentiationException;
21 import org.apache.commons.nabla.NablaMessages;
22 import org.apache.commons.nabla.forward.analysis.InstructionsTransformer;
23 import org.apache.commons.nabla.forward.analysis.MethodDifferentiator;
24 import org.objectweb.asm.Opcodes;
25 import org.objectweb.asm.Type;
26 import org.objectweb.asm.tree.AbstractInsnNode;
27 import org.objectweb.asm.tree.InsnList;
28 import org.objectweb.asm.tree.InsnNode;
29 import org.objectweb.asm.tree.MethodInsnNode;
30
31
32
33
34 public class InvokeStaticTransformer implements InstructionsTransformer {
35
36
37 private final boolean stack0Converted;
38
39
40 private final boolean stack1Converted;
41
42
43
44
45
46 public InvokeStaticTransformer(final boolean stack0Converted, final boolean stack1Converted) {
47 this.stack0Converted = stack0Converted;
48 this.stack1Converted = stack1Converted;
49 }
50
51
52 public InsnList getReplacement(final AbstractInsnNode insn,
53 final MethodDifferentiator methodDifferentiator)
54 throws DifferentiationException {
55
56 final MethodInsnNode methodInsn = (MethodInsnNode) insn;
57 if (!methodDifferentiator.isMathImplementationClass(methodInsn.owner)) {
58
59 throw new RuntimeException("INVOKESTATIC on non math related classes not handled yet" +
60 methodInsn.owner + methodInsn.owner);
61 }
62
63 final InsnList list = new InsnList();
64
65 if (methodInsn.desc.equals(Type.getMethodDescriptor(Type.DOUBLE_TYPE, Type.DOUBLE_TYPE))) {
66
67
68 try {
69
70 DerivativeStructure.class.getDeclaredMethod(methodInsn.name);
71 } catch (NoSuchMethodException nsme) {
72 throw new DifferentiationException(NablaMessages.UNKNOWN_METHOD,
73 methodInsn.owner, methodInsn.name);
74 }
75
76 list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL,
77 DS_TYPE.getInternalName(), methodInsn.name,
78 Type.getMethodDescriptor(DS_TYPE)));
79
80 } else if (methodInsn.desc.equals(Type.getMethodDescriptor(Type.DOUBLE_TYPE, Type.DOUBLE_TYPE, Type.DOUBLE_TYPE))) {
81
82
83 if (methodInsn.name.equals("pow")) {
84
85
86
87 if (stack1Converted) {
88 if (!stack0Converted) {
89 list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL,
90 DS_TYPE.getInternalName(), methodInsn.name,
91 Type.getMethodDescriptor(DS_TYPE, Type.DOUBLE_TYPE)));
92 } else {
93 list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL,
94 DS_TYPE.getInternalName(), methodInsn.name,
95 Type.getMethodDescriptor(DS_TYPE, DS_TYPE)));
96 }
97 } else {
98
99
100 list.add(new InsnNode(Opcodes.DUP_X2));
101 list.add(new InsnNode(Opcodes.POP));
102 list.add(methodDifferentiator.doubleToDerivativeStructureConversion());
103 list.add(new InsnNode(Opcodes.SWAP));
104
105
106 list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL,
107 DS_TYPE.getInternalName(), methodInsn.name,
108 Type.getMethodDescriptor(DS_TYPE, DS_TYPE)));
109 }
110
111 } else {
112
113 if (stack1Converted) {
114 if (!stack0Converted) {
115
116 list.add(methodDifferentiator.doubleToDerivativeStructureConversion());
117 }
118 } else {
119
120 list.add(new InsnNode(Opcodes.DUP_X2));
121 list.add(new InsnNode(Opcodes.POP));
122 list.add(methodDifferentiator.doubleToDerivativeStructureConversion());
123 list.add(new InsnNode(Opcodes.SWAP));
124 }
125
126
127 list.add(new MethodInsnNode(Opcodes.INVOKESTATIC,
128 DS_TYPE.getInternalName(), methodInsn.name,
129 Type.getMethodDescriptor(DS_TYPE, DS_TYPE, DS_TYPE)));
130
131 }
132
133 } else {
134 throw new DifferentiationException(NablaMessages.UNKNOWN_METHOD,
135 methodInsn.owner, methodInsn.name);
136 }
137
138 return list;
139
140 }
141
142 }