View Javadoc

1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.nabla.forward;
18  
19  import java.io.IOException;
20  import java.io.OutputStream;
21  import java.lang.reflect.Constructor;
22  import java.lang.reflect.InvocationTargetException;
23  import java.util.HashMap;
24  import java.util.HashSet;
25  import java.util.Set;
26  
27  import org.apache.commons.math3.analysis.UnivariateFunction;
28  import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
29  import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction;
30  import org.apache.commons.math3.analysis.differentiation.UnivariateFunctionDifferentiator;
31  import org.apache.commons.math3.util.FastMath;
32  import org.apache.commons.nabla.DifferentiationException;
33  import org.apache.commons.nabla.NablaMessages;
34  import org.apache.commons.nabla.forward.analysis.ClassDifferentiator;
35  import org.objectweb.asm.ClassWriter;
36  import org.objectweb.asm.Type;
37  import org.objectweb.asm.tree.ClassNode;
38  
39  /** Algorithmic differentiator class in forward mode based on bytecode analysis.
40   * <p>This class is an implementation of the {@link UnivariateFunctionDifferentiator}
41   * interface that computes <em>exact</em> differentials completely automatically
42   * and generate java classes and instances that compute the differential
43   * of the function as if they were hand-coded and compiled.</p>
44   * <p>The derivative bytecode created the first time an instance of a given class
45   * is differentiated is cached and will be reused if other instances of the same class
46   * are to be created later. The cache can also be dumped in a jar file for
47   * use in an application without bringing the full nabla library and its
48   * dependencies.</p>
49   * <p>This differentiator can handle only pure bytecode methods and known methods
50   * from math implementation classes like {@link java.lang.Math Math}, {@link
51   * java.lang.StrictMath StrictMath} or {@link FastMath}. Pure bytecode methods are
52   * analyzed and converted. Methods from math implementation classes are only
53   * recognized by class and name and replaced by predefined derivative code.</p>
54   * @see org.apache.commons.nabla.caching.FetchDifferentiator
55   * @version $Id$
56   */
57  public class ForwardModeDifferentiator implements UnivariateFunctionDifferentiator {
58  
59      /** UnivariateFunction/UnivariateDifferentiableFunction map. */
60      private final HashMap<Class<? extends UnivariateFunction>,
61                            Class<? extends UnivariateDifferentiableFunction>> map;
62  
63      /** Class name/ bytecode map. */
64      private final HashMap<String, byte[]> byteCodeMap;
65  
66      /** Math implementation classes. */
67      private final Set<String> mathClasses;
68  
69      /** Simple constructor.
70       * <p>Build a ForwardAlgorithmicDifferentiator instance with an empty cache.</p>
71       */
72      public ForwardModeDifferentiator() {
73          map         = new HashMap<Class<? extends UnivariateFunction>,
74                                    Class<? extends UnivariateDifferentiableFunction>>();
75          byteCodeMap = new HashMap<String, byte[]>();
76          mathClasses = new HashSet<String>();
77          addMathImplementation(Math.class);
78          addMathImplementation(StrictMath.class);
79          addMathImplementation(FastMath.class);
80      }
81  
82      /** Add an implementation class for mathematical functions.
83       * <p>At construction, the differentiator considers only the {@link
84       * java.lang.Math Math} and {@link java.lang.StrictMath StrictMath}
85       * classes are math implementation classes. It may be useful to add
86       * other classes for example to add some missing functions like
87       * inverse hyperbolic cosine that are not provided by the standard
88       * java classes as of Java 1.6.</p>
89       * @param mathClass implementation class for mathematical functions
90       */
91      public void addMathImplementation(final Class<?> mathClass) {
92          mathClasses.add(mathClass.getName().replace('.', '/'));
93      }
94  
95      /** Dump the cache into a stream.
96       * @param out output stream where to dump the cache
97       */
98      public void dumpCache(final OutputStream out) {
99          // TODO: implement cache persistence
100         throw new RuntimeException("not implemented yet");
101     }
102 
103     /** {@inheritDoc} */
104     public UnivariateDifferentiableFunction differentiate(final UnivariateFunction d) {
105 
106         // get the derivative class
107         final Class<? extends UnivariateDifferentiableFunction> derivativeClass =
108             getDerivativeClass(d.getClass());
109 
110         try {
111 
112             // create the instance
113             final Constructor<? extends UnivariateDifferentiableFunction> constructor =
114                 derivativeClass.getConstructor(d.getClass());
115             return constructor.newInstance(d);
116 
117         } catch (InstantiationException ie) {
118             throw new DifferentiationException(NablaMessages.CANNOT_INSTANTIATE_ABSTRACT_CLASS,
119                                                derivativeClass.getName(), ie.getMessage());
120         } catch (IllegalAccessException iae) {
121             throw new DifferentiationException(NablaMessages.ILLEGAL_ACCESS_TO_CONSTRUCTOR,
122                                                derivativeClass.getName(), iae.getMessage());
123         } catch (NoSuchMethodException nsme) {
124             throw new DifferentiationException(NablaMessages.CANNOT_BUILD_CLASS_FROM_OTHER_CLASS,
125                                                derivativeClass.getName(), d.getClass().getName(), nsme.getMessage());
126         } catch (InvocationTargetException ite) {
127             throw new DifferentiationException(NablaMessages.CANNOT_INSTANTIATE_CLASS_FROM_OTHER_INSTANCE,
128                                                derivativeClass.getName(), d.getClass().getName(), ite.getMessage());
129         } catch (VerifyError ve) {
130             throw new DifferentiationException(NablaMessages.INCORRECT_GENERATED_CODE,
131                                                derivativeClass.getName(), d.getClass().getName(), ve.getMessage());
132         }
133 
134     }
135 
136     /** Get the derivative class of a differentiable class.
137      * <p>The derivative class is either built on the fly
138      * or retrieved from the cache if it has been built previously.</p>
139      * @param differentiableClass class to differentiate
140      * @return derivative class
141      * @throws DifferentiationException if the class cannot be differentiated
142      */
143     private Class<? extends UnivariateDifferentiableFunction>
144     getDerivativeClass(final Class<? extends UnivariateFunction> differentiableClass)
145         throws DifferentiationException {
146 
147         // lookup in the map if the class has already been differentiated
148         Class<? extends UnivariateDifferentiableFunction> derivativeClass =
149             map.get(differentiableClass);
150 
151         // build the derivative class if it does not exist yet
152         if (derivativeClass == null) {
153 
154             // perform algorithmic differentiation
155             derivativeClass = createDerivativeClass(differentiableClass);
156 
157             // put the newly created class in the map
158             map.put(differentiableClass, derivativeClass);
159 
160         }
161 
162         // return the derivative class
163         return derivativeClass;
164 
165     }
166 
167     /** Build a derivative class of a differentiable class.
168      * @param differentiableClass class to differentiate
169      * @return derivative class
170      * @throws DifferentiationException if the class cannot be differentiated
171      */
172     private Class<? extends UnivariateDifferentiableFunction>
173     createDerivativeClass(final Class<? extends UnivariateFunction> differentiableClass)
174         throws DifferentiationException {
175         try {
176 
177             // differentiate the function embedded in the differentiable class
178             final ClassDifferentiator differentiator =
179                 new ClassDifferentiator(differentiableClass, mathClasses);
180             final Type dsType = Type.getType(DerivativeStructure.class);
181             differentiator.differentiateMethod("value",
182                                                Type.getMethodDescriptor(Type.DOUBLE_TYPE, Type.DOUBLE_TYPE),
183                                                Type.getMethodDescriptor(dsType, dsType));
184 
185             // create the derivative class
186             final ClassNode   derived = differentiator.getDerivedClass();
187             final ClassWriter writer  = new ClassWriter(ClassWriter.COMPUTE_FRAMES);
188             final String name = derived.name.replace('/', '.');
189             derived.accept(writer);
190             final byte[] bytecode = writer.toByteArray();
191 
192             final Class<? extends UnivariateDifferentiableFunction> dClass =
193                     new DerivativeLoader(differentiableClass).defineClass(name, bytecode);
194             byteCodeMap.put(name, bytecode);
195             return dClass;
196 
197         } catch (IOException ioe) {
198             throw new DifferentiationException(NablaMessages.CANNOT_READ_CLASS,
199                                                differentiableClass.getName(), ioe.getMessage());
200         }
201     }
202 
203     /** Class loader generating derivative classes. */
204     private static class DerivativeLoader extends ClassLoader {
205 
206         /** Simple constructor.
207          * @param differentiableClass differentiable class
208          */
209         public DerivativeLoader(final Class<? extends UnivariateFunction> differentiableClass) {
210             super(differentiableClass.getClassLoader());
211         }
212 
213         /** Define a derivative class.
214          * @param name name of the differentiated class
215          * @param bytecode bytecode of the differentiated class
216          * @return a generated derivative class
217          */
218         @SuppressWarnings("unchecked")
219         public Class<? extends UnivariateDifferentiableFunction>
220         defineClass(final String name, final byte[] bytecode) {
221             return (Class<? extends UnivariateDifferentiableFunction>) defineClass(name, bytecode, 0, bytecode.length);
222         }
223     }
224 
225 }