001 /*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements. See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License. You may obtain a copy of the License at
008 *
009 * http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017 package org.apache.commons.nabla.algorithmic.forward;
018
019 import java.io.IOException;
020 import java.io.InputStream;
021 import java.io.OutputStream;
022 import java.lang.reflect.Constructor;
023 import java.lang.reflect.InvocationTargetException;
024 import java.util.HashMap;
025 import java.util.HashSet;
026 import java.util.Set;
027
028 import org.apache.commons.nabla.algorithmic.forward.analysis.ClassDifferentiator;
029 import org.apache.commons.nabla.core.DifferentiationException;
030 import org.apache.commons.nabla.core.UnivariateDerivative;
031 import org.apache.commons.nabla.core.UnivariateDifferentiable;
032 import org.apache.commons.nabla.core.UnivariateDifferentiator;
033 import org.objectweb.asm.ClassReader;
034 import org.objectweb.asm.ClassWriter;
035
036 /** Automatic differentiator class based on bytecode analysis.
037 * <p>This class is an implementation of the {@link UnivariateDifferentiator}
038 * interface that computes <em>exact</em> differentials completely automatically
039 * and generate java classes and instances that compute the differential
040 * of the function as if they were hand-coded and compiled.</p>
041 * <p>The derivative bytecode created the first time an instance of a given class
042 * is differentiated is cached and will be reused if other instances of the same class
043 * are to be created later. The cache can also be dumped in a jar file for
044 * use in an application without bringing the full nabla library and its
045 * dependencies.</p>
046 * <p>This differentiator can handle only pure bytecode methods and known methods
047 * from math implementation classes like {@link java.lang.Math Math} or
048 * {@link java.lang.StrictMath StrictMath}. Pure bytecode methods are analyzed
049 * and converted. Methods from math implementation classes are only recognized
050 * by class and name and replaced by predefined derivative code.</p>
051 * @see org.apache.commons.nabla.caching.FetchDifferentiator
052 */
053 public class ForwardAlgorithmicDifferentiator implements UnivariateDifferentiator {
054
055 /** UnivariateDifferentiable/UnivariateDerivative map. */
056 private final HashMap<Class<? extends UnivariateDifferentiable>,
057 Class<? extends UnivariateDerivative>> map;
058
059 /** Math implementation classes. */
060 private final Set<String> mathClasses;
061
062 /** Simple constructor.
063 * <p>Build a ForwardAlgorithmicDifferentiator instance with an empty cache.</p>
064 */
065 public ForwardAlgorithmicDifferentiator() {
066 map = new HashMap<Class<? extends UnivariateDifferentiable>,
067 Class<? extends UnivariateDerivative>>();
068 mathClasses = new HashSet<String>();
069 addMathImplementation(Math.class);
070 addMathImplementation(StrictMath.class);
071 }
072
073 /** Add an implementation class for mathematical functions.
074 * <p>At construction, the differentiator considers only the {@link
075 * java.lang.Math Math} and {@link java.lang.StrictMath StrictMath}
076 * classes are math implementation classes. It may be useful to add
077 * other class for example to add some missing functions like
078 * inverse hyperbolic cosine that are not provided by the standard
079 * java classes as of Java 1.6.</p>
080 * @param mathClass implementation class for mathematical functions
081 */
082 public void addMathImplementation(final Class<?> mathClass) {
083 mathClasses.add(mathClass.getName().replace('.', '/'));
084 }
085
086 /** Dump the cache into a stream.
087 * @param out output stream where to dump the cache
088 */
089 public void dumpCache(final OutputStream out) {
090 // TODO implement cache persistence
091 throw new RuntimeException("not implemented yet");
092 }
093
094 /** {@inheritDoc} */
095 public UnivariateDerivative differentiate(final UnivariateDifferentiable d)
096 throws DifferentiationException {
097
098 // get the derivative class
099 final Class<? extends UnivariateDerivative> derivativeClass =
100 getDerivativeClass(d.getClass());
101
102 try {
103
104 // create the instance
105 final Constructor<? extends UnivariateDerivative> constructor =
106 derivativeClass.getConstructor(d.getClass());
107 return constructor.newInstance(d);
108
109 } catch (InstantiationException ie) {
110 throw new DifferentiationException("abstract class {0} cannot be instantiated ({1})",
111 derivativeClass.getName(), ie.getMessage());
112 } catch (IllegalAccessException iae) {
113 throw new DifferentiationException("illegal access to class {0} constructor ({1})",
114 derivativeClass.getName(), iae.getMessage());
115 } catch (NoSuchMethodException nsme) {
116 throw new DifferentiationException("class {0} cannot be built from an instance of class {1} ({2})",
117 derivativeClass.getName(), d.getClass().getName(), nsme.getMessage());
118 } catch (InvocationTargetException ite) {
119 throw new DifferentiationException("class {0} instantiation from an instance of class {1} failed ({2})",
120 derivativeClass.getName(), d.getClass().getName(), ite.getMessage());
121 }
122
123 }
124
125 /** Get the derivative class of a differentiable class.
126 * <p>The derivative class is either built on the fly
127 * or retrieved from the cache if it has been built previously.</p>
128 * @param differentiableClass class to differentiate
129 * @return derivative class
130 * @throws DifferentiationException if the class cannot be differentiated
131 */
132 private Class<? extends UnivariateDerivative>
133 getDerivativeClass(final Class<? extends UnivariateDifferentiable> differentiableClass)
134 throws DifferentiationException {
135
136 // lookup in the map if the class has already been differentiated
137 Class<? extends UnivariateDerivative> derivativeClass =
138 map.get(differentiableClass);
139
140 // build the derivative class if it does not exist yet
141 if (derivativeClass == null) {
142 // perform analytical differentiation
143 derivativeClass = createDerivativeClass(differentiableClass);
144
145 // put the newly created class in the map
146 map.put(differentiableClass, derivativeClass);
147
148 }
149
150 // return the derivative class
151 return derivativeClass;
152
153 }
154
155 /** Build a derivative class of a differentiable class.
156 * @param differentiableClass class to differentiate
157 * @return derivative class
158 * @throws DifferentiationException if the class cannot be differentiated
159 */
160 private Class<? extends UnivariateDerivative>
161 createDerivativeClass(final Class<? extends UnivariateDifferentiable> differentiableClass)
162 throws DifferentiationException {
163 try {
164
165 // set up both ends of the class transform chain
166 final String classResourceName = "/" + differentiableClass.getName().replace('.', '/') + ".class";
167 final InputStream stream = differentiableClass.getResourceAsStream(classResourceName);
168 final ClassReader reader = new ClassReader(stream);
169 final ClassWriter writer = new ClassWriter(reader, ClassWriter.COMPUTE_FRAMES);
170
171 // differentiate the function embedded in the differentiable class
172 final ClassDifferentiator differentiator = new ClassDifferentiator(mathClasses, writer);
173 reader.accept(differentiator, ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES);
174 differentiator.reportErrors();
175
176 // create the derivative class
177 return new DerivativeLoader(differentiableClass).defineClass(differentiator, writer);
178
179 } catch (IOException ioe) {
180 throw new DifferentiationException("class {0} cannot be read ({1})",
181 differentiableClass.getName(), ioe.getMessage());
182 }
183 }
184
185 /** Class loader generating derivative classes. */
186 private static class DerivativeLoader extends ClassLoader {
187
188 /** Simple constructor.
189 * @param differentiableClass differentiable class
190 */
191 public DerivativeLoader(final Class<? extends UnivariateDifferentiable> differentiableClass) {
192 super(differentiableClass.getClassLoader());
193 }
194
195 /** Define a derivative class.
196 * @param differentiator class differentiator
197 * @param writer class writer
198 * @return a generated derivative class
199 */
200 @SuppressWarnings("unchecked")
201 public Class<? extends UnivariateDerivative>
202 defineClass(final ClassDifferentiator differentiator, final ClassWriter writer) {
203 final String name = differentiator.getDerivativeClassName().replace('/', '.');
204 final byte[] bytecode = writer.toByteArray();
205 return (Class<? extends UnivariateDerivative>) defineClass(name, bytecode, 0, bytecode.length);
206 }
207 }
208
209 }