001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    package org.apache.commons.nabla.forward.arithmetic;
018    
019    import org.apache.commons.nabla.DifferentiationException;
020    import org.apache.commons.nabla.forward.analysis.InstructionsTransformer;
021    import org.apache.commons.nabla.forward.analysis.MethodDifferentiator;
022    import org.objectweb.asm.Opcodes;
023    import org.objectweb.asm.Type;
024    import org.objectweb.asm.tree.AbstractInsnNode;
025    import org.objectweb.asm.tree.InsnList;
026    import org.objectweb.asm.tree.InsnNode;
027    import org.objectweb.asm.tree.MethodInsnNode;
028    import org.objectweb.asm.tree.VarInsnNode;
029    
030    /** Differentiation transformer for DREM instructions.
031     * @version $Id$
032     */
033    public class DRemTransformer implements InstructionsTransformer {
034    
035        /** Name of the {@link org.apache.commons.math3.analysis.differentiation.DerivativeStructure} method corresponding to the remainder. */
036        private static final String REMAINDER_METHOD = "remainder";
037    
038        /** Name of the {@link org.apache.commons.math3.analysis.differentiation.DerivativeStructure} method corresponding to the value getter. */
039        private static final String VALUE_GETTER_METHOD = "getValue";
040    
041        /** Name of the {@link org.apache.commons.math3.analysis.differentiation.DerivativeStructure} method corresponding to the multiplication. */
042        private static final String MULTIPLY_METHOD = "multiply";
043    
044        /** Name of the {@link org.apache.commons.math3.analysis.differentiation.DerivativeStructure} method corresponding to the addition. */
045        private static final String ADD_METHOD = "add";
046    
047        /** Indicator for top stack element conversion. */
048        private final boolean stack0Converted;
049    
050        /** Indicator for next to top stack element conversion. */
051        private final boolean stack1Converted;
052    
053        /** Simple constructor.
054         * @param stack0Converted if true, the top level stack element has already been converted
055         * @param stack1Converted if true, the next to top level stack element has already been converted
056         */
057        public DRemTransformer(final boolean stack0Converted, final boolean stack1Converted) {
058            this.stack0Converted = stack0Converted;
059            this.stack1Converted = stack1Converted;
060        }
061    
062        /** {@inheritDoc} */
063        public InsnList getReplacement(final AbstractInsnNode insn,
064                                       final MethodDifferentiator methodDifferentiator)
065            throws DifferentiationException {
066    
067            final InsnList list = new InsnList();
068    
069            if (stack1Converted) {
070                if (stack0Converted) {
071                    list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, DS_TYPE.getInternalName(),
072                                                REMAINDER_METHOD,
073                                                Type.getMethodDescriptor(DS_TYPE, DS_TYPE)));
074                } else {
075                    list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, DS_TYPE.getInternalName(),
076                                                REMAINDER_METHOD,
077                                                Type.getMethodDescriptor(DS_TYPE, Type.DOUBLE_TYPE)));
078                }
079            } else {
080    
081                // set up a temporary variable
082                final int tmp1 = methodDifferentiator.getTmp(1);
083    
084                // operand stack initial state: a, ds_b
085                list.add(new InsnNode(Opcodes.DUP_X2));                                            // => ds_b, a, ds_b
086                list.add(new InsnNode(Opcodes.POP));                                               // => ds_b, a
087                list.add(new VarInsnNode(Opcodes.DSTORE, tmp1));                                   // => ds_b
088                list.add(new InsnNode(Opcodes.DUP));                                               // => ds_b, ds_b
089                list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, DS_TYPE.getInternalName(),
090                                            VALUE_GETTER_METHOD,
091                                            Type.getMethodDescriptor(Type.DOUBLE_TYPE)));          // => ds_b, b0
092                list.add(new InsnNode(Opcodes.DUP2));                                              // => ds_b, b0, b0
093                list.add(new VarInsnNode(Opcodes.DLOAD,  tmp1));                                   // => ds_b, b0, b0, a
094                list.add(new InsnNode(Opcodes.DUP2_X2));                                           // => ds_b, b0, a, b0, a
095                list.add(new InsnNode(Opcodes.POP2));                                              // => ds_b, b0, a, b0
096                list.add(new InsnNode(Opcodes.DREM));                                              // => ds_b, b0, a%b0
097                list.add(new VarInsnNode(Opcodes.DLOAD,  tmp1));                                   // => ds_b, b0, a%b0, a
098                list.add(new InsnNode(Opcodes.DSUB));                                              // => ds_b, b0, a%b0-a
099                list.add(new InsnNode(Opcodes.DUP2_X2));                                           // => ds_b, a%b0-a, b0, a%b0-a
100                list.add(new InsnNode(Opcodes.POP2));                                              // => ds_b, a%b0-a, b0
101                list.add(new InsnNode(Opcodes.DDIV));                                              // => ds_b, q=(a%b0-a)/b0
102                list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, DS_TYPE.getInternalName(),
103                                            MULTIPLY_METHOD,
104                                            Type.getMethodDescriptor(DS_TYPE, Type.DOUBLE_TYPE))); // => q*b
105                list.add(new VarInsnNode(Opcodes.DLOAD,  tmp1));                                   // => q*b, a
106                list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, DS_TYPE.getInternalName(),
107                                            ADD_METHOD,
108                                            Type.getMethodDescriptor(DS_TYPE, Type.DOUBLE_TYPE))); // => a+q*b
109    
110            }
111    
112            return list;
113    
114        }
115    
116    }