001 /* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 package org.apache.commons.nabla.algorithmic.forward; 018 019 import java.io.IOException; 020 import java.io.InputStream; 021 import java.io.OutputStream; 022 import java.lang.reflect.Constructor; 023 import java.lang.reflect.InvocationTargetException; 024 import java.util.HashMap; 025 import java.util.HashSet; 026 import java.util.Set; 027 028 import org.apache.commons.nabla.algorithmic.forward.analysis.ClassDifferentiator; 029 import org.apache.commons.nabla.core.DifferentiationException; 030 import org.apache.commons.nabla.core.UnivariateDerivative; 031 import org.apache.commons.nabla.core.UnivariateDifferentiable; 032 import org.apache.commons.nabla.core.UnivariateDifferentiator; 033 import org.objectweb.asm.ClassReader; 034 import org.objectweb.asm.ClassWriter; 035 036 /** Automatic differentiator class based on bytecode analysis. 037 * <p>This class is an implementation of the {@link UnivariateDifferentiator} 038 * interface that computes <em>exact</em> differentials completely automatically 039 * and generate java classes and instances that compute the differential 040 * of the function as if they were hand-coded and compiled.</p> 041 * <p>The derivative bytecode created the first time an instance of a given class 042 * is differentiated is cached and will be reused if other instances of the same class 043 * are to be created later. The cache can also be dumped in a jar file for 044 * use in an application without bringing the full nabla library and its 045 * dependencies.</p> 046 * <p>This differentiator can handle only pure bytecode methods and known methods 047 * from math implementation classes like {@link java.lang.Math Math} or 048 * {@link java.lang.StrictMath StrictMath}. Pure bytecode methods are analyzed 049 * and converted. Methods from math implementation classes are only recognized 050 * by class and name and replaced by predefined derivative code.</p> 051 * @see org.apache.commons.nabla.caching.FetchDifferentiator 052 */ 053 public class ForwardAlgorithmicDifferentiator implements UnivariateDifferentiator { 054 055 /** UnivariateDifferentiable/UnivariateDerivative map. */ 056 private final HashMap<Class<? extends UnivariateDifferentiable>, 057 Class<? extends UnivariateDerivative>> map; 058 059 /** Math implementation classes. */ 060 private final Set<String> mathClasses; 061 062 /** Simple constructor. 063 * <p>Build a ForwardAlgorithmicDifferentiator instance with an empty cache.</p> 064 */ 065 public ForwardAlgorithmicDifferentiator() { 066 map = new HashMap<Class<? extends UnivariateDifferentiable>, 067 Class<? extends UnivariateDerivative>>(); 068 mathClasses = new HashSet<String>(); 069 addMathImplementation(Math.class); 070 addMathImplementation(StrictMath.class); 071 } 072 073 /** Add an implementation class for mathematical functions. 074 * <p>At construction, the differentiator considers only the {@link 075 * java.lang.Math Math} and {@link java.lang.StrictMath StrictMath} 076 * classes are math implementation classes. It may be useful to add 077 * other class for example to add some missing functions like 078 * inverse hyperbolic cosine that are not provided by the standard 079 * java classes as of Java 1.6.</p> 080 * @param mathClass implementation class for mathematical functions 081 */ 082 public void addMathImplementation(final Class<?> mathClass) { 083 mathClasses.add(mathClass.getName().replace('.', '/')); 084 } 085 086 /** Dump the cache into a stream. 087 * @param out output stream where to dump the cache 088 */ 089 public void dumpCache(final OutputStream out) { 090 // TODO implement cache persistence 091 throw new RuntimeException("not implemented yet"); 092 } 093 094 /** {@inheritDoc} */ 095 public UnivariateDerivative differentiate(final UnivariateDifferentiable d) 096 throws DifferentiationException { 097 098 // get the derivative class 099 final Class<? extends UnivariateDerivative> derivativeClass = 100 getDerivativeClass(d.getClass()); 101 102 try { 103 104 // create the instance 105 final Constructor<? extends UnivariateDerivative> constructor = 106 derivativeClass.getConstructor(d.getClass()); 107 return constructor.newInstance(d); 108 109 } catch (InstantiationException ie) { 110 throw new DifferentiationException("abstract class {0} cannot be instantiated ({1})", 111 derivativeClass.getName(), ie.getMessage()); 112 } catch (IllegalAccessException iae) { 113 throw new DifferentiationException("illegal access to class {0} constructor ({1})", 114 derivativeClass.getName(), iae.getMessage()); 115 } catch (NoSuchMethodException nsme) { 116 throw new DifferentiationException("class {0} cannot be built from an instance of class {1} ({2})", 117 derivativeClass.getName(), d.getClass().getName(), nsme.getMessage()); 118 } catch (InvocationTargetException ite) { 119 throw new DifferentiationException("class {0} instantiation from an instance of class {1} failed ({2})", 120 derivativeClass.getName(), d.getClass().getName(), ite.getMessage()); 121 } 122 123 } 124 125 /** Get the derivative class of a differentiable class. 126 * <p>The derivative class is either built on the fly 127 * or retrieved from the cache if it has been built previously.</p> 128 * @param differentiableClass class to differentiate 129 * @return derivative class 130 * @throws DifferentiationException if the class cannot be differentiated 131 */ 132 private Class<? extends UnivariateDerivative> 133 getDerivativeClass(final Class<? extends UnivariateDifferentiable> differentiableClass) 134 throws DifferentiationException { 135 136 // lookup in the map if the class has already been differentiated 137 Class<? extends UnivariateDerivative> derivativeClass = 138 map.get(differentiableClass); 139 140 // build the derivative class if it does not exist yet 141 if (derivativeClass == null) { 142 // perform analytical differentiation 143 derivativeClass = createDerivativeClass(differentiableClass); 144 145 // put the newly created class in the map 146 map.put(differentiableClass, derivativeClass); 147 148 } 149 150 // return the derivative class 151 return derivativeClass; 152 153 } 154 155 /** Build a derivative class of a differentiable class. 156 * @param differentiableClass class to differentiate 157 * @return derivative class 158 * @throws DifferentiationException if the class cannot be differentiated 159 */ 160 private Class<? extends UnivariateDerivative> 161 createDerivativeClass(final Class<? extends UnivariateDifferentiable> differentiableClass) 162 throws DifferentiationException { 163 try { 164 165 // set up both ends of the class transform chain 166 final String classResourceName = "/" + differentiableClass.getName().replace('.', '/') + ".class"; 167 final InputStream stream = differentiableClass.getResourceAsStream(classResourceName); 168 final ClassReader reader = new ClassReader(stream); 169 final ClassWriter writer = new ClassWriter(reader, ClassWriter.COMPUTE_FRAMES); 170 171 // differentiate the function embedded in the differentiable class 172 final ClassDifferentiator differentiator = new ClassDifferentiator(mathClasses, writer); 173 reader.accept(differentiator, ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES); 174 differentiator.reportErrors(); 175 176 // create the derivative class 177 return new DerivativeLoader(differentiableClass).defineClass(differentiator, writer); 178 179 } catch (IOException ioe) { 180 throw new DifferentiationException("class {0} cannot be read ({1})", 181 differentiableClass.getName(), ioe.getMessage()); 182 } 183 } 184 185 /** Class loader generating derivative classes. */ 186 private static class DerivativeLoader extends ClassLoader { 187 188 /** Simple constructor. 189 * @param differentiableClass differentiable class 190 */ 191 public DerivativeLoader(final Class<? extends UnivariateDifferentiable> differentiableClass) { 192 super(differentiableClass.getClassLoader()); 193 } 194 195 /** Define a derivative class. 196 * @param differentiator class differentiator 197 * @param writer class writer 198 * @return a generated derivative class 199 */ 200 @SuppressWarnings("unchecked") 201 public Class<? extends UnivariateDerivative> 202 defineClass(final ClassDifferentiator differentiator, final ClassWriter writer) { 203 final String name = differentiator.getDerivativeClassName().replace('/', '.'); 204 final byte[] bytecode = writer.toByteArray(); 205 return (Class<? extends UnivariateDerivative>) defineClass(name, bytecode, 0, bytecode.length); 206 } 207 } 208 209 }