001 /*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements. See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License. You may obtain a copy of the License at
008 *
009 * http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017 package org.apache.commons.nabla.forward;
018
019 import java.io.IOException;
020 import java.io.OutputStream;
021 import java.lang.reflect.Constructor;
022 import java.lang.reflect.InvocationTargetException;
023 import java.util.HashMap;
024 import java.util.HashSet;
025 import java.util.Set;
026
027 import org.apache.commons.math3.analysis.UnivariateFunction;
028 import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
029 import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction;
030 import org.apache.commons.math3.analysis.differentiation.UnivariateFunctionDifferentiator;
031 import org.apache.commons.math3.util.FastMath;
032 import org.apache.commons.nabla.DifferentiationException;
033 import org.apache.commons.nabla.NablaMessages;
034 import org.apache.commons.nabla.forward.analysis.ClassDifferentiator;
035 import org.objectweb.asm.ClassWriter;
036 import org.objectweb.asm.Type;
037 import org.objectweb.asm.tree.ClassNode;
038
039 /** Algorithmic differentiator class in forward mode based on bytecode analysis.
040 * <p>This class is an implementation of the {@link UnivariateFunctionDifferentiator}
041 * interface that computes <em>exact</em> differentials completely automatically
042 * and generate java classes and instances that compute the differential
043 * of the function as if they were hand-coded and compiled.</p>
044 * <p>The derivative bytecode created the first time an instance of a given class
045 * is differentiated is cached and will be reused if other instances of the same class
046 * are to be created later. The cache can also be dumped in a jar file for
047 * use in an application without bringing the full nabla library and its
048 * dependencies.</p>
049 * <p>This differentiator can handle only pure bytecode methods and known methods
050 * from math implementation classes like {@link java.lang.Math Math}, {@link
051 * java.lang.StrictMath StrictMath} or {@link FastMath}. Pure bytecode methods are
052 * analyzed and converted. Methods from math implementation classes are only
053 * recognized by class and name and replaced by predefined derivative code.</p>
054 * @see org.apache.commons.nabla.caching.FetchDifferentiator
055 * @version $Id$
056 */
057 public class ForwardModeDifferentiator implements UnivariateFunctionDifferentiator {
058
059 /** UnivariateFunction/UnivariateDifferentiableFunction map. */
060 private final HashMap<Class<? extends UnivariateFunction>,
061 Class<? extends UnivariateDifferentiableFunction>> map;
062
063 /** Class name/ bytecode map. */
064 private final HashMap<String, byte[]> byteCodeMap;
065
066 /** Math implementation classes. */
067 private final Set<String> mathClasses;
068
069 /** Simple constructor.
070 * <p>Build a ForwardAlgorithmicDifferentiator instance with an empty cache.</p>
071 */
072 public ForwardModeDifferentiator() {
073 map = new HashMap<Class<? extends UnivariateFunction>,
074 Class<? extends UnivariateDifferentiableFunction>>();
075 byteCodeMap = new HashMap<String, byte[]>();
076 mathClasses = new HashSet<String>();
077 addMathImplementation(Math.class);
078 addMathImplementation(StrictMath.class);
079 addMathImplementation(FastMath.class);
080 }
081
082 /** Add an implementation class for mathematical functions.
083 * <p>At construction, the differentiator considers only the {@link
084 * java.lang.Math Math} and {@link java.lang.StrictMath StrictMath}
085 * classes are math implementation classes. It may be useful to add
086 * other classes for example to add some missing functions like
087 * inverse hyperbolic cosine that are not provided by the standard
088 * java classes as of Java 1.6.</p>
089 * @param mathClass implementation class for mathematical functions
090 */
091 public void addMathImplementation(final Class<?> mathClass) {
092 mathClasses.add(mathClass.getName().replace('.', '/'));
093 }
094
095 /** Dump the cache into a stream.
096 * @param out output stream where to dump the cache
097 */
098 public void dumpCache(final OutputStream out) {
099 // TODO: implement cache persistence
100 throw new RuntimeException("not implemented yet");
101 }
102
103 /** {@inheritDoc} */
104 public UnivariateDifferentiableFunction differentiate(final UnivariateFunction d) {
105
106 // get the derivative class
107 final Class<? extends UnivariateDifferentiableFunction> derivativeClass =
108 getDerivativeClass(d.getClass());
109
110 try {
111
112 // create the instance
113 final Constructor<? extends UnivariateDifferentiableFunction> constructor =
114 derivativeClass.getConstructor(d.getClass());
115 return constructor.newInstance(d);
116
117 } catch (InstantiationException ie) {
118 throw new DifferentiationException(NablaMessages.CANNOT_INSTANTIATE_ABSTRACT_CLASS,
119 derivativeClass.getName(), ie.getMessage());
120 } catch (IllegalAccessException iae) {
121 throw new DifferentiationException(NablaMessages.ILLEGAL_ACCESS_TO_CONSTRUCTOR,
122 derivativeClass.getName(), iae.getMessage());
123 } catch (NoSuchMethodException nsme) {
124 throw new DifferentiationException(NablaMessages.CANNOT_BUILD_CLASS_FROM_OTHER_CLASS,
125 derivativeClass.getName(), d.getClass().getName(), nsme.getMessage());
126 } catch (InvocationTargetException ite) {
127 throw new DifferentiationException(NablaMessages.CANNOT_INSTANTIATE_CLASS_FROM_OTHER_INSTANCE,
128 derivativeClass.getName(), d.getClass().getName(), ite.getMessage());
129 } catch (VerifyError ve) {
130 throw new DifferentiationException(NablaMessages.INCORRECT_GENERATED_CODE,
131 derivativeClass.getName(), d.getClass().getName(), ve.getMessage());
132 }
133
134 }
135
136 /** Get the derivative class of a differentiable class.
137 * <p>The derivative class is either built on the fly
138 * or retrieved from the cache if it has been built previously.</p>
139 * @param differentiableClass class to differentiate
140 * @return derivative class
141 * @throws DifferentiationException if the class cannot be differentiated
142 */
143 private Class<? extends UnivariateDifferentiableFunction>
144 getDerivativeClass(final Class<? extends UnivariateFunction> differentiableClass)
145 throws DifferentiationException {
146
147 // lookup in the map if the class has already been differentiated
148 Class<? extends UnivariateDifferentiableFunction> derivativeClass =
149 map.get(differentiableClass);
150
151 // build the derivative class if it does not exist yet
152 if (derivativeClass == null) {
153
154 // perform algorithmic differentiation
155 derivativeClass = createDerivativeClass(differentiableClass);
156
157 // put the newly created class in the map
158 map.put(differentiableClass, derivativeClass);
159
160 }
161
162 // return the derivative class
163 return derivativeClass;
164
165 }
166
167 /** Build a derivative class of a differentiable class.
168 * @param differentiableClass class to differentiate
169 * @return derivative class
170 * @throws DifferentiationException if the class cannot be differentiated
171 */
172 private Class<? extends UnivariateDifferentiableFunction>
173 createDerivativeClass(final Class<? extends UnivariateFunction> differentiableClass)
174 throws DifferentiationException {
175 try {
176
177 // differentiate the function embedded in the differentiable class
178 final ClassDifferentiator differentiator =
179 new ClassDifferentiator(differentiableClass, mathClasses);
180 final Type dsType = Type.getType(DerivativeStructure.class);
181 differentiator.differentiateMethod("value",
182 Type.getMethodDescriptor(Type.DOUBLE_TYPE, Type.DOUBLE_TYPE),
183 Type.getMethodDescriptor(dsType, dsType));
184
185 // create the derivative class
186 final ClassNode derived = differentiator.getDerivedClass();
187 final ClassWriter writer = new ClassWriter(ClassWriter.COMPUTE_FRAMES);
188 final String name = derived.name.replace('/', '.');
189 derived.accept(writer);
190 final byte[] bytecode = writer.toByteArray();
191
192 final Class<? extends UnivariateDifferentiableFunction> dClass =
193 new DerivativeLoader(differentiableClass).defineClass(name, bytecode);
194 byteCodeMap.put(name, bytecode);
195 return dClass;
196
197 } catch (IOException ioe) {
198 throw new DifferentiationException(NablaMessages.CANNOT_READ_CLASS,
199 differentiableClass.getName(), ioe.getMessage());
200 }
201 }
202
203 /** Class loader generating derivative classes. */
204 private static class DerivativeLoader extends ClassLoader {
205
206 /** Simple constructor.
207 * @param differentiableClass differentiable class
208 */
209 public DerivativeLoader(final Class<? extends UnivariateFunction> differentiableClass) {
210 super(differentiableClass.getClassLoader());
211 }
212
213 /** Define a derivative class.
214 * @param name name of the differentiated class
215 * @param bytecode bytecode of the differentiated class
216 * @return a generated derivative class
217 */
218 @SuppressWarnings("unchecked")
219 public Class<? extends UnivariateDifferentiableFunction>
220 defineClass(final String name, final byte[] bytecode) {
221 return (Class<? extends UnivariateDifferentiableFunction>) defineClass(name, bytecode, 0, bytecode.length);
222 }
223 }
224
225 }