001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    package org.apache.commons.nabla.forward.instructions;
018    
019    import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
020    import org.apache.commons.nabla.DifferentiationException;
021    import org.apache.commons.nabla.NablaMessages;
022    import org.apache.commons.nabla.forward.analysis.InstructionsTransformer;
023    import org.apache.commons.nabla.forward.analysis.MethodDifferentiator;
024    import org.objectweb.asm.Opcodes;
025    import org.objectweb.asm.Type;
026    import org.objectweb.asm.tree.AbstractInsnNode;
027    import org.objectweb.asm.tree.InsnList;
028    import org.objectweb.asm.tree.InsnNode;
029    import org.objectweb.asm.tree.MethodInsnNode;
030    
031    /** Differentiation transformer for INVOKESTATIC instructions.
032     * @version $Id$
033     */
034    public class InvokeStaticTransformer implements InstructionsTransformer {
035    
036        /** Indicator for top stack element conversion. */
037        private final boolean stack0Converted;
038    
039        /** Indicator for next to top stack element conversion. */
040        private final boolean stack1Converted;
041    
042        /** Simple constructor.
043         * @param stack0Converted if true, the top level stack element has already been converted
044         * @param stack1Converted if true, the next to top level stack element has already been converted
045         */
046        public InvokeStaticTransformer(final boolean stack0Converted, final boolean stack1Converted) {
047            this.stack0Converted = stack0Converted;
048            this.stack1Converted = stack1Converted;
049        }
050    
051        /** {@inheritDoc} */
052        public InsnList getReplacement(final AbstractInsnNode insn,
053                                       final MethodDifferentiator methodDifferentiator)
054            throws DifferentiationException {
055    
056            final MethodInsnNode methodInsn = (MethodInsnNode) insn;
057            if (!methodDifferentiator.isMathImplementationClass(methodInsn.owner)) {
058                // TODO: handle INVOKESTATIC on non math related classes
059                throw new RuntimeException("INVOKESTATIC on non math related classes not handled yet" +
060                        methodInsn.owner + methodInsn.owner);
061            }
062    
063            final InsnList list = new InsnList();
064    
065            if (methodInsn.desc.equals(Type.getMethodDescriptor(Type.DOUBLE_TYPE, Type.DOUBLE_TYPE))) {
066                // this is a univariate method like sin, cos, exp ...
067    
068                try {
069                    // check that a corresponding method exist for DerivativeStructure
070                    DerivativeStructure.class.getDeclaredMethod(methodInsn.name);
071                } catch (NoSuchMethodException nsme) {
072                    throw new DifferentiationException(NablaMessages.UNKNOWN_METHOD,
073                                                       methodInsn.owner, methodInsn.name);
074                }
075    
076                list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL,
077                                            DS_TYPE.getInternalName(), methodInsn.name,
078                                            Type.getMethodDescriptor(DS_TYPE)));
079    
080            } else if (methodInsn.desc.equals(Type.getMethodDescriptor(Type.DOUBLE_TYPE, Type.DOUBLE_TYPE, Type.DOUBLE_TYPE))) {
081                // this is a bivariate method like atan2, pow ...
082    
083                if (methodInsn.name.equals("pow")) {
084                    // special case for pow: in DerivativeStructure, it is an instance method,
085                    // not a static method as the other two parameters functions like atan2 or hypot
086    
087                    if (stack1Converted) {
088                        if (!stack0Converted) {
089                            list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL,
090                                                        DS_TYPE.getInternalName(), methodInsn.name,
091                                                        Type.getMethodDescriptor(DS_TYPE, Type.DOUBLE_TYPE)));
092                        } else {
093                            list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL,
094                                                        DS_TYPE.getInternalName(), methodInsn.name,
095                                                        Type.getMethodDescriptor(DS_TYPE, DS_TYPE)));
096                        }
097                    } else {
098    
099                        // initial stack state: x, ds_y
100                        list.add(new InsnNode(Opcodes.DUP_X2));                                 // => ds_y, x, ds_y
101                        list.add(new InsnNode(Opcodes.POP));                                    // => ds_y, x
102                        list.add(methodDifferentiator.doubleToDerivativeStructureConversion()); // => ds_y, ds_x
103                        list.add(new InsnNode(Opcodes.SWAP));                                   // => ds_x, ds_y
104    
105                        // call the static two parameters method for DerivativeStructure
106                        list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL,
107                                                    DS_TYPE.getInternalName(), methodInsn.name,
108                                                    Type.getMethodDescriptor(DS_TYPE, DS_TYPE)));
109                    }
110    
111                } else {
112    
113                    if (stack1Converted) {
114                        if (!stack0Converted) {
115                            // the top level element is not a DerivativeStructure, convert it
116                            list.add(methodDifferentiator.doubleToDerivativeStructureConversion());
117                        }
118                    } else {
119                        // initial stack state: x, ds_y
120                        list.add(new InsnNode(Opcodes.DUP_X2));                                 // => ds_y, x, ds_y
121                        list.add(new InsnNode(Opcodes.POP));                                    // => ds_y, x
122                        list.add(methodDifferentiator.doubleToDerivativeStructureConversion()); // => ds_y, ds_x
123                        list.add(new InsnNode(Opcodes.SWAP));                                   // => ds_x, ds_y
124                    }
125    
126                    // call the static two parameters method for DerivativeStructure
127                    list.add(new MethodInsnNode(Opcodes.INVOKESTATIC,
128                                                DS_TYPE.getInternalName(), methodInsn.name,
129                                                Type.getMethodDescriptor(DS_TYPE, DS_TYPE, DS_TYPE)));
130    
131                }
132    
133            } else {
134                throw new DifferentiationException(NablaMessages.UNKNOWN_METHOD,
135                                                   methodInsn.owner, methodInsn.name);
136            }
137    
138            return list;
139    
140        }
141    
142    }