001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    package org.apache.commons.nabla.algorithmic.forward;
018    
019    import java.io.IOException;
020    import java.io.InputStream;
021    import java.io.OutputStream;
022    import java.lang.reflect.Constructor;
023    import java.lang.reflect.InvocationTargetException;
024    import java.util.HashMap;
025    import java.util.HashSet;
026    import java.util.Set;
027    
028    import org.apache.commons.nabla.algorithmic.forward.analysis.ClassDifferentiator;
029    import org.apache.commons.nabla.core.DifferentiationException;
030    import org.apache.commons.nabla.core.UnivariateDerivative;
031    import org.apache.commons.nabla.core.UnivariateDifferentiable;
032    import org.apache.commons.nabla.core.UnivariateDifferentiator;
033    import org.objectweb.asm.ClassReader;
034    import org.objectweb.asm.ClassWriter;
035    
036    /** Automatic differentiator class based on bytecode analysis.
037     * <p>This class is an implementation of the {@link UnivariateDifferentiator}
038     * interface that computes <em>exact</em> differentials completely automatically
039     * and generate java classes and instances that compute the differential
040     * of the function as if they were hand-coded and compiled.</p>
041     * <p>The derivative bytecode created the first time an instance of a given class
042     * is differentiated is cached and will be reused if other instances of the same class
043     * are to be created later. The cache can also be dumped in a jar file for
044     * use in an application without bringing the full nabla library and its
045     * dependencies.</p>
046     * <p>This differentiator can handle only pure bytecode methods and known methods
047     * from math implementation classes like {@link java.lang.Math Math} or
048     * {@link java.lang.StrictMath StrictMath}. Pure bytecode methods are analyzed
049     * and converted. Methods from math implementation classes are only recognized
050     * by class and name and replaced by predefined derivative code.</p>
051     * @see org.apache.commons.nabla.caching.FetchDifferentiator
052     */
053    public class ForwardAlgorithmicDifferentiator implements UnivariateDifferentiator {
054    
055        /** UnivariateDifferentiable/UnivariateDerivative map. */
056        private final HashMap<Class<? extends UnivariateDifferentiable>,
057        Class<? extends UnivariateDerivative>> map;
058    
059        /** Math implementation classes. */
060        private final Set<String> mathClasses;
061    
062        /** Simple constructor.
063         * <p>Build a ForwardAlgorithmicDifferentiator instance with an empty cache.</p>
064         */
065        public ForwardAlgorithmicDifferentiator() {
066            map = new HashMap<Class<? extends UnivariateDifferentiable>,
067            Class<? extends UnivariateDerivative>>();
068            mathClasses = new HashSet<String>();
069            addMathImplementation(Math.class);
070            addMathImplementation(StrictMath.class);
071        }
072    
073        /** Add an implementation class for mathematical functions.
074         * <p>At construction, the differentiator considers only the {@link
075         * java.lang.Math Math} and {@link java.lang.StrictMath StrictMath}
076         * classes are math implementation classes. It may be useful to add
077         * other class for example to add some missing functions like
078         * inverse hyperbolic cosine that are not provided by the standard
079         * java classes as of Java 1.6.</p>
080         * @param mathClass implementation class for mathematical functions
081         */
082        public void addMathImplementation(final Class<?> mathClass) {
083            mathClasses.add(mathClass.getName().replace('.', '/'));
084        }
085    
086        /** Dump the cache into a stream.
087         * @param out output stream where to dump the cache
088         */
089        public void dumpCache(final OutputStream out) {
090            // TODO implement cache persistence
091            throw new RuntimeException("not implemented yet");
092        }
093    
094        /** {@inheritDoc} */
095        public UnivariateDerivative differentiate(final UnivariateDifferentiable d)
096            throws DifferentiationException {
097    
098            // get the derivative class
099            final Class<? extends UnivariateDerivative> derivativeClass =
100                getDerivativeClass(d.getClass());
101    
102            try {
103    
104                // create the instance
105                final Constructor<? extends UnivariateDerivative> constructor =
106                    derivativeClass.getConstructor(d.getClass());
107                return constructor.newInstance(d);
108    
109            } catch (InstantiationException ie) {
110                throw new DifferentiationException("abstract class {0} cannot be instantiated ({1})",
111                                                   derivativeClass.getName(), ie.getMessage());
112            } catch (IllegalAccessException iae) {
113                throw new DifferentiationException("illegal access to class {0} constructor ({1})",
114                                                   derivativeClass.getName(), iae.getMessage());
115            } catch (NoSuchMethodException nsme) {
116                throw new DifferentiationException("class {0} cannot be built from an instance of class {1} ({2})",
117                                                   derivativeClass.getName(), d.getClass().getName(), nsme.getMessage());
118            } catch (InvocationTargetException ite) {
119                throw new DifferentiationException("class {0} instantiation from an instance of class {1} failed ({2})",
120                                                   derivativeClass.getName(), d.getClass().getName(), ite.getMessage());
121            }
122    
123        }
124    
125        /** Get the derivative class of a differentiable class.
126         * <p>The derivative class is either built on the fly
127         * or retrieved from the cache if it has been built previously.</p>
128         * @param differentiableClass class to differentiate
129         * @return derivative class
130         * @throws DifferentiationException if the class cannot be differentiated
131         */
132        private Class<? extends UnivariateDerivative>
133        getDerivativeClass(final Class<? extends UnivariateDifferentiable> differentiableClass)
134            throws DifferentiationException {
135    
136            // lookup in the map if the class has already been differentiated
137            Class<? extends UnivariateDerivative> derivativeClass =
138                map.get(differentiableClass);
139    
140            // build the derivative class if it does not exist yet
141            if (derivativeClass == null) {
142                // perform analytical differentiation
143                derivativeClass = createDerivativeClass(differentiableClass);
144    
145                // put the newly created class in the map
146                map.put(differentiableClass, derivativeClass);
147    
148            }
149    
150            // return the derivative class
151            return derivativeClass;
152    
153        }
154    
155        /** Build a derivative class of a differentiable class.
156         * @param differentiableClass class to differentiate
157         * @return derivative class
158         * @throws DifferentiationException if the class cannot be differentiated
159         */
160        private Class<? extends UnivariateDerivative>
161        createDerivativeClass(final Class<? extends UnivariateDifferentiable> differentiableClass)
162            throws DifferentiationException {
163            try {
164    
165                // set up both ends of the class transform chain
166                final String classResourceName = "/" + differentiableClass.getName().replace('.', '/') + ".class";
167                final InputStream stream = differentiableClass.getResourceAsStream(classResourceName);
168                final ClassReader reader = new ClassReader(stream);
169                final ClassWriter writer = new ClassWriter(reader, ClassWriter.COMPUTE_FRAMES);
170    
171                // differentiate the function embedded in the differentiable class
172                final ClassDifferentiator differentiator = new ClassDifferentiator(mathClasses, writer);
173                reader.accept(differentiator, ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES);
174                differentiator.reportErrors();
175    
176                // create the derivative class
177                return new DerivativeLoader(differentiableClass).defineClass(differentiator, writer);
178    
179            } catch (IOException ioe) {
180                throw new DifferentiationException("class {0} cannot be read ({1})",
181                                              differentiableClass.getName(), ioe.getMessage());
182            }
183        }
184    
185        /** Class loader generating derivative classes. */
186        private static class DerivativeLoader extends ClassLoader {
187    
188            /** Simple constructor.
189             * @param differentiableClass differentiable class
190             */
191            public DerivativeLoader(final Class<? extends UnivariateDifferentiable> differentiableClass) {
192                super(differentiableClass.getClassLoader());
193            }
194    
195            /** Define a derivative class.
196             * @param differentiator class differentiator
197             * @param writer class writer
198             * @return a generated derivative class
199             */
200            @SuppressWarnings("unchecked")
201            public Class<? extends UnivariateDerivative>
202            defineClass(final ClassDifferentiator differentiator, final ClassWriter writer) {
203                final String name = differentiator.getDerivativeClassName().replace('/', '.');
204                final byte[] bytecode = writer.toByteArray();
205                return (Class<? extends UnivariateDerivative>) defineClass(name, bytecode, 0, bytecode.length);
206            }
207        }
208    
209    }