001 /*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements. See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License. You may obtain a copy of the License at
008 *
009 * http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017 package org.apache.commons.nabla.algorithmic.forward.analysis;
018
019 import java.util.ArrayList;
020 import java.util.HashMap;
021 import java.util.HashSet;
022 import java.util.IdentityHashMap;
023 import java.util.Iterator;
024 import java.util.List;
025 import java.util.Map;
026 import java.util.Set;
027
028 import org.apache.commons.nabla.algorithmic.forward.arithmetic.DAddTransformer1;
029 import org.apache.commons.nabla.algorithmic.forward.arithmetic.DAddTransformer12;
030 import org.apache.commons.nabla.algorithmic.forward.arithmetic.DAddTransformer2;
031 import org.apache.commons.nabla.algorithmic.forward.arithmetic.DDivTransformer1;
032 import org.apache.commons.nabla.algorithmic.forward.arithmetic.DDivTransformer12;
033 import org.apache.commons.nabla.algorithmic.forward.arithmetic.DDivTransformer2;
034 import org.apache.commons.nabla.algorithmic.forward.arithmetic.DMulTransformer1;
035 import org.apache.commons.nabla.algorithmic.forward.arithmetic.DMulTransformer12;
036 import org.apache.commons.nabla.algorithmic.forward.arithmetic.DMulTransformer2;
037 import org.apache.commons.nabla.algorithmic.forward.arithmetic.DNegTransformer;
038 import org.apache.commons.nabla.algorithmic.forward.arithmetic.DRemTransformer1;
039 import org.apache.commons.nabla.algorithmic.forward.arithmetic.DRemTransformer12;
040 import org.apache.commons.nabla.algorithmic.forward.arithmetic.DRemTransformer2;
041 import org.apache.commons.nabla.algorithmic.forward.arithmetic.DSubTransformer1;
042 import org.apache.commons.nabla.algorithmic.forward.arithmetic.DSubTransformer12;
043 import org.apache.commons.nabla.algorithmic.forward.arithmetic.DSubTransformer2;
044 import org.apache.commons.nabla.algorithmic.forward.functions.AcosTransformer;
045 import org.apache.commons.nabla.algorithmic.forward.functions.AcoshTransformer;
046 import org.apache.commons.nabla.algorithmic.forward.functions.AsinTransformer;
047 import org.apache.commons.nabla.algorithmic.forward.functions.AsinhTransformer;
048 import org.apache.commons.nabla.algorithmic.forward.functions.Atan2Transformer1;
049 import org.apache.commons.nabla.algorithmic.forward.functions.Atan2Transformer12;
050 import org.apache.commons.nabla.algorithmic.forward.functions.Atan2Transformer2;
051 import org.apache.commons.nabla.algorithmic.forward.functions.AtanTransformer;
052 import org.apache.commons.nabla.algorithmic.forward.functions.AtanhTransformer;
053 import org.apache.commons.nabla.algorithmic.forward.functions.CbrtTransformer;
054 import org.apache.commons.nabla.algorithmic.forward.functions.CosTransformer;
055 import org.apache.commons.nabla.algorithmic.forward.functions.CoshTransformer;
056 import org.apache.commons.nabla.algorithmic.forward.functions.ExpTransformer;
057 import org.apache.commons.nabla.algorithmic.forward.functions.Expm1Transformer;
058 import org.apache.commons.nabla.algorithmic.forward.functions.HypotTransformer1;
059 import org.apache.commons.nabla.algorithmic.forward.functions.HypotTransformer12;
060 import org.apache.commons.nabla.algorithmic.forward.functions.HypotTransformer2;
061 import org.apache.commons.nabla.algorithmic.forward.functions.Log10Transformer;
062 import org.apache.commons.nabla.algorithmic.forward.functions.Log1pTransformer;
063 import org.apache.commons.nabla.algorithmic.forward.functions.LogTransformer;
064 import org.apache.commons.nabla.algorithmic.forward.functions.MathInvocationTransformer;
065 import org.apache.commons.nabla.algorithmic.forward.functions.PowTransformer1;
066 import org.apache.commons.nabla.algorithmic.forward.functions.PowTransformer12;
067 import org.apache.commons.nabla.algorithmic.forward.functions.PowTransformer2;
068 import org.apache.commons.nabla.algorithmic.forward.functions.SinTransformer;
069 import org.apache.commons.nabla.algorithmic.forward.functions.SinhTransformer;
070 import org.apache.commons.nabla.algorithmic.forward.functions.SqrtTransformer;
071 import org.apache.commons.nabla.algorithmic.forward.functions.TanTransformer;
072 import org.apache.commons.nabla.algorithmic.forward.functions.TanhTransformer;
073 import org.apache.commons.nabla.algorithmic.forward.instructions.DLoadTransformer;
074 import org.apache.commons.nabla.algorithmic.forward.instructions.DReturnTransformer;
075 import org.apache.commons.nabla.algorithmic.forward.instructions.DStoreTransformer;
076 import org.apache.commons.nabla.algorithmic.forward.instructions.DcmpTransformer1;
077 import org.apache.commons.nabla.algorithmic.forward.instructions.DcmpTransformer12;
078 import org.apache.commons.nabla.algorithmic.forward.instructions.DcmpTransformer2;
079 import org.apache.commons.nabla.algorithmic.forward.instructions.Dup2Transformer;
080 import org.apache.commons.nabla.algorithmic.forward.instructions.Dup2X1Transformer;
081 import org.apache.commons.nabla.algorithmic.forward.instructions.Dup2X2Transformer1;
082 import org.apache.commons.nabla.algorithmic.forward.instructions.Dup2X2Transformer12;
083 import org.apache.commons.nabla.algorithmic.forward.instructions.Dup2X2Transformer2;
084 import org.apache.commons.nabla.algorithmic.forward.instructions.NarrowingTransformer;
085 import org.apache.commons.nabla.algorithmic.forward.instructions.WideningTransformer;
086 import org.apache.commons.nabla.algorithmic.forward.trimming.DLoadPop2Trimmer;
087 import org.apache.commons.nabla.algorithmic.forward.trimming.SwappedDloadTrimmer;
088 import org.apache.commons.nabla.algorithmic.forward.trimming.SwappedDstoreTrimmer;
089 import org.apache.commons.nabla.core.DifferentialPair;
090 import org.apache.commons.nabla.core.DifferentiationException;
091 import org.objectweb.asm.MethodVisitor;
092 import org.objectweb.asm.Opcodes;
093 import org.objectweb.asm.tree.AbstractInsnNode;
094 import org.objectweb.asm.tree.FieldInsnNode;
095 import org.objectweb.asm.tree.IincInsnNode;
096 import org.objectweb.asm.tree.InsnList;
097 import org.objectweb.asm.tree.InsnNode;
098 import org.objectweb.asm.tree.LabelNode;
099 import org.objectweb.asm.tree.MethodInsnNode;
100 import org.objectweb.asm.tree.MethodNode;
101 import org.objectweb.asm.tree.VarInsnNode;
102 import org.objectweb.asm.tree.analysis.Analyzer;
103 import org.objectweb.asm.tree.analysis.AnalyzerException;
104 import org.objectweb.asm.tree.analysis.BasicValue;
105 import org.objectweb.asm.tree.analysis.Frame;
106 import org.objectweb.asm.tree.analysis.Interpreter;
107
108 /** Class transforming a method computing a value to a method
109 * computing both a value and its differential.
110 */
111 public class MethodDifferentiator extends MethodNode {
112
113 /** Name for the DifferentialPair class. */
114 public static final String DP_NAME = DifferentialPair.class.getName().replace('.', '/');
115
116 /** Descriptor for the DifferentialPair class. */
117 public static final String DP_DESCRIPTOR = "L" + DP_NAME + ";";
118
119 /** Descriptor for the derivative class f method. */
120 public static final String DP_RETURN_DP_DESCRIPTOR = "(" + DP_DESCRIPTOR + ")" + DP_DESCRIPTOR;
121
122 /** Descriptor for <code>double f()</code> methods. */
123 private static final String VOID_RETURN_D_DESCRIPTOR = "()D";
124
125 /** Math functions transformer. */
126 private static final Map<String, MathInvocationTransformer> MATH_TRANSFORMERS =
127 new HashMap<String, MathInvocationTransformer>();
128
129 static {
130 MATH_TRANSFORMERS.put("acos", new AcosTransformer());
131 MATH_TRANSFORMERS.put("acosh", new AcoshTransformer());
132 MATH_TRANSFORMERS.put("asin", new AsinTransformer());
133 MATH_TRANSFORMERS.put("asinh", new AsinhTransformer());
134 MATH_TRANSFORMERS.put("atan2_12", new Atan2Transformer12());
135 MATH_TRANSFORMERS.put("atan2_1", new Atan2Transformer1());
136 MATH_TRANSFORMERS.put("atan2_2", new Atan2Transformer2());
137 MATH_TRANSFORMERS.put("atan", new AtanTransformer());
138 MATH_TRANSFORMERS.put("atanh", new AtanhTransformer());
139 MATH_TRANSFORMERS.put("cbrt", new CbrtTransformer());
140 MATH_TRANSFORMERS.put("cos", new CosTransformer());
141 MATH_TRANSFORMERS.put("cosh", new CoshTransformer());
142 MATH_TRANSFORMERS.put("exp", new ExpTransformer());
143 MATH_TRANSFORMERS.put("expm1", new Expm1Transformer());
144 MATH_TRANSFORMERS.put("hypot_12", new HypotTransformer12());
145 MATH_TRANSFORMERS.put("hypot_1", new HypotTransformer1());
146 MATH_TRANSFORMERS.put("hypot_2", new HypotTransformer2());
147 MATH_TRANSFORMERS.put("log10", new Log10Transformer());
148 MATH_TRANSFORMERS.put("log1p", new Log1pTransformer());
149 MATH_TRANSFORMERS.put("log", new LogTransformer());
150 MATH_TRANSFORMERS.put("pow_12", new PowTransformer12());
151 MATH_TRANSFORMERS.put("pow_1", new PowTransformer1());
152 MATH_TRANSFORMERS.put("pow_2", new PowTransformer2());
153 MATH_TRANSFORMERS.put("sin", new SinTransformer());
154 MATH_TRANSFORMERS.put("sinh", new SinhTransformer());
155 MATH_TRANSFORMERS.put("sqrt", new SqrtTransformer());
156 MATH_TRANSFORMERS.put("tan", new TanTransformer());
157 MATH_TRANSFORMERS.put("tanh", new TanhTransformer());
158 }
159
160 /** Message format for unknown method. */
161 private static final String UNKNOWN_METHOD_FMT = "unknown method {0}.{1}";
162
163 /** Maximal number of temporary size 2 variables. */
164 private static final int MAX_TEMP = 5;
165
166 /** Math implementation classes. */
167 private final Set<String> mathClasses;
168
169 /** Generator to use. */
170 private final MethodVisitor generator;
171
172 /** Used locals variables array. */
173 private boolean[] usedLocals;
174
175 /** Primitive class name. */
176 private final String primitiveName;
177
178 /** Error reporter to use. */
179 private final ErrorReporter errorReporter;
180
181 /** Set of converted values. */
182 private final Set<TrackingValue> converted;
183
184 /** Frames for the original method. */
185 private final Map<AbstractInsnNode, Frame> frames;
186
187 /** Instructions successors array. */
188 private final Map<AbstractInsnNode, Set<AbstractInsnNode>> successors;
189
190 /** Cloned labels map. */
191 private final Map<LabelNode, LabelNode> clonedLabels;
192
193 /** Build a differentiator for a method.
194 * @param access access flags of the method
195 * @param name name of the method
196 * @param desc descriptor of the method
197 * @param signature signature of the method
198 * @param exceptions exceptions thrown by the method
199 * @param generator bytecode generator to use for the transformed method
200 * @param primitiveName primitive class name
201 * @param mathClasses math implementation classes
202 * @param errorReporter reporter used for delaying exceptions
203 */
204 public MethodDifferentiator(final int access, final String name, final String desc,
205 final String signature, final String[] exceptions,
206 final MethodVisitor generator,final String primitiveName,
207 final Set<String> mathClasses,
208 final ErrorReporter errorReporter) {
209
210 super(access, name, desc, signature, exceptions);
211 this.generator = generator;
212 this.usedLocals = null;
213 this.primitiveName = primitiveName;
214 this.mathClasses = mathClasses;
215 this.errorReporter = errorReporter;
216 this.converted = new HashSet<TrackingValue>();
217 this.frames = new IdentityHashMap<AbstractInsnNode, Frame>();
218 this.successors = new IdentityHashMap<AbstractInsnNode, Set<AbstractInsnNode>>();
219 this.clonedLabels = new HashMap<LabelNode, LabelNode>();
220
221 }
222
223 /** {@inheritDoc} */
224 @Override
225 public void visitEnd() {
226 try {
227
228 // at start, "this" and one differential pair are used
229 maxLocals = 2 * (maxLocals + MAX_TEMP) - 1;
230 usedLocals = new boolean[maxLocals];
231 useLocal(0, 1);
232 useLocal(1, 4);
233
234 // add spare cells to hold new variables if needed
235 addSpareLocalVariables();
236
237 // analyze the original code, tracing values production/consumption
238 final Frame[] array =
239 new FlowAnalyzer(new TrackingInterpreter()).analyze(primitiveName, this);
240
241 // convert the array into a map, since code changes will shift all indices
242 for (int i = 0; i < array.length; ++i) {
243 frames.put(instructions.get(i), array[i]);
244 }
245
246 // identify the needed changes
247 final Set<AbstractInsnNode> changes = identifyChanges();
248
249 if (changes.isEmpty()) {
250
251 // the method does not depend on the parameter at all!
252 // we replace all code by a simple "return DifferentialPair.ZERO;"
253 instructions.clear();
254 instructions.add(new FieldInsnNode(Opcodes.GETSTATIC, DP_NAME, "ZERO", DP_DESCRIPTOR));
255 instructions.add(new InsnNode(Opcodes.ARETURN));
256
257 } else {
258
259 // perform the code changes
260 changeCode(changes);
261
262 // remove the local variables added at the beginning and not used
263 removeUnusedSpareLocalVariables();
264
265 // trim generated instructions list
266 SwappedDloadTrimmer.getInstance().trim(instructions);
267 SwappedDstoreTrimmer.getInstance().trim(instructions);
268 DLoadPop2Trimmer.getInstance().trim(instructions);
269
270 }
271
272 // change the descriptor to its true final value
273 desc = DP_RETURN_DP_DESCRIPTOR;
274
275 // generate the method
276 accept(generator);
277
278 } catch (AnalyzerException ae) {
279 if ((ae.getCause() != null) && ae.getCause() instanceof DifferentiationException) {
280 errorReporter.register((DifferentiationException) ae.getCause());
281 } else {
282 final DifferentiationException de =
283 new DifferentiationException("unable to analyze the {0}.{1} method ({2})",
284 new Object[] {
285 primitiveName, name, ae.getMessage()
286 });
287 errorReporter.register(de);
288 }
289 } catch (DifferentiationException de) {
290 errorReporter.register(de);
291 }
292 }
293
294 /** Add spare cells for new local variables.
295 * <p>In order to ease conversion from double values to differential pairs,
296 * we start by reserving one spare cell between each original local variables.
297 * So we have to modify the indices in all instructions referencing local
298 * variables in the original code, to take into account the renumbering
299 * introduced by these spare cells. The spare cells by themselves will
300 * be referenced by the converted instructions in the following passes.</p>
301 * <p>The spare cells that will not be used will be reclaimed after
302 * conversion, to avoid wasting memory.</p>
303 * @exception DifferentiationException if local variables array has not been
304 * expanded appropriately beforehand
305 * @see #removeUnusedSpareLocalVariables()
306 */
307 private void addSpareLocalVariables() throws DifferentiationException {
308 for (final Iterator<?> i = instructions.iterator(); i.hasNext();) {
309 final AbstractInsnNode insn = (AbstractInsnNode) i.next();
310 if (insn.getType() == AbstractInsnNode.VAR_INSN) {
311 final VarInsnNode varInsn = (VarInsnNode) insn;
312 if (varInsn.var > 2) {
313 varInsn.var = 2 * varInsn.var - 1;
314 final int opcode = varInsn.getOpcode();
315 if ((opcode == Opcodes.ILOAD) || (opcode == Opcodes.FLOAD) ||
316 (opcode == Opcodes.ALOAD) || (opcode == Opcodes.ISTORE) ||
317 (opcode == Opcodes.FSTORE) || (opcode == Opcodes.ASTORE)) {
318 useLocal(varInsn.var, 1);
319 } else {
320 useLocal(varInsn.var, 2);
321 }
322 }
323 } else if (insn.getOpcode() == Opcodes.IINC) {
324 final IincInsnNode iincInsn = (IincInsnNode) insn;
325 if (iincInsn.var > 2) {
326 iincInsn.var = 2 * iincInsn.var - 1;
327 useLocal(iincInsn.var, 1);
328 }
329 }
330 }
331 }
332
333 /** Remove the unused spare cells introduced at conversion start.
334 * @see #addSpareLocalVariables()
335 */
336 private void removeUnusedSpareLocalVariables() {
337 for (final Iterator<?> i = instructions.iterator(); i.hasNext();) {
338 final AbstractInsnNode insn = (AbstractInsnNode) i.next();
339 if (insn.getType() == AbstractInsnNode.VAR_INSN) {
340 shiftVariable((VarInsnNode) insn);
341 }
342 }
343 }
344
345 /** Identify the instructions that must be changed.
346 * <p>Identification is based on data flow analysis. We start by changing
347 * the local variables in the initial frame to match the parameters of
348 * the derivative method, and propagate these variables following the
349 * instructions path, updating stack cells and local variables as needed.
350 * Instructions that must be changed are the ones that consume changed
351 * variables or stack cells.</p>
352 * @return set containing all the instructions that must be changed
353 */
354 private Set<AbstractInsnNode> identifyChanges() {
355
356 // the pending set contains the values (local variables or stack cells)
357 // that have been changed, they will trigger changes on the instructions
358 // that consume them
359 final Set<TrackingValue> pending = new HashSet<TrackingValue>();
360
361 // the changes set contains the instructions that must be changed
362 final Set<AbstractInsnNode> changes = new HashSet<AbstractInsnNode>();
363
364 // start by converting the parameter of the method,
365 // which is kept in local variable 1 of the initial frame
366 final TrackingValue dpParameter = (TrackingValue) frames.get(instructions.get(0)).getLocal(1);
367 pending.add(dpParameter);
368
369 // propagate the values conversions throughout the method
370 while (!pending.isEmpty()) {
371
372 // pop one element from the set of changed values
373 final Iterator<TrackingValue> iterator = pending.iterator();
374 final TrackingValue value = iterator.next();
375 iterator.remove();
376
377 // this value is converted
378 converted.add(value);
379
380 // check the consumers instructions for this value
381 for (final AbstractInsnNode consumer : value.getConsumers()) {
382
383 // an instruction consuming a converted value and producing
384 // a double must be changed to produce a differential pair,
385 // get the double values produced and add them to the changed set
386 for (TrackingValue produced : getProducedDoubleValues(consumer)) {
387
388 // add it to the pending set if it has not already been processed
389 if (!converted.contains(produced)) {
390 pending.add(produced);
391 }
392
393 }
394
395 // as a consumer of a converted value, the instruction must be changed
396 changes.add(consumer);
397
398 }
399
400 // check the producers instructions for this value
401 for (final AbstractInsnNode producer : value.getProducers()) {
402
403 // an instruction producing a converted value must be changed
404 changes.add(producer);
405
406 }
407 }
408
409 return changes;
410
411 }
412
413 /** Get the list of double values produced by an instruction.
414 * @param instruction instruction producing the values
415 * @return list of double values produced
416 */
417 private List<TrackingValue> getProducedDoubleValues(final AbstractInsnNode instruction) {
418
419 final List<TrackingValue> values = new ArrayList<TrackingValue>();
420
421 // get the frame before instruction execution
422 final Frame before = frames.get(instruction);
423 final int beforeStackSize = before.getStackSize();
424 final int locals = before.getLocals();
425
426 // check the frames produced by this instruction
427 // (they correspond to the input frames of its successors)
428 final Set<AbstractInsnNode> set = successors.get(instruction);
429 if (set != null) {
430
431 // loop over the successors of this instruction
432 for (final AbstractInsnNode successor : set) {
433 final Frame produced = frames.get(successor);
434
435 // check the stack cells
436 for (int i = 0; i < produced.getStackSize(); ++i) {
437 final TrackingValue value = (TrackingValue) produced.getStack(i);
438 if (((i >= beforeStackSize) || (value != before.getStack(i))) &&
439 value.getValue().equals(BasicValue.DOUBLE_VALUE)) {
440 values.add(value);
441 }
442 }
443
444 // check the local variables
445 for (int i = 0; i < locals; ++i) {
446 final TrackingValue value = (TrackingValue) produced.getLocal(i);
447 if ((value != before.getLocal(i)) &&
448 value.getValue().equals(BasicValue.DOUBLE_VALUE)) {
449 values.add(value);
450 }
451 }
452 }
453 }
454
455 return values;
456
457 }
458
459 /** Perform the code changes.
460 * @param changes instructions that must be changed
461 * @exception DifferentiationException if some instruction cannot be handled
462 */
463 private void changeCode(final Set<AbstractInsnNode> changes)
464 throws DifferentiationException {
465
466 // insert the parameter conversion code at method start
467 final InsnList list = new InsnList();
468 list.add(new VarInsnNode(Opcodes.ALOAD, 1));
469 list.add(new InsnNode(Opcodes.DUP));
470 list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, DP_NAME,
471 "getValue", VOID_RETURN_D_DESCRIPTOR));
472 list.add(new VarInsnNode(Opcodes.DSTORE, 1));
473 list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, DP_NAME,
474 "getFirstDerivative", VOID_RETURN_D_DESCRIPTOR));
475 list.add(new VarInsnNode(Opcodes.DSTORE, 3));
476
477 instructions.insertBefore(instructions.get(0), list);
478
479 // transform the existing instructions
480 for (final AbstractInsnNode insn : changes) {
481 instructions.insert(insn, getReplacement(insn));
482 instructions.remove(insn);
483 }
484
485 }
486
487 /** Get the replacement list for an instruction.
488 * @param insn instruction to replace
489 * @return replacement instructions list
490 * @exception DifferentiationException if some instruction cannot be handled
491 * or if no temporary variable can be reserved
492 */
493 private InsnList getReplacement(final AbstractInsnNode insn)
494 throws DifferentiationException {
495
496 // get the frame at the start of the instruction
497 final Frame frame = frames.get(insn);
498 final int size = frame.getStackSize();
499 final boolean stack1Converted = (size > 0) && converted.contains(frame.getStack(size - 2));
500 final boolean stack0Converted = (size > 1) && converted.contains(frame.getStack(size - 1));
501
502 switch(insn.getOpcode()) {
503 case Opcodes.DLOAD :
504 useLocal(((VarInsnNode) insn).var, 4);
505 return DLoadTransformer.getInstance().getReplacement(insn, this);
506 case Opcodes.DALOAD :
507 // TODO add support for DALOAD differentiation
508 throw new RuntimeException("DALOAD not handled yet");
509 case Opcodes.DSTORE :
510 useLocal(((VarInsnNode) insn).var, 4);
511 return DStoreTransformer.getInstance().getReplacement(insn, this);
512 case Opcodes.DASTORE :
513 // TODO add support for DASTORE differentiation
514 throw new RuntimeException("DASTORE not handled yet");
515 case Opcodes.DUP2 :
516 return Dup2Transformer.getInstance().getReplacement(insn, this);
517 case Opcodes.DUP2_X1 :
518 return Dup2X1Transformer.getInstance().getReplacement(insn, this);
519 case Opcodes.DUP2_X2 :
520 if (stack1Converted) {
521 if (stack0Converted) {
522 return Dup2X2Transformer12.getInstance().getReplacement(insn, this);
523 }
524 return Dup2X2Transformer1.getInstance().getReplacement(insn, this);
525 }
526 return Dup2X2Transformer2.getInstance().getReplacement(insn, this);
527 case Opcodes.DADD :
528 if (stack1Converted) {
529 if (stack0Converted) {
530 return DAddTransformer12.getInstance().getReplacement(insn, this);
531 }
532 return DAddTransformer1.getInstance().getReplacement(insn, this);
533 }
534 return DAddTransformer2.getInstance().getReplacement(insn, this);
535 case Opcodes.DSUB :
536 if (stack1Converted) {
537 if (stack0Converted) {
538 return DSubTransformer12.getInstance().getReplacement(insn, this);
539 }
540 return DSubTransformer1.getInstance().getReplacement(insn, this);
541 }
542 return DSubTransformer2.getInstance().getReplacement(insn, this);
543 case Opcodes.DMUL :
544 if (stack1Converted) {
545 if (stack0Converted) {
546 return DMulTransformer12.getInstance().getReplacement(insn, this);
547 }
548 return DMulTransformer1.getInstance().getReplacement(insn, this);
549 }
550 return DMulTransformer2.getInstance().getReplacement(insn, this);
551 case Opcodes.DDIV :
552 if (stack1Converted) {
553 if (stack0Converted) {
554 return DDivTransformer12.getInstance().getReplacement(insn, this);
555 }
556 return DDivTransformer1.getInstance().getReplacement(insn, this);
557 }
558 return DDivTransformer2.getInstance().getReplacement(insn, this);
559 case Opcodes.DREM :
560 if (stack1Converted) {
561 if (stack0Converted) {
562 return DRemTransformer12.getInstance().getReplacement(insn, this);
563 }
564 return DRemTransformer1.getInstance().getReplacement(insn, this);
565 }
566 return DRemTransformer2.getInstance().getReplacement(insn, this);
567 case Opcodes.DNEG :
568 return DNegTransformer.getInstance().getReplacement(insn, this);
569 case Opcodes.DCONST_0 :
570 case Opcodes.DCONST_1 :
571 case Opcodes.LDC :
572 case Opcodes.I2D :
573 case Opcodes.L2D :
574 case Opcodes.F2D :
575 return WideningTransformer.getInstance().getReplacement(insn, this);
576 case Opcodes.POP2 :
577 case Opcodes.D2I :
578 case Opcodes.D2L :
579 case Opcodes.D2F :
580 return NarrowingTransformer.getInstance().getReplacement(insn, this);
581 case Opcodes.DCMPL :
582 case Opcodes.DCMPG :
583 if (stack1Converted) {
584 if (stack0Converted) {
585 return DcmpTransformer12.getInstance().getReplacement(insn, this);
586 }
587 return DcmpTransformer1.getInstance().getReplacement(insn, this);
588 }
589 return DcmpTransformer2.getInstance().getReplacement(insn, this);
590 case Opcodes.DRETURN :
591 return DReturnTransformer.getInstance().getReplacement(insn, this);
592 case Opcodes.GETSTATIC :
593 // TODO add support for GETSTATIC differentiation
594 throw new RuntimeException("GETSTATIC not handled yet");
595 case Opcodes.PUTSTATIC :
596 // TODO add support for PUTSTATIC differentiation
597 throw new RuntimeException("PUTSTATIC not handled yet");
598 case Opcodes.GETFIELD :
599 // TODO add support for GETFIELD differentiation
600 throw new RuntimeException("GETFIELD not handled yet");
601 case Opcodes.PUTFIELD :
602 // TODO add support for PUTFIELD differentiation
603 throw new RuntimeException("PUTFIELD not handled yet");
604 case Opcodes.INVOKEVIRTUAL :
605 // TODO add support for INVOKEVIRTUAL differentiation
606 throw new RuntimeException("INVOKEVIRTUAL not handled yet");
607 case Opcodes.INVOKESPECIAL :
608 // TODO add support for INVOKESPECIAL differentiation
609 throw new RuntimeException("INVOKESPECIAL not handled yet");
610 case Opcodes.INVOKESTATIC :
611 return replaceInvocation((MethodInsnNode) insn,
612 stack1Converted, stack0Converted);
613 case Opcodes.INVOKEINTERFACE :
614 // TODO add support for INVOKEINTERFACE differentiation
615 throw new RuntimeException("INVOKEINTERFACE not handled yet");
616 case Opcodes.NEWARRAY :
617 // TODO add support for NEWARRAY differentiation
618 throw new RuntimeException("NEWARRAY not handled yet");
619 case Opcodes.ANEWARRAY :
620 // TODO add support for ANEWARRAY differentiation
621 throw new RuntimeException("ANEWARRAY not handled yet");
622 case Opcodes.MULTIANEWARRAY :
623 // TODO add support for MULTIANEWARRAY differentiation
624 throw new RuntimeException("MULTIANEWARRAY not handled yet");
625 default:
626 throw new DifferentiationException("unable to handle instruction with opcode {0}",
627 new Object[] {
628 Integer.valueOf(insn.getOpcode())
629 });
630 }
631
632 }
633
634 /** Replace an INVOKESTATIC instruction.
635 * @param methodInsn invocation instruction
636 * @param stack1Converted if true, the stack sub-head has been
637 * converted to differential pair
638 * @param stack0Converted if true, the stack head has been
639 * converted to differential pair
640 * @return replacement instructions list
641 * @exception DifferentiationException if the instruction cannot be replaced
642 */
643 private InsnList replaceInvocation(final MethodInsnNode methodInsn,
644 final boolean stack1Converted,
645 final boolean stack0Converted)
646 throws DifferentiationException {
647 if (isMathImplementationClass(methodInsn.owner)) {
648 if ("(D)D".equals(methodInsn.desc)) {
649 // this is a univariate method like sin, cos, exp ...
650 final MathInvocationTransformer transformer = MATH_TRANSFORMERS.get(methodInsn.name);
651 if (transformer == null) {
652 throw new DifferentiationException(UNKNOWN_METHOD_FMT,
653 methodInsn.owner, methodInsn.name);
654 }
655 return transformer.getReplacementList(methodInsn.owner, this);
656 } else if ("(DD)D".equals(methodInsn.desc)) {
657 // this is a bivariate method like atan2, pow ...
658
659 // we may want to differentiate against first, second or both parameters
660 String name = null;
661 if (stack1Converted) {
662 if (stack0Converted) {
663 name = methodInsn.name + "_12";
664 } else {
665 name = methodInsn.name + "_1";
666 }
667 } else if (stack0Converted) {
668 name = methodInsn.name + "_2";
669 }
670
671 if (name != null) {
672 final MathInvocationTransformer transformer = MATH_TRANSFORMERS.get(name);
673 if (transformer == null) {
674 throw new DifferentiationException(UNKNOWN_METHOD_FMT,
675 methodInsn.owner, methodInsn.name);
676 }
677 return transformer.getReplacementList(methodInsn.owner, this);
678 }
679 }
680 }
681 throw new DifferentiationException("unexpected instruction {0}",
682 Integer.valueOf(methodInsn.getOpcode()));
683 }
684
685 /** Test if a class is a math implementation class.
686 * @param name name of the class to test
687 * @return true if the named class is a math implementation class
688 */
689 public boolean isMathImplementationClass(final String name) {
690 return mathClasses.contains(name);
691 }
692
693 /** Set a local variable as used by the modified code.
694 * @param index index of the variable
695 * @param size size of the variable (1 or 2 for standard variables,
696 * 4 for special expanded differential pairs)
697 * @exception DifferentiationException if the number of the
698 * temporary variable lies outside of the allowed range
699 */
700 public void useLocal(final int index, final int size)
701 throws DifferentiationException {
702 if ((index < 0) || ((index + size - 1) >= usedLocals.length)) {
703 throw new DifferentiationException("index of size {0} local variable ({1}) " +
704 "outside of [{2}, {3}] range",
705 Integer.valueOf(size), Integer.valueOf(index),
706 Integer.valueOf(1), Integer.valueOf(MAX_TEMP));
707 }
708 for (int i = index; i < index + size; ++i) {
709 usedLocals[i] = true;
710 }
711 }
712
713 /** Get index of a size 2 temporary variable.
714 * <p>Temporary variables can be used for very short duration
715 * storage of size 2 values (i.e one double, or one long or two
716 * integers). These variables are reused in many replacement
717 * instructions sequences, so their content may be overridden
718 * at any time: they should be considered to have a scope
719 * limited to one replacement sequence only. This means that
720 * one should <em>not</em> store a value in a variable in one
721 * replacement sequence and retrieve it later in another
722 * replacement sequence as it may have been overridden in
723 * between.</p>
724 * <p>At most 5 temporary variables may be used.</p>
725 * @param number number of the temporary variable (must be
726 * between 1 and the maximal number of available temporary
727 * variables)
728 * @return index of the variable within the local variables
729 * array
730 * @exception DifferentiationException if the number of the
731 * temporary variable lies outside of the allowed range
732 */
733 public int getTmp(final int number) throws DifferentiationException {
734 if ((number < 0) || (number > MAX_TEMP)) {
735 throw new DifferentiationException("number of temporary variable ({0}) outside of [{1}, {2}] range",
736 Integer.valueOf(number),
737 Integer.valueOf(1),
738 Integer.valueOf(MAX_TEMP));
739 }
740 final int index = usedLocals.length - 2 * number;
741 useLocal(index, 2);
742 return index;
743 }
744
745 /** Shifted the index of a variable instruction.
746 * @param insn variable instruction
747 */
748 public void shiftVariable(final VarInsnNode insn) {
749 int shifted = 0;
750 for (int i = 0; i < insn.var; ++i) {
751 if (usedLocals[i]) {
752 ++shifted;
753 }
754 }
755 insn.var = shifted;
756 }
757
758 /** Clone an instruction.
759 * @param insn instruction to clone
760 * @return cloned instruction
761 */
762 public AbstractInsnNode clone(final AbstractInsnNode insn) {
763 return insn.clone(clonedLabels);
764 }
765
766 /** Analyzer preserving instructions successors information. */
767 private class FlowAnalyzer extends Analyzer {
768
769 /** Simple constructor.
770 * @param interpreter associated interpreter
771 */
772 public FlowAnalyzer(final Interpreter interpreter) {
773 super(interpreter);
774 }
775
776 /** Store a new edge.
777 * @param insn current instruction
778 * @param successor successor instruction
779 */
780 protected void newControlFlowEdge(final int insn, final int successor) {
781 // store the successor information
782 final AbstractInsnNode node = instructions.get(insn);
783 Set<AbstractInsnNode> set = successors.get(node);
784 if (set == null) {
785 set = new HashSet<AbstractInsnNode>();
786 successors.put(node, set);
787 }
788 set.add(instructions.get(successor));
789 }
790
791 }
792
793 }