View Javadoc

1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.nabla.forward.arithmetic;
18  
19  import org.apache.commons.nabla.DifferentiationException;
20  import org.apache.commons.nabla.forward.analysis.InstructionsTransformer;
21  import org.apache.commons.nabla.forward.analysis.MethodDifferentiator;
22  import org.objectweb.asm.Opcodes;
23  import org.objectweb.asm.Type;
24  import org.objectweb.asm.tree.AbstractInsnNode;
25  import org.objectweb.asm.tree.InsnList;
26  import org.objectweb.asm.tree.InsnNode;
27  import org.objectweb.asm.tree.MethodInsnNode;
28  import org.objectweb.asm.tree.VarInsnNode;
29  
30  /** Differentiation transformer for DREM instructions.
31   * @version $Id$
32   */
33  public class DRemTransformer implements InstructionsTransformer {
34  
35      /** Name of the {@link org.apache.commons.math3.analysis.differentiation.DerivativeStructure} method corresponding to the remainder. */
36      private static final String REMAINDER_METHOD = "remainder";
37  
38      /** Name of the {@link org.apache.commons.math3.analysis.differentiation.DerivativeStructure} method corresponding to the value getter. */
39      private static final String VALUE_GETTER_METHOD = "getValue";
40  
41      /** Name of the {@link org.apache.commons.math3.analysis.differentiation.DerivativeStructure} method corresponding to the multiplication. */
42      private static final String MULTIPLY_METHOD = "multiply";
43  
44      /** Name of the {@link org.apache.commons.math3.analysis.differentiation.DerivativeStructure} method corresponding to the addition. */
45      private static final String ADD_METHOD = "add";
46  
47      /** Indicator for top stack element conversion. */
48      private final boolean stack0Converted;
49  
50      /** Indicator for next to top stack element conversion. */
51      private final boolean stack1Converted;
52  
53      /** Simple constructor.
54       * @param stack0Converted if true, the top level stack element has already been converted
55       * @param stack1Converted if true, the next to top level stack element has already been converted
56       */
57      public DRemTransformer(final boolean stack0Converted, final boolean stack1Converted) {
58          this.stack0Converted = stack0Converted;
59          this.stack1Converted = stack1Converted;
60      }
61  
62      /** {@inheritDoc} */
63      public InsnList getReplacement(final AbstractInsnNode insn,
64                                     final MethodDifferentiator methodDifferentiator)
65          throws DifferentiationException {
66  
67          final InsnList list = new InsnList();
68  
69          if (stack1Converted) {
70              if (stack0Converted) {
71                  list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, DS_TYPE.getInternalName(),
72                                              REMAINDER_METHOD,
73                                              Type.getMethodDescriptor(DS_TYPE, DS_TYPE)));
74              } else {
75                  list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, DS_TYPE.getInternalName(),
76                                              REMAINDER_METHOD,
77                                              Type.getMethodDescriptor(DS_TYPE, Type.DOUBLE_TYPE)));
78              }
79          } else {
80  
81              // set up a temporary variable
82              final int tmp1 = methodDifferentiator.getTmp(1);
83  
84              // operand stack initial state: a, ds_b
85              list.add(new InsnNode(Opcodes.DUP_X2));                                            // => ds_b, a, ds_b
86              list.add(new InsnNode(Opcodes.POP));                                               // => ds_b, a
87              list.add(new VarInsnNode(Opcodes.DSTORE, tmp1));                                   // => ds_b
88              list.add(new InsnNode(Opcodes.DUP));                                               // => ds_b, ds_b
89              list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, DS_TYPE.getInternalName(),
90                                          VALUE_GETTER_METHOD,
91                                          Type.getMethodDescriptor(Type.DOUBLE_TYPE)));          // => ds_b, b0
92              list.add(new InsnNode(Opcodes.DUP2));                                              // => ds_b, b0, b0
93              list.add(new VarInsnNode(Opcodes.DLOAD,  tmp1));                                   // => ds_b, b0, b0, a
94              list.add(new InsnNode(Opcodes.DUP2_X2));                                           // => ds_b, b0, a, b0, a
95              list.add(new InsnNode(Opcodes.POP2));                                              // => ds_b, b0, a, b0
96              list.add(new InsnNode(Opcodes.DREM));                                              // => ds_b, b0, a%b0
97              list.add(new VarInsnNode(Opcodes.DLOAD,  tmp1));                                   // => ds_b, b0, a%b0, a
98              list.add(new InsnNode(Opcodes.DSUB));                                              // => ds_b, b0, a%b0-a
99              list.add(new InsnNode(Opcodes.DUP2_X2));                                           // => ds_b, a%b0-a, b0, a%b0-a
100             list.add(new InsnNode(Opcodes.POP2));                                              // => ds_b, a%b0-a, b0
101             list.add(new InsnNode(Opcodes.DDIV));                                              // => ds_b, q=(a%b0-a)/b0
102             list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, DS_TYPE.getInternalName(),
103                                         MULTIPLY_METHOD,
104                                         Type.getMethodDescriptor(DS_TYPE, Type.DOUBLE_TYPE))); // => q*b
105             list.add(new VarInsnNode(Opcodes.DLOAD,  tmp1));                                   // => q*b, a
106             list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, DS_TYPE.getInternalName(),
107                                         ADD_METHOD,
108                                         Type.getMethodDescriptor(DS_TYPE, Type.DOUBLE_TYPE))); // => a+q*b
109 
110         }
111 
112         return list;
113 
114     }
115 
116 }