001 /* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 package org.apache.commons.nabla.forward; 018 019 import java.io.IOException; 020 import java.io.OutputStream; 021 import java.lang.reflect.Constructor; 022 import java.lang.reflect.InvocationTargetException; 023 import java.util.HashMap; 024 import java.util.HashSet; 025 import java.util.Set; 026 027 import org.apache.commons.math3.analysis.UnivariateFunction; 028 import org.apache.commons.math3.analysis.differentiation.DerivativeStructure; 029 import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction; 030 import org.apache.commons.math3.analysis.differentiation.UnivariateFunctionDifferentiator; 031 import org.apache.commons.math3.util.FastMath; 032 import org.apache.commons.nabla.DifferentiationException; 033 import org.apache.commons.nabla.NablaMessages; 034 import org.apache.commons.nabla.forward.analysis.ClassDifferentiator; 035 import org.objectweb.asm.ClassWriter; 036 import org.objectweb.asm.Type; 037 import org.objectweb.asm.tree.ClassNode; 038 039 /** Algorithmic differentiator class in forward mode based on bytecode analysis. 040 * <p>This class is an implementation of the {@link UnivariateFunctionDifferentiator} 041 * interface that computes <em>exact</em> differentials completely automatically 042 * and generate java classes and instances that compute the differential 043 * of the function as if they were hand-coded and compiled.</p> 044 * <p>The derivative bytecode created the first time an instance of a given class 045 * is differentiated is cached and will be reused if other instances of the same class 046 * are to be created later. The cache can also be dumped in a jar file for 047 * use in an application without bringing the full nabla library and its 048 * dependencies.</p> 049 * <p>This differentiator can handle only pure bytecode methods and known methods 050 * from math implementation classes like {@link java.lang.Math Math}, {@link 051 * java.lang.StrictMath StrictMath} or {@link FastMath}. Pure bytecode methods are 052 * analyzed and converted. Methods from math implementation classes are only 053 * recognized by class and name and replaced by predefined derivative code.</p> 054 * @see org.apache.commons.nabla.caching.FetchDifferentiator 055 * @version $Id$ 056 */ 057 public class ForwardModeDifferentiator implements UnivariateFunctionDifferentiator { 058 059 /** UnivariateFunction/UnivariateDifferentiableFunction map. */ 060 private final HashMap<Class<? extends UnivariateFunction>, 061 Class<? extends UnivariateDifferentiableFunction>> map; 062 063 /** Class name/ bytecode map. */ 064 private final HashMap<String, byte[]> byteCodeMap; 065 066 /** Math implementation classes. */ 067 private final Set<String> mathClasses; 068 069 /** Simple constructor. 070 * <p>Build a ForwardAlgorithmicDifferentiator instance with an empty cache.</p> 071 */ 072 public ForwardModeDifferentiator() { 073 map = new HashMap<Class<? extends UnivariateFunction>, 074 Class<? extends UnivariateDifferentiableFunction>>(); 075 byteCodeMap = new HashMap<String, byte[]>(); 076 mathClasses = new HashSet<String>(); 077 addMathImplementation(Math.class); 078 addMathImplementation(StrictMath.class); 079 addMathImplementation(FastMath.class); 080 } 081 082 /** Add an implementation class for mathematical functions. 083 * <p>At construction, the differentiator considers only the {@link 084 * java.lang.Math Math} and {@link java.lang.StrictMath StrictMath} 085 * classes are math implementation classes. It may be useful to add 086 * other classes for example to add some missing functions like 087 * inverse hyperbolic cosine that are not provided by the standard 088 * java classes as of Java 1.6.</p> 089 * @param mathClass implementation class for mathematical functions 090 */ 091 public void addMathImplementation(final Class<?> mathClass) { 092 mathClasses.add(mathClass.getName().replace('.', '/')); 093 } 094 095 /** Dump the cache into a stream. 096 * @param out output stream where to dump the cache 097 */ 098 public void dumpCache(final OutputStream out) { 099 // TODO: implement cache persistence 100 throw new RuntimeException("not implemented yet"); 101 } 102 103 /** {@inheritDoc} */ 104 public UnivariateDifferentiableFunction differentiate(final UnivariateFunction d) { 105 106 // get the derivative class 107 final Class<? extends UnivariateDifferentiableFunction> derivativeClass = 108 getDerivativeClass(d.getClass()); 109 110 try { 111 112 // create the instance 113 final Constructor<? extends UnivariateDifferentiableFunction> constructor = 114 derivativeClass.getConstructor(d.getClass()); 115 return constructor.newInstance(d); 116 117 } catch (InstantiationException ie) { 118 throw new DifferentiationException(NablaMessages.CANNOT_INSTANTIATE_ABSTRACT_CLASS, 119 derivativeClass.getName(), ie.getMessage()); 120 } catch (IllegalAccessException iae) { 121 throw new DifferentiationException(NablaMessages.ILLEGAL_ACCESS_TO_CONSTRUCTOR, 122 derivativeClass.getName(), iae.getMessage()); 123 } catch (NoSuchMethodException nsme) { 124 throw new DifferentiationException(NablaMessages.CANNOT_BUILD_CLASS_FROM_OTHER_CLASS, 125 derivativeClass.getName(), d.getClass().getName(), nsme.getMessage()); 126 } catch (InvocationTargetException ite) { 127 throw new DifferentiationException(NablaMessages.CANNOT_INSTANTIATE_CLASS_FROM_OTHER_INSTANCE, 128 derivativeClass.getName(), d.getClass().getName(), ite.getMessage()); 129 } catch (VerifyError ve) { 130 throw new DifferentiationException(NablaMessages.INCORRECT_GENERATED_CODE, 131 derivativeClass.getName(), d.getClass().getName(), ve.getMessage()); 132 } 133 134 } 135 136 /** Get the derivative class of a differentiable class. 137 * <p>The derivative class is either built on the fly 138 * or retrieved from the cache if it has been built previously.</p> 139 * @param differentiableClass class to differentiate 140 * @return derivative class 141 * @throws DifferentiationException if the class cannot be differentiated 142 */ 143 private Class<? extends UnivariateDifferentiableFunction> 144 getDerivativeClass(final Class<? extends UnivariateFunction> differentiableClass) 145 throws DifferentiationException { 146 147 // lookup in the map if the class has already been differentiated 148 Class<? extends UnivariateDifferentiableFunction> derivativeClass = 149 map.get(differentiableClass); 150 151 // build the derivative class if it does not exist yet 152 if (derivativeClass == null) { 153 154 // perform algorithmic differentiation 155 derivativeClass = createDerivativeClass(differentiableClass); 156 157 // put the newly created class in the map 158 map.put(differentiableClass, derivativeClass); 159 160 } 161 162 // return the derivative class 163 return derivativeClass; 164 165 } 166 167 /** Build a derivative class of a differentiable class. 168 * @param differentiableClass class to differentiate 169 * @return derivative class 170 * @throws DifferentiationException if the class cannot be differentiated 171 */ 172 private Class<? extends UnivariateDifferentiableFunction> 173 createDerivativeClass(final Class<? extends UnivariateFunction> differentiableClass) 174 throws DifferentiationException { 175 try { 176 177 // differentiate the function embedded in the differentiable class 178 final ClassDifferentiator differentiator = 179 new ClassDifferentiator(differentiableClass, mathClasses); 180 final Type dsType = Type.getType(DerivativeStructure.class); 181 differentiator.differentiateMethod("value", 182 Type.getMethodDescriptor(Type.DOUBLE_TYPE, Type.DOUBLE_TYPE), 183 Type.getMethodDescriptor(dsType, dsType)); 184 185 // create the derivative class 186 final ClassNode derived = differentiator.getDerivedClass(); 187 final ClassWriter writer = new ClassWriter(ClassWriter.COMPUTE_FRAMES); 188 final String name = derived.name.replace('/', '.'); 189 derived.accept(writer); 190 final byte[] bytecode = writer.toByteArray(); 191 192 final Class<? extends UnivariateDifferentiableFunction> dClass = 193 new DerivativeLoader(differentiableClass).defineClass(name, bytecode); 194 byteCodeMap.put(name, bytecode); 195 return dClass; 196 197 } catch (IOException ioe) { 198 throw new DifferentiationException(NablaMessages.CANNOT_READ_CLASS, 199 differentiableClass.getName(), ioe.getMessage()); 200 } 201 } 202 203 /** Class loader generating derivative classes. */ 204 private static class DerivativeLoader extends ClassLoader { 205 206 /** Simple constructor. 207 * @param differentiableClass differentiable class 208 */ 209 public DerivativeLoader(final Class<? extends UnivariateFunction> differentiableClass) { 210 super(differentiableClass.getClassLoader()); 211 } 212 213 /** Define a derivative class. 214 * @param name name of the differentiated class 215 * @param bytecode bytecode of the differentiated class 216 * @return a generated derivative class 217 */ 218 @SuppressWarnings("unchecked") 219 public Class<? extends UnivariateDifferentiableFunction> 220 defineClass(final String name, final byte[] bytecode) { 221 return (Class<? extends UnivariateDifferentiableFunction>) defineClass(name, bytecode, 0, bytecode.length); 222 } 223 } 224 225 }