001 /* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 package org.apache.commons.nabla.algorithmic.forward.analysis; 018 019 import java.util.ArrayList; 020 import java.util.HashMap; 021 import java.util.HashSet; 022 import java.util.IdentityHashMap; 023 import java.util.Iterator; 024 import java.util.List; 025 import java.util.Map; 026 import java.util.Set; 027 028 import org.apache.commons.nabla.algorithmic.forward.arithmetic.DAddTransformer1; 029 import org.apache.commons.nabla.algorithmic.forward.arithmetic.DAddTransformer12; 030 import org.apache.commons.nabla.algorithmic.forward.arithmetic.DAddTransformer2; 031 import org.apache.commons.nabla.algorithmic.forward.arithmetic.DDivTransformer1; 032 import org.apache.commons.nabla.algorithmic.forward.arithmetic.DDivTransformer12; 033 import org.apache.commons.nabla.algorithmic.forward.arithmetic.DDivTransformer2; 034 import org.apache.commons.nabla.algorithmic.forward.arithmetic.DMulTransformer1; 035 import org.apache.commons.nabla.algorithmic.forward.arithmetic.DMulTransformer12; 036 import org.apache.commons.nabla.algorithmic.forward.arithmetic.DMulTransformer2; 037 import org.apache.commons.nabla.algorithmic.forward.arithmetic.DNegTransformer; 038 import org.apache.commons.nabla.algorithmic.forward.arithmetic.DRemTransformer1; 039 import org.apache.commons.nabla.algorithmic.forward.arithmetic.DRemTransformer12; 040 import org.apache.commons.nabla.algorithmic.forward.arithmetic.DRemTransformer2; 041 import org.apache.commons.nabla.algorithmic.forward.arithmetic.DSubTransformer1; 042 import org.apache.commons.nabla.algorithmic.forward.arithmetic.DSubTransformer12; 043 import org.apache.commons.nabla.algorithmic.forward.arithmetic.DSubTransformer2; 044 import org.apache.commons.nabla.algorithmic.forward.functions.AcosTransformer; 045 import org.apache.commons.nabla.algorithmic.forward.functions.AcoshTransformer; 046 import org.apache.commons.nabla.algorithmic.forward.functions.AsinTransformer; 047 import org.apache.commons.nabla.algorithmic.forward.functions.AsinhTransformer; 048 import org.apache.commons.nabla.algorithmic.forward.functions.Atan2Transformer1; 049 import org.apache.commons.nabla.algorithmic.forward.functions.Atan2Transformer12; 050 import org.apache.commons.nabla.algorithmic.forward.functions.Atan2Transformer2; 051 import org.apache.commons.nabla.algorithmic.forward.functions.AtanTransformer; 052 import org.apache.commons.nabla.algorithmic.forward.functions.AtanhTransformer; 053 import org.apache.commons.nabla.algorithmic.forward.functions.CbrtTransformer; 054 import org.apache.commons.nabla.algorithmic.forward.functions.CosTransformer; 055 import org.apache.commons.nabla.algorithmic.forward.functions.CoshTransformer; 056 import org.apache.commons.nabla.algorithmic.forward.functions.ExpTransformer; 057 import org.apache.commons.nabla.algorithmic.forward.functions.Expm1Transformer; 058 import org.apache.commons.nabla.algorithmic.forward.functions.HypotTransformer1; 059 import org.apache.commons.nabla.algorithmic.forward.functions.HypotTransformer12; 060 import org.apache.commons.nabla.algorithmic.forward.functions.HypotTransformer2; 061 import org.apache.commons.nabla.algorithmic.forward.functions.Log10Transformer; 062 import org.apache.commons.nabla.algorithmic.forward.functions.Log1pTransformer; 063 import org.apache.commons.nabla.algorithmic.forward.functions.LogTransformer; 064 import org.apache.commons.nabla.algorithmic.forward.functions.MathInvocationTransformer; 065 import org.apache.commons.nabla.algorithmic.forward.functions.PowTransformer1; 066 import org.apache.commons.nabla.algorithmic.forward.functions.PowTransformer12; 067 import org.apache.commons.nabla.algorithmic.forward.functions.PowTransformer2; 068 import org.apache.commons.nabla.algorithmic.forward.functions.SinTransformer; 069 import org.apache.commons.nabla.algorithmic.forward.functions.SinhTransformer; 070 import org.apache.commons.nabla.algorithmic.forward.functions.SqrtTransformer; 071 import org.apache.commons.nabla.algorithmic.forward.functions.TanTransformer; 072 import org.apache.commons.nabla.algorithmic.forward.functions.TanhTransformer; 073 import org.apache.commons.nabla.algorithmic.forward.instructions.DLoadTransformer; 074 import org.apache.commons.nabla.algorithmic.forward.instructions.DReturnTransformer; 075 import org.apache.commons.nabla.algorithmic.forward.instructions.DStoreTransformer; 076 import org.apache.commons.nabla.algorithmic.forward.instructions.DcmpTransformer1; 077 import org.apache.commons.nabla.algorithmic.forward.instructions.DcmpTransformer12; 078 import org.apache.commons.nabla.algorithmic.forward.instructions.DcmpTransformer2; 079 import org.apache.commons.nabla.algorithmic.forward.instructions.Dup2Transformer; 080 import org.apache.commons.nabla.algorithmic.forward.instructions.Dup2X1Transformer; 081 import org.apache.commons.nabla.algorithmic.forward.instructions.Dup2X2Transformer1; 082 import org.apache.commons.nabla.algorithmic.forward.instructions.Dup2X2Transformer12; 083 import org.apache.commons.nabla.algorithmic.forward.instructions.Dup2X2Transformer2; 084 import org.apache.commons.nabla.algorithmic.forward.instructions.NarrowingTransformer; 085 import org.apache.commons.nabla.algorithmic.forward.instructions.WideningTransformer; 086 import org.apache.commons.nabla.algorithmic.forward.trimming.DLoadPop2Trimmer; 087 import org.apache.commons.nabla.algorithmic.forward.trimming.SwappedDloadTrimmer; 088 import org.apache.commons.nabla.algorithmic.forward.trimming.SwappedDstoreTrimmer; 089 import org.apache.commons.nabla.core.DifferentialPair; 090 import org.apache.commons.nabla.core.DifferentiationException; 091 import org.objectweb.asm.MethodVisitor; 092 import org.objectweb.asm.Opcodes; 093 import org.objectweb.asm.tree.AbstractInsnNode; 094 import org.objectweb.asm.tree.FieldInsnNode; 095 import org.objectweb.asm.tree.IincInsnNode; 096 import org.objectweb.asm.tree.InsnList; 097 import org.objectweb.asm.tree.InsnNode; 098 import org.objectweb.asm.tree.LabelNode; 099 import org.objectweb.asm.tree.MethodInsnNode; 100 import org.objectweb.asm.tree.MethodNode; 101 import org.objectweb.asm.tree.VarInsnNode; 102 import org.objectweb.asm.tree.analysis.Analyzer; 103 import org.objectweb.asm.tree.analysis.AnalyzerException; 104 import org.objectweb.asm.tree.analysis.BasicValue; 105 import org.objectweb.asm.tree.analysis.Frame; 106 import org.objectweb.asm.tree.analysis.Interpreter; 107 108 /** Class transforming a method computing a value to a method 109 * computing both a value and its differential. 110 */ 111 public class MethodDifferentiator extends MethodNode { 112 113 /** Name for the DifferentialPair class. */ 114 public static final String DP_NAME = DifferentialPair.class.getName().replace('.', '/'); 115 116 /** Descriptor for the DifferentialPair class. */ 117 public static final String DP_DESCRIPTOR = "L" + DP_NAME + ";"; 118 119 /** Descriptor for the derivative class f method. */ 120 public static final String DP_RETURN_DP_DESCRIPTOR = "(" + DP_DESCRIPTOR + ")" + DP_DESCRIPTOR; 121 122 /** Descriptor for <code>double f()</code> methods. */ 123 private static final String VOID_RETURN_D_DESCRIPTOR = "()D"; 124 125 /** Math functions transformer. */ 126 private static final Map<String, MathInvocationTransformer> MATH_TRANSFORMERS = 127 new HashMap<String, MathInvocationTransformer>(); 128 129 static { 130 MATH_TRANSFORMERS.put("acos", new AcosTransformer()); 131 MATH_TRANSFORMERS.put("acosh", new AcoshTransformer()); 132 MATH_TRANSFORMERS.put("asin", new AsinTransformer()); 133 MATH_TRANSFORMERS.put("asinh", new AsinhTransformer()); 134 MATH_TRANSFORMERS.put("atan2_12", new Atan2Transformer12()); 135 MATH_TRANSFORMERS.put("atan2_1", new Atan2Transformer1()); 136 MATH_TRANSFORMERS.put("atan2_2", new Atan2Transformer2()); 137 MATH_TRANSFORMERS.put("atan", new AtanTransformer()); 138 MATH_TRANSFORMERS.put("atanh", new AtanhTransformer()); 139 MATH_TRANSFORMERS.put("cbrt", new CbrtTransformer()); 140 MATH_TRANSFORMERS.put("cos", new CosTransformer()); 141 MATH_TRANSFORMERS.put("cosh", new CoshTransformer()); 142 MATH_TRANSFORMERS.put("exp", new ExpTransformer()); 143 MATH_TRANSFORMERS.put("expm1", new Expm1Transformer()); 144 MATH_TRANSFORMERS.put("hypot_12", new HypotTransformer12()); 145 MATH_TRANSFORMERS.put("hypot_1", new HypotTransformer1()); 146 MATH_TRANSFORMERS.put("hypot_2", new HypotTransformer2()); 147 MATH_TRANSFORMERS.put("log10", new Log10Transformer()); 148 MATH_TRANSFORMERS.put("log1p", new Log1pTransformer()); 149 MATH_TRANSFORMERS.put("log", new LogTransformer()); 150 MATH_TRANSFORMERS.put("pow_12", new PowTransformer12()); 151 MATH_TRANSFORMERS.put("pow_1", new PowTransformer1()); 152 MATH_TRANSFORMERS.put("pow_2", new PowTransformer2()); 153 MATH_TRANSFORMERS.put("sin", new SinTransformer()); 154 MATH_TRANSFORMERS.put("sinh", new SinhTransformer()); 155 MATH_TRANSFORMERS.put("sqrt", new SqrtTransformer()); 156 MATH_TRANSFORMERS.put("tan", new TanTransformer()); 157 MATH_TRANSFORMERS.put("tanh", new TanhTransformer()); 158 } 159 160 /** Message format for unknown method. */ 161 private static final String UNKNOWN_METHOD_FMT = "unknown method {0}.{1}"; 162 163 /** Maximal number of temporary size 2 variables. */ 164 private static final int MAX_TEMP = 5; 165 166 /** Math implementation classes. */ 167 private final Set<String> mathClasses; 168 169 /** Generator to use. */ 170 private final MethodVisitor generator; 171 172 /** Used locals variables array. */ 173 private boolean[] usedLocals; 174 175 /** Primitive class name. */ 176 private final String primitiveName; 177 178 /** Error reporter to use. */ 179 private final ErrorReporter errorReporter; 180 181 /** Set of converted values. */ 182 private final Set<TrackingValue> converted; 183 184 /** Frames for the original method. */ 185 private final Map<AbstractInsnNode, Frame> frames; 186 187 /** Instructions successors array. */ 188 private final Map<AbstractInsnNode, Set<AbstractInsnNode>> successors; 189 190 /** Cloned labels map. */ 191 private final Map<LabelNode, LabelNode> clonedLabels; 192 193 /** Build a differentiator for a method. 194 * @param access access flags of the method 195 * @param name name of the method 196 * @param desc descriptor of the method 197 * @param signature signature of the method 198 * @param exceptions exceptions thrown by the method 199 * @param generator bytecode generator to use for the transformed method 200 * @param primitiveName primitive class name 201 * @param mathClasses math implementation classes 202 * @param errorReporter reporter used for delaying exceptions 203 */ 204 public MethodDifferentiator(final int access, final String name, final String desc, 205 final String signature, final String[] exceptions, 206 final MethodVisitor generator,final String primitiveName, 207 final Set<String> mathClasses, 208 final ErrorReporter errorReporter) { 209 210 super(access, name, desc, signature, exceptions); 211 this.generator = generator; 212 this.usedLocals = null; 213 this.primitiveName = primitiveName; 214 this.mathClasses = mathClasses; 215 this.errorReporter = errorReporter; 216 this.converted = new HashSet<TrackingValue>(); 217 this.frames = new IdentityHashMap<AbstractInsnNode, Frame>(); 218 this.successors = new IdentityHashMap<AbstractInsnNode, Set<AbstractInsnNode>>(); 219 this.clonedLabels = new HashMap<LabelNode, LabelNode>(); 220 221 } 222 223 /** {@inheritDoc} */ 224 @Override 225 public void visitEnd() { 226 try { 227 228 // at start, "this" and one differential pair are used 229 maxLocals = 2 * (maxLocals + MAX_TEMP) - 1; 230 usedLocals = new boolean[maxLocals]; 231 useLocal(0, 1); 232 useLocal(1, 4); 233 234 // add spare cells to hold new variables if needed 235 addSpareLocalVariables(); 236 237 // analyze the original code, tracing values production/consumption 238 final Frame[] array = 239 new FlowAnalyzer(new TrackingInterpreter()).analyze(primitiveName, this); 240 241 // convert the array into a map, since code changes will shift all indices 242 for (int i = 0; i < array.length; ++i) { 243 frames.put(instructions.get(i), array[i]); 244 } 245 246 // identify the needed changes 247 final Set<AbstractInsnNode> changes = identifyChanges(); 248 249 if (changes.isEmpty()) { 250 251 // the method does not depend on the parameter at all! 252 // we replace all code by a simple "return DifferentialPair.ZERO;" 253 instructions.clear(); 254 instructions.add(new FieldInsnNode(Opcodes.GETSTATIC, DP_NAME, "ZERO", DP_DESCRIPTOR)); 255 instructions.add(new InsnNode(Opcodes.ARETURN)); 256 257 } else { 258 259 // perform the code changes 260 changeCode(changes); 261 262 // remove the local variables added at the beginning and not used 263 removeUnusedSpareLocalVariables(); 264 265 // trim generated instructions list 266 SwappedDloadTrimmer.getInstance().trim(instructions); 267 SwappedDstoreTrimmer.getInstance().trim(instructions); 268 DLoadPop2Trimmer.getInstance().trim(instructions); 269 270 } 271 272 // change the descriptor to its true final value 273 desc = DP_RETURN_DP_DESCRIPTOR; 274 275 // generate the method 276 accept(generator); 277 278 } catch (AnalyzerException ae) { 279 if ((ae.getCause() != null) && ae.getCause() instanceof DifferentiationException) { 280 errorReporter.register((DifferentiationException) ae.getCause()); 281 } else { 282 final DifferentiationException de = 283 new DifferentiationException("unable to analyze the {0}.{1} method ({2})", 284 new Object[] { 285 primitiveName, name, ae.getMessage() 286 }); 287 errorReporter.register(de); 288 } 289 } catch (DifferentiationException de) { 290 errorReporter.register(de); 291 } 292 } 293 294 /** Add spare cells for new local variables. 295 * <p>In order to ease conversion from double values to differential pairs, 296 * we start by reserving one spare cell between each original local variables. 297 * So we have to modify the indices in all instructions referencing local 298 * variables in the original code, to take into account the renumbering 299 * introduced by these spare cells. The spare cells by themselves will 300 * be referenced by the converted instructions in the following passes.</p> 301 * <p>The spare cells that will not be used will be reclaimed after 302 * conversion, to avoid wasting memory.</p> 303 * @exception DifferentiationException if local variables array has not been 304 * expanded appropriately beforehand 305 * @see #removeUnusedSpareLocalVariables() 306 */ 307 private void addSpareLocalVariables() throws DifferentiationException { 308 for (final Iterator<?> i = instructions.iterator(); i.hasNext();) { 309 final AbstractInsnNode insn = (AbstractInsnNode) i.next(); 310 if (insn.getType() == AbstractInsnNode.VAR_INSN) { 311 final VarInsnNode varInsn = (VarInsnNode) insn; 312 if (varInsn.var > 2) { 313 varInsn.var = 2 * varInsn.var - 1; 314 final int opcode = varInsn.getOpcode(); 315 if ((opcode == Opcodes.ILOAD) || (opcode == Opcodes.FLOAD) || 316 (opcode == Opcodes.ALOAD) || (opcode == Opcodes.ISTORE) || 317 (opcode == Opcodes.FSTORE) || (opcode == Opcodes.ASTORE)) { 318 useLocal(varInsn.var, 1); 319 } else { 320 useLocal(varInsn.var, 2); 321 } 322 } 323 } else if (insn.getOpcode() == Opcodes.IINC) { 324 final IincInsnNode iincInsn = (IincInsnNode) insn; 325 if (iincInsn.var > 2) { 326 iincInsn.var = 2 * iincInsn.var - 1; 327 useLocal(iincInsn.var, 1); 328 } 329 } 330 } 331 } 332 333 /** Remove the unused spare cells introduced at conversion start. 334 * @see #addSpareLocalVariables() 335 */ 336 private void removeUnusedSpareLocalVariables() { 337 for (final Iterator<?> i = instructions.iterator(); i.hasNext();) { 338 final AbstractInsnNode insn = (AbstractInsnNode) i.next(); 339 if (insn.getType() == AbstractInsnNode.VAR_INSN) { 340 shiftVariable((VarInsnNode) insn); 341 } 342 } 343 } 344 345 /** Identify the instructions that must be changed. 346 * <p>Identification is based on data flow analysis. We start by changing 347 * the local variables in the initial frame to match the parameters of 348 * the derivative method, and propagate these variables following the 349 * instructions path, updating stack cells and local variables as needed. 350 * Instructions that must be changed are the ones that consume changed 351 * variables or stack cells.</p> 352 * @return set containing all the instructions that must be changed 353 */ 354 private Set<AbstractInsnNode> identifyChanges() { 355 356 // the pending set contains the values (local variables or stack cells) 357 // that have been changed, they will trigger changes on the instructions 358 // that consume them 359 final Set<TrackingValue> pending = new HashSet<TrackingValue>(); 360 361 // the changes set contains the instructions that must be changed 362 final Set<AbstractInsnNode> changes = new HashSet<AbstractInsnNode>(); 363 364 // start by converting the parameter of the method, 365 // which is kept in local variable 1 of the initial frame 366 final TrackingValue dpParameter = (TrackingValue) frames.get(instructions.get(0)).getLocal(1); 367 pending.add(dpParameter); 368 369 // propagate the values conversions throughout the method 370 while (!pending.isEmpty()) { 371 372 // pop one element from the set of changed values 373 final Iterator<TrackingValue> iterator = pending.iterator(); 374 final TrackingValue value = iterator.next(); 375 iterator.remove(); 376 377 // this value is converted 378 converted.add(value); 379 380 // check the consumers instructions for this value 381 for (final AbstractInsnNode consumer : value.getConsumers()) { 382 383 // an instruction consuming a converted value and producing 384 // a double must be changed to produce a differential pair, 385 // get the double values produced and add them to the changed set 386 for (TrackingValue produced : getProducedDoubleValues(consumer)) { 387 388 // add it to the pending set if it has not already been processed 389 if (!converted.contains(produced)) { 390 pending.add(produced); 391 } 392 393 } 394 395 // as a consumer of a converted value, the instruction must be changed 396 changes.add(consumer); 397 398 } 399 400 // check the producers instructions for this value 401 for (final AbstractInsnNode producer : value.getProducers()) { 402 403 // an instruction producing a converted value must be changed 404 changes.add(producer); 405 406 } 407 } 408 409 return changes; 410 411 } 412 413 /** Get the list of double values produced by an instruction. 414 * @param instruction instruction producing the values 415 * @return list of double values produced 416 */ 417 private List<TrackingValue> getProducedDoubleValues(final AbstractInsnNode instruction) { 418 419 final List<TrackingValue> values = new ArrayList<TrackingValue>(); 420 421 // get the frame before instruction execution 422 final Frame before = frames.get(instruction); 423 final int beforeStackSize = before.getStackSize(); 424 final int locals = before.getLocals(); 425 426 // check the frames produced by this instruction 427 // (they correspond to the input frames of its successors) 428 final Set<AbstractInsnNode> set = successors.get(instruction); 429 if (set != null) { 430 431 // loop over the successors of this instruction 432 for (final AbstractInsnNode successor : set) { 433 final Frame produced = frames.get(successor); 434 435 // check the stack cells 436 for (int i = 0; i < produced.getStackSize(); ++i) { 437 final TrackingValue value = (TrackingValue) produced.getStack(i); 438 if (((i >= beforeStackSize) || (value != before.getStack(i))) && 439 value.getValue().equals(BasicValue.DOUBLE_VALUE)) { 440 values.add(value); 441 } 442 } 443 444 // check the local variables 445 for (int i = 0; i < locals; ++i) { 446 final TrackingValue value = (TrackingValue) produced.getLocal(i); 447 if ((value != before.getLocal(i)) && 448 value.getValue().equals(BasicValue.DOUBLE_VALUE)) { 449 values.add(value); 450 } 451 } 452 } 453 } 454 455 return values; 456 457 } 458 459 /** Perform the code changes. 460 * @param changes instructions that must be changed 461 * @exception DifferentiationException if some instruction cannot be handled 462 */ 463 private void changeCode(final Set<AbstractInsnNode> changes) 464 throws DifferentiationException { 465 466 // insert the parameter conversion code at method start 467 final InsnList list = new InsnList(); 468 list.add(new VarInsnNode(Opcodes.ALOAD, 1)); 469 list.add(new InsnNode(Opcodes.DUP)); 470 list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, DP_NAME, 471 "getValue", VOID_RETURN_D_DESCRIPTOR)); 472 list.add(new VarInsnNode(Opcodes.DSTORE, 1)); 473 list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, DP_NAME, 474 "getFirstDerivative", VOID_RETURN_D_DESCRIPTOR)); 475 list.add(new VarInsnNode(Opcodes.DSTORE, 3)); 476 477 instructions.insertBefore(instructions.get(0), list); 478 479 // transform the existing instructions 480 for (final AbstractInsnNode insn : changes) { 481 instructions.insert(insn, getReplacement(insn)); 482 instructions.remove(insn); 483 } 484 485 } 486 487 /** Get the replacement list for an instruction. 488 * @param insn instruction to replace 489 * @return replacement instructions list 490 * @exception DifferentiationException if some instruction cannot be handled 491 * or if no temporary variable can be reserved 492 */ 493 private InsnList getReplacement(final AbstractInsnNode insn) 494 throws DifferentiationException { 495 496 // get the frame at the start of the instruction 497 final Frame frame = frames.get(insn); 498 final int size = frame.getStackSize(); 499 final boolean stack1Converted = (size > 0) && converted.contains(frame.getStack(size - 2)); 500 final boolean stack0Converted = (size > 1) && converted.contains(frame.getStack(size - 1)); 501 502 switch(insn.getOpcode()) { 503 case Opcodes.DLOAD : 504 useLocal(((VarInsnNode) insn).var, 4); 505 return DLoadTransformer.getInstance().getReplacement(insn, this); 506 case Opcodes.DALOAD : 507 // TODO add support for DALOAD differentiation 508 throw new RuntimeException("DALOAD not handled yet"); 509 case Opcodes.DSTORE : 510 useLocal(((VarInsnNode) insn).var, 4); 511 return DStoreTransformer.getInstance().getReplacement(insn, this); 512 case Opcodes.DASTORE : 513 // TODO add support for DASTORE differentiation 514 throw new RuntimeException("DASTORE not handled yet"); 515 case Opcodes.DUP2 : 516 return Dup2Transformer.getInstance().getReplacement(insn, this); 517 case Opcodes.DUP2_X1 : 518 return Dup2X1Transformer.getInstance().getReplacement(insn, this); 519 case Opcodes.DUP2_X2 : 520 if (stack1Converted) { 521 if (stack0Converted) { 522 return Dup2X2Transformer12.getInstance().getReplacement(insn, this); 523 } 524 return Dup2X2Transformer1.getInstance().getReplacement(insn, this); 525 } 526 return Dup2X2Transformer2.getInstance().getReplacement(insn, this); 527 case Opcodes.DADD : 528 if (stack1Converted) { 529 if (stack0Converted) { 530 return DAddTransformer12.getInstance().getReplacement(insn, this); 531 } 532 return DAddTransformer1.getInstance().getReplacement(insn, this); 533 } 534 return DAddTransformer2.getInstance().getReplacement(insn, this); 535 case Opcodes.DSUB : 536 if (stack1Converted) { 537 if (stack0Converted) { 538 return DSubTransformer12.getInstance().getReplacement(insn, this); 539 } 540 return DSubTransformer1.getInstance().getReplacement(insn, this); 541 } 542 return DSubTransformer2.getInstance().getReplacement(insn, this); 543 case Opcodes.DMUL : 544 if (stack1Converted) { 545 if (stack0Converted) { 546 return DMulTransformer12.getInstance().getReplacement(insn, this); 547 } 548 return DMulTransformer1.getInstance().getReplacement(insn, this); 549 } 550 return DMulTransformer2.getInstance().getReplacement(insn, this); 551 case Opcodes.DDIV : 552 if (stack1Converted) { 553 if (stack0Converted) { 554 return DDivTransformer12.getInstance().getReplacement(insn, this); 555 } 556 return DDivTransformer1.getInstance().getReplacement(insn, this); 557 } 558 return DDivTransformer2.getInstance().getReplacement(insn, this); 559 case Opcodes.DREM : 560 if (stack1Converted) { 561 if (stack0Converted) { 562 return DRemTransformer12.getInstance().getReplacement(insn, this); 563 } 564 return DRemTransformer1.getInstance().getReplacement(insn, this); 565 } 566 return DRemTransformer2.getInstance().getReplacement(insn, this); 567 case Opcodes.DNEG : 568 return DNegTransformer.getInstance().getReplacement(insn, this); 569 case Opcodes.DCONST_0 : 570 case Opcodes.DCONST_1 : 571 case Opcodes.LDC : 572 case Opcodes.I2D : 573 case Opcodes.L2D : 574 case Opcodes.F2D : 575 return WideningTransformer.getInstance().getReplacement(insn, this); 576 case Opcodes.POP2 : 577 case Opcodes.D2I : 578 case Opcodes.D2L : 579 case Opcodes.D2F : 580 return NarrowingTransformer.getInstance().getReplacement(insn, this); 581 case Opcodes.DCMPL : 582 case Opcodes.DCMPG : 583 if (stack1Converted) { 584 if (stack0Converted) { 585 return DcmpTransformer12.getInstance().getReplacement(insn, this); 586 } 587 return DcmpTransformer1.getInstance().getReplacement(insn, this); 588 } 589 return DcmpTransformer2.getInstance().getReplacement(insn, this); 590 case Opcodes.DRETURN : 591 return DReturnTransformer.getInstance().getReplacement(insn, this); 592 case Opcodes.GETSTATIC : 593 // TODO add support for GETSTATIC differentiation 594 throw new RuntimeException("GETSTATIC not handled yet"); 595 case Opcodes.PUTSTATIC : 596 // TODO add support for PUTSTATIC differentiation 597 throw new RuntimeException("PUTSTATIC not handled yet"); 598 case Opcodes.GETFIELD : 599 // TODO add support for GETFIELD differentiation 600 throw new RuntimeException("GETFIELD not handled yet"); 601 case Opcodes.PUTFIELD : 602 // TODO add support for PUTFIELD differentiation 603 throw new RuntimeException("PUTFIELD not handled yet"); 604 case Opcodes.INVOKEVIRTUAL : 605 // TODO add support for INVOKEVIRTUAL differentiation 606 throw new RuntimeException("INVOKEVIRTUAL not handled yet"); 607 case Opcodes.INVOKESPECIAL : 608 // TODO add support for INVOKESPECIAL differentiation 609 throw new RuntimeException("INVOKESPECIAL not handled yet"); 610 case Opcodes.INVOKESTATIC : 611 return replaceInvocation((MethodInsnNode) insn, 612 stack1Converted, stack0Converted); 613 case Opcodes.INVOKEINTERFACE : 614 // TODO add support for INVOKEINTERFACE differentiation 615 throw new RuntimeException("INVOKEINTERFACE not handled yet"); 616 case Opcodes.NEWARRAY : 617 // TODO add support for NEWARRAY differentiation 618 throw new RuntimeException("NEWARRAY not handled yet"); 619 case Opcodes.ANEWARRAY : 620 // TODO add support for ANEWARRAY differentiation 621 throw new RuntimeException("ANEWARRAY not handled yet"); 622 case Opcodes.MULTIANEWARRAY : 623 // TODO add support for MULTIANEWARRAY differentiation 624 throw new RuntimeException("MULTIANEWARRAY not handled yet"); 625 default: 626 throw new DifferentiationException("unable to handle instruction with opcode {0}", 627 new Object[] { 628 Integer.valueOf(insn.getOpcode()) 629 }); 630 } 631 632 } 633 634 /** Replace an INVOKESTATIC instruction. 635 * @param methodInsn invocation instruction 636 * @param stack1Converted if true, the stack sub-head has been 637 * converted to differential pair 638 * @param stack0Converted if true, the stack head has been 639 * converted to differential pair 640 * @return replacement instructions list 641 * @exception DifferentiationException if the instruction cannot be replaced 642 */ 643 private InsnList replaceInvocation(final MethodInsnNode methodInsn, 644 final boolean stack1Converted, 645 final boolean stack0Converted) 646 throws DifferentiationException { 647 if (isMathImplementationClass(methodInsn.owner)) { 648 if ("(D)D".equals(methodInsn.desc)) { 649 // this is a univariate method like sin, cos, exp ... 650 final MathInvocationTransformer transformer = MATH_TRANSFORMERS.get(methodInsn.name); 651 if (transformer == null) { 652 throw new DifferentiationException(UNKNOWN_METHOD_FMT, 653 methodInsn.owner, methodInsn.name); 654 } 655 return transformer.getReplacementList(methodInsn.owner, this); 656 } else if ("(DD)D".equals(methodInsn.desc)) { 657 // this is a bivariate method like atan2, pow ... 658 659 // we may want to differentiate against first, second or both parameters 660 String name = null; 661 if (stack1Converted) { 662 if (stack0Converted) { 663 name = methodInsn.name + "_12"; 664 } else { 665 name = methodInsn.name + "_1"; 666 } 667 } else if (stack0Converted) { 668 name = methodInsn.name + "_2"; 669 } 670 671 if (name != null) { 672 final MathInvocationTransformer transformer = MATH_TRANSFORMERS.get(name); 673 if (transformer == null) { 674 throw new DifferentiationException(UNKNOWN_METHOD_FMT, 675 methodInsn.owner, methodInsn.name); 676 } 677 return transformer.getReplacementList(methodInsn.owner, this); 678 } 679 } 680 } 681 throw new DifferentiationException("unexpected instruction {0}", 682 Integer.valueOf(methodInsn.getOpcode())); 683 } 684 685 /** Test if a class is a math implementation class. 686 * @param name name of the class to test 687 * @return true if the named class is a math implementation class 688 */ 689 public boolean isMathImplementationClass(final String name) { 690 return mathClasses.contains(name); 691 } 692 693 /** Set a local variable as used by the modified code. 694 * @param index index of the variable 695 * @param size size of the variable (1 or 2 for standard variables, 696 * 4 for special expanded differential pairs) 697 * @exception DifferentiationException if the number of the 698 * temporary variable lies outside of the allowed range 699 */ 700 public void useLocal(final int index, final int size) 701 throws DifferentiationException { 702 if ((index < 0) || ((index + size - 1) >= usedLocals.length)) { 703 throw new DifferentiationException("index of size {0} local variable ({1}) " + 704 "outside of [{2}, {3}] range", 705 Integer.valueOf(size), Integer.valueOf(index), 706 Integer.valueOf(1), Integer.valueOf(MAX_TEMP)); 707 } 708 for (int i = index; i < index + size; ++i) { 709 usedLocals[i] = true; 710 } 711 } 712 713 /** Get index of a size 2 temporary variable. 714 * <p>Temporary variables can be used for very short duration 715 * storage of size 2 values (i.e one double, or one long or two 716 * integers). These variables are reused in many replacement 717 * instructions sequences, so their content may be overridden 718 * at any time: they should be considered to have a scope 719 * limited to one replacement sequence only. This means that 720 * one should <em>not</em> store a value in a variable in one 721 * replacement sequence and retrieve it later in another 722 * replacement sequence as it may have been overridden in 723 * between.</p> 724 * <p>At most 5 temporary variables may be used.</p> 725 * @param number number of the temporary variable (must be 726 * between 1 and the maximal number of available temporary 727 * variables) 728 * @return index of the variable within the local variables 729 * array 730 * @exception DifferentiationException if the number of the 731 * temporary variable lies outside of the allowed range 732 */ 733 public int getTmp(final int number) throws DifferentiationException { 734 if ((number < 0) || (number > MAX_TEMP)) { 735 throw new DifferentiationException("number of temporary variable ({0}) outside of [{1}, {2}] range", 736 Integer.valueOf(number), 737 Integer.valueOf(1), 738 Integer.valueOf(MAX_TEMP)); 739 } 740 final int index = usedLocals.length - 2 * number; 741 useLocal(index, 2); 742 return index; 743 } 744 745 /** Shifted the index of a variable instruction. 746 * @param insn variable instruction 747 */ 748 public void shiftVariable(final VarInsnNode insn) { 749 int shifted = 0; 750 for (int i = 0; i < insn.var; ++i) { 751 if (usedLocals[i]) { 752 ++shifted; 753 } 754 } 755 insn.var = shifted; 756 } 757 758 /** Clone an instruction. 759 * @param insn instruction to clone 760 * @return cloned instruction 761 */ 762 public AbstractInsnNode clone(final AbstractInsnNode insn) { 763 return insn.clone(clonedLabels); 764 } 765 766 /** Analyzer preserving instructions successors information. */ 767 private class FlowAnalyzer extends Analyzer { 768 769 /** Simple constructor. 770 * @param interpreter associated interpreter 771 */ 772 public FlowAnalyzer(final Interpreter interpreter) { 773 super(interpreter); 774 } 775 776 /** Store a new edge. 777 * @param insn current instruction 778 * @param successor successor instruction 779 */ 780 protected void newControlFlowEdge(final int insn, final int successor) { 781 // store the successor information 782 final AbstractInsnNode node = instructions.get(insn); 783 Set<AbstractInsnNode> set = successors.get(node); 784 if (set == null) { 785 set = new HashSet<AbstractInsnNode>(); 786 successors.put(node, set); 787 } 788 set.add(instructions.get(successor)); 789 } 790 791 } 792 793 }