1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17 package org.apache.commons.nabla.forward.instructions;
18
19 import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
20 import org.apache.commons.nabla.DifferentiationException;
21 import org.apache.commons.nabla.NablaMessages;
22 import org.apache.commons.nabla.forward.analysis.InstructionsTransformer;
23 import org.apache.commons.nabla.forward.analysis.MethodDifferentiator;
24 import org.objectweb.asm.Opcodes;
25 import org.objectweb.asm.Type;
26 import org.objectweb.asm.tree.AbstractInsnNode;
27 import org.objectweb.asm.tree.InsnList;
28 import org.objectweb.asm.tree.InsnNode;
29 import org.objectweb.asm.tree.MethodInsnNode;
30
31 /** Differentiation transformer for INVOKESTATIC instructions.
32 * @version $Id$
33 */
34 public class InvokeStaticTransformer implements InstructionsTransformer {
35
36 /** Indicator for top stack element conversion. */
37 private final boolean stack0Converted;
38
39 /** Indicator for next to top stack element conversion. */
40 private final boolean stack1Converted;
41
42 /** Simple constructor.
43 * @param stack0Converted if true, the top level stack element has already been converted
44 * @param stack1Converted if true, the next to top level stack element has already been converted
45 */
46 public InvokeStaticTransformer(final boolean stack0Converted, final boolean stack1Converted) {
47 this.stack0Converted = stack0Converted;
48 this.stack1Converted = stack1Converted;
49 }
50
51 /** {@inheritDoc} */
52 public InsnList getReplacement(final AbstractInsnNode insn,
53 final MethodDifferentiator methodDifferentiator)
54 throws DifferentiationException {
55
56 final MethodInsnNode methodInsn = (MethodInsnNode) insn;
57 if (!methodDifferentiator.isMathImplementationClass(methodInsn.owner)) {
58 // TODO: handle INVOKESTATIC on non math related classes
59 throw new RuntimeException("INVOKESTATIC on non math related classes not handled yet" +
60 methodInsn.owner + methodInsn.owner);
61 }
62
63 final InsnList list = new InsnList();
64
65 if (methodInsn.desc.equals(Type.getMethodDescriptor(Type.DOUBLE_TYPE, Type.DOUBLE_TYPE))) {
66 // this is a univariate method like sin, cos, exp ...
67
68 try {
69 // check that a corresponding method exist for DerivativeStructure
70 DerivativeStructure.class.getDeclaredMethod(methodInsn.name);
71 } catch (NoSuchMethodException nsme) {
72 throw new DifferentiationException(NablaMessages.UNKNOWN_METHOD,
73 methodInsn.owner, methodInsn.name);
74 }
75
76 list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL,
77 DS_TYPE.getInternalName(), methodInsn.name,
78 Type.getMethodDescriptor(DS_TYPE)));
79
80 } else if (methodInsn.desc.equals(Type.getMethodDescriptor(Type.DOUBLE_TYPE, Type.DOUBLE_TYPE, Type.DOUBLE_TYPE))) {
81 // this is a bivariate method like atan2, pow ...
82
83 if (methodInsn.name.equals("pow")) {
84 // special case for pow: in DerivativeStructure, it is an instance method,
85 // not a static method as the other two parameters functions like atan2 or hypot
86
87 if (stack1Converted) {
88 if (!stack0Converted) {
89 list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL,
90 DS_TYPE.getInternalName(), methodInsn.name,
91 Type.getMethodDescriptor(DS_TYPE, Type.DOUBLE_TYPE)));
92 } else {
93 list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL,
94 DS_TYPE.getInternalName(), methodInsn.name,
95 Type.getMethodDescriptor(DS_TYPE, DS_TYPE)));
96 }
97 } else {
98
99 // initial stack state: x, ds_y
100 list.add(new InsnNode(Opcodes.DUP_X2)); // => ds_y, x, ds_y
101 list.add(new InsnNode(Opcodes.POP)); // => ds_y, x
102 list.add(methodDifferentiator.doubleToDerivativeStructureConversion()); // => ds_y, ds_x
103 list.add(new InsnNode(Opcodes.SWAP)); // => ds_x, ds_y
104
105 // call the static two parameters method for DerivativeStructure
106 list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL,
107 DS_TYPE.getInternalName(), methodInsn.name,
108 Type.getMethodDescriptor(DS_TYPE, DS_TYPE)));
109 }
110
111 } else {
112
113 if (stack1Converted) {
114 if (!stack0Converted) {
115 // the top level element is not a DerivativeStructure, convert it
116 list.add(methodDifferentiator.doubleToDerivativeStructureConversion());
117 }
118 } else {
119 // initial stack state: x, ds_y
120 list.add(new InsnNode(Opcodes.DUP_X2)); // => ds_y, x, ds_y
121 list.add(new InsnNode(Opcodes.POP)); // => ds_y, x
122 list.add(methodDifferentiator.doubleToDerivativeStructureConversion()); // => ds_y, ds_x
123 list.add(new InsnNode(Opcodes.SWAP)); // => ds_x, ds_y
124 }
125
126 // call the static two parameters method for DerivativeStructure
127 list.add(new MethodInsnNode(Opcodes.INVOKESTATIC,
128 DS_TYPE.getInternalName(), methodInsn.name,
129 Type.getMethodDescriptor(DS_TYPE, DS_TYPE, DS_TYPE)));
130
131 }
132
133 } else {
134 throw new DifferentiationException(NablaMessages.UNKNOWN_METHOD,
135 methodInsn.owner, methodInsn.name);
136 }
137
138 return list;
139
140 }
141
142 }