001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    package org.apache.commons.nabla.forward;
018    
019    import java.io.IOException;
020    import java.io.OutputStream;
021    import java.lang.reflect.Constructor;
022    import java.lang.reflect.InvocationTargetException;
023    import java.util.HashMap;
024    import java.util.HashSet;
025    import java.util.Set;
026    
027    import org.apache.commons.math3.analysis.UnivariateFunction;
028    import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
029    import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction;
030    import org.apache.commons.math3.analysis.differentiation.UnivariateFunctionDifferentiator;
031    import org.apache.commons.math3.util.FastMath;
032    import org.apache.commons.nabla.DifferentiationException;
033    import org.apache.commons.nabla.NablaMessages;
034    import org.apache.commons.nabla.forward.analysis.ClassDifferentiator;
035    import org.objectweb.asm.ClassWriter;
036    import org.objectweb.asm.Type;
037    import org.objectweb.asm.tree.ClassNode;
038    
039    /** Algorithmic differentiator class in forward mode based on bytecode analysis.
040     * <p>This class is an implementation of the {@link UnivariateFunctionDifferentiator}
041     * interface that computes <em>exact</em> differentials completely automatically
042     * and generate java classes and instances that compute the differential
043     * of the function as if they were hand-coded and compiled.</p>
044     * <p>The derivative bytecode created the first time an instance of a given class
045     * is differentiated is cached and will be reused if other instances of the same class
046     * are to be created later. The cache can also be dumped in a jar file for
047     * use in an application without bringing the full nabla library and its
048     * dependencies.</p>
049     * <p>This differentiator can handle only pure bytecode methods and known methods
050     * from math implementation classes like {@link java.lang.Math Math}, {@link
051     * java.lang.StrictMath StrictMath} or {@link FastMath}. Pure bytecode methods are
052     * analyzed and converted. Methods from math implementation classes are only
053     * recognized by class and name and replaced by predefined derivative code.</p>
054     * @see org.apache.commons.nabla.caching.FetchDifferentiator
055     * @version $Id$
056     */
057    public class ForwardModeDifferentiator implements UnivariateFunctionDifferentiator {
058    
059        /** UnivariateFunction/UnivariateDifferentiableFunction map. */
060        private final HashMap<Class<? extends UnivariateFunction>,
061                              Class<? extends UnivariateDifferentiableFunction>> map;
062    
063        /** Class name/ bytecode map. */
064        private final HashMap<String, byte[]> byteCodeMap;
065    
066        /** Math implementation classes. */
067        private final Set<String> mathClasses;
068    
069        /** Simple constructor.
070         * <p>Build a ForwardAlgorithmicDifferentiator instance with an empty cache.</p>
071         */
072        public ForwardModeDifferentiator() {
073            map         = new HashMap<Class<? extends UnivariateFunction>,
074                                      Class<? extends UnivariateDifferentiableFunction>>();
075            byteCodeMap = new HashMap<String, byte[]>();
076            mathClasses = new HashSet<String>();
077            addMathImplementation(Math.class);
078            addMathImplementation(StrictMath.class);
079            addMathImplementation(FastMath.class);
080        }
081    
082        /** Add an implementation class for mathematical functions.
083         * <p>At construction, the differentiator considers only the {@link
084         * java.lang.Math Math} and {@link java.lang.StrictMath StrictMath}
085         * classes are math implementation classes. It may be useful to add
086         * other classes for example to add some missing functions like
087         * inverse hyperbolic cosine that are not provided by the standard
088         * java classes as of Java 1.6.</p>
089         * @param mathClass implementation class for mathematical functions
090         */
091        public void addMathImplementation(final Class<?> mathClass) {
092            mathClasses.add(mathClass.getName().replace('.', '/'));
093        }
094    
095        /** Dump the cache into a stream.
096         * @param out output stream where to dump the cache
097         */
098        public void dumpCache(final OutputStream out) {
099            // TODO: implement cache persistence
100            throw new RuntimeException("not implemented yet");
101        }
102    
103        /** {@inheritDoc} */
104        public UnivariateDifferentiableFunction differentiate(final UnivariateFunction d) {
105    
106            // get the derivative class
107            final Class<? extends UnivariateDifferentiableFunction> derivativeClass =
108                getDerivativeClass(d.getClass());
109    
110            try {
111    
112                // create the instance
113                final Constructor<? extends UnivariateDifferentiableFunction> constructor =
114                    derivativeClass.getConstructor(d.getClass());
115                return constructor.newInstance(d);
116    
117            } catch (InstantiationException ie) {
118                throw new DifferentiationException(NablaMessages.CANNOT_INSTANTIATE_ABSTRACT_CLASS,
119                                                   derivativeClass.getName(), ie.getMessage());
120            } catch (IllegalAccessException iae) {
121                throw new DifferentiationException(NablaMessages.ILLEGAL_ACCESS_TO_CONSTRUCTOR,
122                                                   derivativeClass.getName(), iae.getMessage());
123            } catch (NoSuchMethodException nsme) {
124                throw new DifferentiationException(NablaMessages.CANNOT_BUILD_CLASS_FROM_OTHER_CLASS,
125                                                   derivativeClass.getName(), d.getClass().getName(), nsme.getMessage());
126            } catch (InvocationTargetException ite) {
127                throw new DifferentiationException(NablaMessages.CANNOT_INSTANTIATE_CLASS_FROM_OTHER_INSTANCE,
128                                                   derivativeClass.getName(), d.getClass().getName(), ite.getMessage());
129            } catch (VerifyError ve) {
130                throw new DifferentiationException(NablaMessages.INCORRECT_GENERATED_CODE,
131                                                   derivativeClass.getName(), d.getClass().getName(), ve.getMessage());
132            }
133    
134        }
135    
136        /** Get the derivative class of a differentiable class.
137         * <p>The derivative class is either built on the fly
138         * or retrieved from the cache if it has been built previously.</p>
139         * @param differentiableClass class to differentiate
140         * @return derivative class
141         * @throws DifferentiationException if the class cannot be differentiated
142         */
143        private Class<? extends UnivariateDifferentiableFunction>
144        getDerivativeClass(final Class<? extends UnivariateFunction> differentiableClass)
145            throws DifferentiationException {
146    
147            // lookup in the map if the class has already been differentiated
148            Class<? extends UnivariateDifferentiableFunction> derivativeClass =
149                map.get(differentiableClass);
150    
151            // build the derivative class if it does not exist yet
152            if (derivativeClass == null) {
153    
154                // perform algorithmic differentiation
155                derivativeClass = createDerivativeClass(differentiableClass);
156    
157                // put the newly created class in the map
158                map.put(differentiableClass, derivativeClass);
159    
160            }
161    
162            // return the derivative class
163            return derivativeClass;
164    
165        }
166    
167        /** Build a derivative class of a differentiable class.
168         * @param differentiableClass class to differentiate
169         * @return derivative class
170         * @throws DifferentiationException if the class cannot be differentiated
171         */
172        private Class<? extends UnivariateDifferentiableFunction>
173        createDerivativeClass(final Class<? extends UnivariateFunction> differentiableClass)
174            throws DifferentiationException {
175            try {
176    
177                // differentiate the function embedded in the differentiable class
178                final ClassDifferentiator differentiator =
179                    new ClassDifferentiator(differentiableClass, mathClasses);
180                final Type dsType = Type.getType(DerivativeStructure.class);
181                differentiator.differentiateMethod("value",
182                                                   Type.getMethodDescriptor(Type.DOUBLE_TYPE, Type.DOUBLE_TYPE),
183                                                   Type.getMethodDescriptor(dsType, dsType));
184    
185                // create the derivative class
186                final ClassNode   derived = differentiator.getDerivedClass();
187                final ClassWriter writer  = new ClassWriter(ClassWriter.COMPUTE_FRAMES);
188                final String name = derived.name.replace('/', '.');
189                derived.accept(writer);
190                final byte[] bytecode = writer.toByteArray();
191    
192                final Class<? extends UnivariateDifferentiableFunction> dClass =
193                        new DerivativeLoader(differentiableClass).defineClass(name, bytecode);
194                byteCodeMap.put(name, bytecode);
195                return dClass;
196    
197            } catch (IOException ioe) {
198                throw new DifferentiationException(NablaMessages.CANNOT_READ_CLASS,
199                                                   differentiableClass.getName(), ioe.getMessage());
200            }
201        }
202    
203        /** Class loader generating derivative classes. */
204        private static class DerivativeLoader extends ClassLoader {
205    
206            /** Simple constructor.
207             * @param differentiableClass differentiable class
208             */
209            public DerivativeLoader(final Class<? extends UnivariateFunction> differentiableClass) {
210                super(differentiableClass.getClassLoader());
211            }
212    
213            /** Define a derivative class.
214             * @param name name of the differentiated class
215             * @param bytecode bytecode of the differentiated class
216             * @return a generated derivative class
217             */
218            @SuppressWarnings("unchecked")
219            public Class<? extends UnivariateDifferentiableFunction>
220            defineClass(final String name, final byte[] bytecode) {
221                return (Class<? extends UnivariateDifferentiableFunction>) defineClass(name, bytecode, 0, bytecode.length);
222            }
223        }
224    
225    }