View Javadoc

1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.nabla.forward.instructions;
18  
19  import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
20  import org.apache.commons.nabla.DifferentiationException;
21  import org.apache.commons.nabla.NablaMessages;
22  import org.apache.commons.nabla.forward.analysis.InstructionsTransformer;
23  import org.apache.commons.nabla.forward.analysis.MethodDifferentiator;
24  import org.objectweb.asm.Opcodes;
25  import org.objectweb.asm.Type;
26  import org.objectweb.asm.tree.AbstractInsnNode;
27  import org.objectweb.asm.tree.InsnList;
28  import org.objectweb.asm.tree.InsnNode;
29  import org.objectweb.asm.tree.MethodInsnNode;
30  
31  /** Differentiation transformer for INVOKESTATIC instructions.
32   * @version $Id$
33   */
34  public class InvokeStaticTransformer implements InstructionsTransformer {
35  
36      /** Indicator for top stack element conversion. */
37      private final boolean stack0Converted;
38  
39      /** Indicator for next to top stack element conversion. */
40      private final boolean stack1Converted;
41  
42      /** Simple constructor.
43       * @param stack0Converted if true, the top level stack element has already been converted
44       * @param stack1Converted if true, the next to top level stack element has already been converted
45       */
46      public InvokeStaticTransformer(final boolean stack0Converted, final boolean stack1Converted) {
47          this.stack0Converted = stack0Converted;
48          this.stack1Converted = stack1Converted;
49      }
50  
51      /** {@inheritDoc} */
52      public InsnList getReplacement(final AbstractInsnNode insn,
53                                     final MethodDifferentiator methodDifferentiator)
54          throws DifferentiationException {
55  
56          final MethodInsnNode methodInsn = (MethodInsnNode) insn;
57          if (!methodDifferentiator.isMathImplementationClass(methodInsn.owner)) {
58              // TODO: handle INVOKESTATIC on non math related classes
59              throw new RuntimeException("INVOKESTATIC on non math related classes not handled yet" +
60                      methodInsn.owner + methodInsn.owner);
61          }
62  
63          final InsnList list = new InsnList();
64  
65          if (methodInsn.desc.equals(Type.getMethodDescriptor(Type.DOUBLE_TYPE, Type.DOUBLE_TYPE))) {
66              // this is a univariate method like sin, cos, exp ...
67  
68              try {
69                  // check that a corresponding method exist for DerivativeStructure
70                  DerivativeStructure.class.getDeclaredMethod(methodInsn.name);
71              } catch (NoSuchMethodException nsme) {
72                  throw new DifferentiationException(NablaMessages.UNKNOWN_METHOD,
73                                                     methodInsn.owner, methodInsn.name);
74              }
75  
76              list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL,
77                                          DS_TYPE.getInternalName(), methodInsn.name,
78                                          Type.getMethodDescriptor(DS_TYPE)));
79  
80          } else if (methodInsn.desc.equals(Type.getMethodDescriptor(Type.DOUBLE_TYPE, Type.DOUBLE_TYPE, Type.DOUBLE_TYPE))) {
81              // this is a bivariate method like atan2, pow ...
82  
83              if (methodInsn.name.equals("pow")) {
84                  // special case for pow: in DerivativeStructure, it is an instance method,
85                  // not a static method as the other two parameters functions like atan2 or hypot
86  
87                  if (stack1Converted) {
88                      if (!stack0Converted) {
89                          list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL,
90                                                      DS_TYPE.getInternalName(), methodInsn.name,
91                                                      Type.getMethodDescriptor(DS_TYPE, Type.DOUBLE_TYPE)));
92                      } else {
93                          list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL,
94                                                      DS_TYPE.getInternalName(), methodInsn.name,
95                                                      Type.getMethodDescriptor(DS_TYPE, DS_TYPE)));
96                      }
97                  } else {
98  
99                      // initial stack state: x, ds_y
100                     list.add(new InsnNode(Opcodes.DUP_X2));                                 // => ds_y, x, ds_y
101                     list.add(new InsnNode(Opcodes.POP));                                    // => ds_y, x
102                     list.add(methodDifferentiator.doubleToDerivativeStructureConversion()); // => ds_y, ds_x
103                     list.add(new InsnNode(Opcodes.SWAP));                                   // => ds_x, ds_y
104 
105                     // call the static two parameters method for DerivativeStructure
106                     list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL,
107                                                 DS_TYPE.getInternalName(), methodInsn.name,
108                                                 Type.getMethodDescriptor(DS_TYPE, DS_TYPE)));
109                 }
110 
111             } else {
112 
113                 if (stack1Converted) {
114                     if (!stack0Converted) {
115                         // the top level element is not a DerivativeStructure, convert it
116                         list.add(methodDifferentiator.doubleToDerivativeStructureConversion());
117                     }
118                 } else {
119                     // initial stack state: x, ds_y
120                     list.add(new InsnNode(Opcodes.DUP_X2));                                 // => ds_y, x, ds_y
121                     list.add(new InsnNode(Opcodes.POP));                                    // => ds_y, x
122                     list.add(methodDifferentiator.doubleToDerivativeStructureConversion()); // => ds_y, ds_x
123                     list.add(new InsnNode(Opcodes.SWAP));                                   // => ds_x, ds_y
124                 }
125 
126                 // call the static two parameters method for DerivativeStructure
127                 list.add(new MethodInsnNode(Opcodes.INVOKESTATIC,
128                                             DS_TYPE.getInternalName(), methodInsn.name,
129                                             Type.getMethodDescriptor(DS_TYPE, DS_TYPE, DS_TYPE)));
130 
131             }
132 
133         } else {
134             throw new DifferentiationException(NablaMessages.UNKNOWN_METHOD,
135                                                methodInsn.owner, methodInsn.name);
136         }
137 
138         return list;
139 
140     }
141 
142 }