| 1 | |
|
| 2 | |
|
| 3 | |
|
| 4 | |
|
| 5 | |
|
| 6 | |
|
| 7 | |
|
| 8 | |
|
| 9 | |
|
| 10 | |
|
| 11 | |
|
| 12 | |
|
| 13 | |
|
| 14 | |
|
| 15 | |
|
| 16 | |
|
| 17 | |
package org.apache.commons.nabla.automatic.analysis; |
| 18 | |
|
| 19 | |
import java.util.ArrayList; |
| 20 | |
import java.util.HashMap; |
| 21 | |
import java.util.HashSet; |
| 22 | |
import java.util.IdentityHashMap; |
| 23 | |
import java.util.Iterator; |
| 24 | |
import java.util.List; |
| 25 | |
import java.util.Map; |
| 26 | |
import java.util.Set; |
| 27 | |
|
| 28 | |
import org.apache.commons.nabla.automatic.arithmetic.DAddTransformer1; |
| 29 | |
import org.apache.commons.nabla.automatic.arithmetic.DAddTransformer12; |
| 30 | |
import org.apache.commons.nabla.automatic.arithmetic.DAddTransformer2; |
| 31 | |
import org.apache.commons.nabla.automatic.arithmetic.DDivTransformer1; |
| 32 | |
import org.apache.commons.nabla.automatic.arithmetic.DDivTransformer12; |
| 33 | |
import org.apache.commons.nabla.automatic.arithmetic.DDivTransformer2; |
| 34 | |
import org.apache.commons.nabla.automatic.arithmetic.DMulTransformer1; |
| 35 | |
import org.apache.commons.nabla.automatic.arithmetic.DMulTransformer12; |
| 36 | |
import org.apache.commons.nabla.automatic.arithmetic.DMulTransformer2; |
| 37 | |
import org.apache.commons.nabla.automatic.arithmetic.DNegTransformer; |
| 38 | |
import org.apache.commons.nabla.automatic.arithmetic.DRemTransformer1; |
| 39 | |
import org.apache.commons.nabla.automatic.arithmetic.DRemTransformer12; |
| 40 | |
import org.apache.commons.nabla.automatic.arithmetic.DRemTransformer2; |
| 41 | |
import org.apache.commons.nabla.automatic.arithmetic.DSubTransformer1; |
| 42 | |
import org.apache.commons.nabla.automatic.arithmetic.DSubTransformer12; |
| 43 | |
import org.apache.commons.nabla.automatic.arithmetic.DSubTransformer2; |
| 44 | |
import org.apache.commons.nabla.automatic.functions.AcosTransformer; |
| 45 | |
import org.apache.commons.nabla.automatic.functions.AcoshTransformer; |
| 46 | |
import org.apache.commons.nabla.automatic.functions.AsinTransformer; |
| 47 | |
import org.apache.commons.nabla.automatic.functions.AsinhTransformer; |
| 48 | |
import org.apache.commons.nabla.automatic.functions.Atan2Transformer1; |
| 49 | |
import org.apache.commons.nabla.automatic.functions.Atan2Transformer12; |
| 50 | |
import org.apache.commons.nabla.automatic.functions.Atan2Transformer2; |
| 51 | |
import org.apache.commons.nabla.automatic.functions.AtanTransformer; |
| 52 | |
import org.apache.commons.nabla.automatic.functions.AtanhTransformer; |
| 53 | |
import org.apache.commons.nabla.automatic.functions.CbrtTransformer; |
| 54 | |
import org.apache.commons.nabla.automatic.functions.CosTransformer; |
| 55 | |
import org.apache.commons.nabla.automatic.functions.CoshTransformer; |
| 56 | |
import org.apache.commons.nabla.automatic.functions.ExpTransformer; |
| 57 | |
import org.apache.commons.nabla.automatic.functions.Expm1Transformer; |
| 58 | |
import org.apache.commons.nabla.automatic.functions.HypotTransformer1; |
| 59 | |
import org.apache.commons.nabla.automatic.functions.HypotTransformer12; |
| 60 | |
import org.apache.commons.nabla.automatic.functions.HypotTransformer2; |
| 61 | |
import org.apache.commons.nabla.automatic.functions.Log10Transformer; |
| 62 | |
import org.apache.commons.nabla.automatic.functions.Log1pTransformer; |
| 63 | |
import org.apache.commons.nabla.automatic.functions.LogTransformer; |
| 64 | |
import org.apache.commons.nabla.automatic.functions.MathInvocationTransformer; |
| 65 | |
import org.apache.commons.nabla.automatic.functions.PowTransformer1; |
| 66 | |
import org.apache.commons.nabla.automatic.functions.PowTransformer12; |
| 67 | |
import org.apache.commons.nabla.automatic.functions.PowTransformer2; |
| 68 | |
import org.apache.commons.nabla.automatic.functions.SinTransformer; |
| 69 | |
import org.apache.commons.nabla.automatic.functions.SinhTransformer; |
| 70 | |
import org.apache.commons.nabla.automatic.functions.SqrtTransformer; |
| 71 | |
import org.apache.commons.nabla.automatic.functions.TanTransformer; |
| 72 | |
import org.apache.commons.nabla.automatic.functions.TanhTransformer; |
| 73 | |
import org.apache.commons.nabla.automatic.instructions.DLoadTransformer; |
| 74 | |
import org.apache.commons.nabla.automatic.instructions.DReturnTransformer; |
| 75 | |
import org.apache.commons.nabla.automatic.instructions.DStoreTransformer; |
| 76 | |
import org.apache.commons.nabla.automatic.instructions.DcmpTransformer1; |
| 77 | |
import org.apache.commons.nabla.automatic.instructions.DcmpTransformer12; |
| 78 | |
import org.apache.commons.nabla.automatic.instructions.DcmpTransformer2; |
| 79 | |
import org.apache.commons.nabla.automatic.instructions.Dup2Transformer; |
| 80 | |
import org.apache.commons.nabla.automatic.instructions.Dup2X1Transformer; |
| 81 | |
import org.apache.commons.nabla.automatic.instructions.Dup2X2Transformer1; |
| 82 | |
import org.apache.commons.nabla.automatic.instructions.Dup2X2Transformer12; |
| 83 | |
import org.apache.commons.nabla.automatic.instructions.Dup2X2Transformer2; |
| 84 | |
import org.apache.commons.nabla.automatic.instructions.NarrowingTransformer; |
| 85 | |
import org.apache.commons.nabla.automatic.instructions.WideningTransformer; |
| 86 | |
import org.apache.commons.nabla.automatic.trimming.DLoadPop2Trimmer; |
| 87 | |
import org.apache.commons.nabla.automatic.trimming.SwappedDloadTrimmer; |
| 88 | |
import org.apache.commons.nabla.automatic.trimming.SwappedDstoreTrimmer; |
| 89 | |
import org.apache.commons.nabla.core.DifferentialPair; |
| 90 | |
import org.apache.commons.nabla.core.DifferentiationException; |
| 91 | |
import org.objectweb.asm.MethodVisitor; |
| 92 | |
import org.objectweb.asm.Opcodes; |
| 93 | |
import org.objectweb.asm.tree.AbstractInsnNode; |
| 94 | |
import org.objectweb.asm.tree.FieldInsnNode; |
| 95 | |
import org.objectweb.asm.tree.IincInsnNode; |
| 96 | |
import org.objectweb.asm.tree.InsnList; |
| 97 | |
import org.objectweb.asm.tree.InsnNode; |
| 98 | |
import org.objectweb.asm.tree.LabelNode; |
| 99 | |
import org.objectweb.asm.tree.MethodInsnNode; |
| 100 | |
import org.objectweb.asm.tree.MethodNode; |
| 101 | |
import org.objectweb.asm.tree.VarInsnNode; |
| 102 | |
import org.objectweb.asm.tree.analysis.Analyzer; |
| 103 | |
import org.objectweb.asm.tree.analysis.AnalyzerException; |
| 104 | |
import org.objectweb.asm.tree.analysis.BasicValue; |
| 105 | |
import org.objectweb.asm.tree.analysis.Frame; |
| 106 | |
import org.objectweb.asm.tree.analysis.Interpreter; |
| 107 | |
|
| 108 | |
|
| 109 | |
|
| 110 | |
|
| 111 | 432 | public class MethodDifferentiator extends MethodNode { |
| 112 | |
|
| 113 | |
|
| 114 | 1 | public static final String DP_NAME = DifferentialPair.class.getName().replace('.', '/'); |
| 115 | |
|
| 116 | |
|
| 117 | 1 | public static final String DP_DESCRIPTOR = "L" + DP_NAME + ";"; |
| 118 | |
|
| 119 | |
|
| 120 | 1 | public static final String DP_RETURN_DP_DESCRIPTOR = "(" + DP_DESCRIPTOR + ")" + DP_DESCRIPTOR; |
| 121 | |
|
| 122 | |
|
| 123 | |
private static final String VOID_RETURN_D_DESCRIPTOR = "()D"; |
| 124 | |
|
| 125 | |
|
| 126 | 1 | private static final Map<String, MathInvocationTransformer> MATH_TRANSFORMERS = |
| 127 | |
new HashMap<String, MathInvocationTransformer>(); |
| 128 | |
|
| 129 | |
static { |
| 130 | 1 | MATH_TRANSFORMERS.put("acos", new AcosTransformer()); |
| 131 | 1 | MATH_TRANSFORMERS.put("acosh", new AcoshTransformer()); |
| 132 | 1 | MATH_TRANSFORMERS.put("asin", new AsinTransformer()); |
| 133 | 1 | MATH_TRANSFORMERS.put("asinh", new AsinhTransformer()); |
| 134 | 1 | MATH_TRANSFORMERS.put("atan2_12", new Atan2Transformer12()); |
| 135 | 1 | MATH_TRANSFORMERS.put("atan2_1", new Atan2Transformer1()); |
| 136 | 1 | MATH_TRANSFORMERS.put("atan2_2", new Atan2Transformer2()); |
| 137 | 1 | MATH_TRANSFORMERS.put("atan", new AtanTransformer()); |
| 138 | 1 | MATH_TRANSFORMERS.put("atanh", new AtanhTransformer()); |
| 139 | 1 | MATH_TRANSFORMERS.put("cbrt", new CbrtTransformer()); |
| 140 | 1 | MATH_TRANSFORMERS.put("cos", new CosTransformer()); |
| 141 | 1 | MATH_TRANSFORMERS.put("cosh", new CoshTransformer()); |
| 142 | 1 | MATH_TRANSFORMERS.put("exp", new ExpTransformer()); |
| 143 | 1 | MATH_TRANSFORMERS.put("expm1", new Expm1Transformer()); |
| 144 | 1 | MATH_TRANSFORMERS.put("hypot_12", new HypotTransformer12()); |
| 145 | 1 | MATH_TRANSFORMERS.put("hypot_1", new HypotTransformer1()); |
| 146 | 1 | MATH_TRANSFORMERS.put("hypot_2", new HypotTransformer2()); |
| 147 | 1 | MATH_TRANSFORMERS.put("log10", new Log10Transformer()); |
| 148 | 1 | MATH_TRANSFORMERS.put("log1p", new Log1pTransformer()); |
| 149 | 1 | MATH_TRANSFORMERS.put("log", new LogTransformer()); |
| 150 | 1 | MATH_TRANSFORMERS.put("pow_12", new PowTransformer12()); |
| 151 | 1 | MATH_TRANSFORMERS.put("pow_1", new PowTransformer1()); |
| 152 | 1 | MATH_TRANSFORMERS.put("pow_2", new PowTransformer2()); |
| 153 | 1 | MATH_TRANSFORMERS.put("sin", new SinTransformer()); |
| 154 | 1 | MATH_TRANSFORMERS.put("sinh", new SinhTransformer()); |
| 155 | 1 | MATH_TRANSFORMERS.put("sqrt", new SqrtTransformer()); |
| 156 | 1 | MATH_TRANSFORMERS.put("tan", new TanTransformer()); |
| 157 | 1 | MATH_TRANSFORMERS.put("tanh", new TanhTransformer()); |
| 158 | 1 | } |
| 159 | |
|
| 160 | |
|
| 161 | |
private static final String UNKNOWN_METHOD_FMT = "unknown method {0}.{1}"; |
| 162 | |
|
| 163 | |
|
| 164 | |
private static final int MAX_TEMP = 5; |
| 165 | |
|
| 166 | |
|
| 167 | |
private final Set<String> mathClasses; |
| 168 | |
|
| 169 | |
|
| 170 | |
private final MethodVisitor generator; |
| 171 | |
|
| 172 | |
|
| 173 | |
private boolean[] usedLocals; |
| 174 | |
|
| 175 | |
|
| 176 | |
private final String primitiveName; |
| 177 | |
|
| 178 | |
|
| 179 | |
private final ErrorReporter errorReporter; |
| 180 | |
|
| 181 | |
|
| 182 | |
private final Set<TrackingValue> converted; |
| 183 | |
|
| 184 | |
|
| 185 | |
private final Map<AbstractInsnNode, Frame> frames; |
| 186 | |
|
| 187 | |
|
| 188 | |
private final Map<AbstractInsnNode, Set<AbstractInsnNode>> successors; |
| 189 | |
|
| 190 | |
|
| 191 | |
private final Map<LabelNode, LabelNode> clonedLabels; |
| 192 | |
|
| 193 | |
|
| 194 | |
|
| 195 | |
|
| 196 | |
|
| 197 | |
|
| 198 | |
|
| 199 | |
|
| 200 | |
|
| 201 | |
|
| 202 | |
|
| 203 | |
|
| 204 | |
public MethodDifferentiator(final int access, final String name, final String desc, |
| 205 | |
final String signature, final String[] exceptions, |
| 206 | |
final MethodVisitor generator,final String primitiveName, |
| 207 | |
final Set<String> mathClasses, |
| 208 | |
final ErrorReporter errorReporter) { |
| 209 | |
|
| 210 | 66 | super(access, name, desc, signature, exceptions); |
| 211 | 66 | this.generator = generator; |
| 212 | 66 | this.usedLocals = null; |
| 213 | 66 | this.primitiveName = primitiveName; |
| 214 | 66 | this.mathClasses = mathClasses; |
| 215 | 66 | this.errorReporter = errorReporter; |
| 216 | 66 | this.converted = new HashSet<TrackingValue>(); |
| 217 | 66 | this.frames = new IdentityHashMap<AbstractInsnNode, Frame>(); |
| 218 | 66 | this.successors = new IdentityHashMap<AbstractInsnNode, Set<AbstractInsnNode>>(); |
| 219 | 66 | this.clonedLabels = new HashMap<LabelNode, LabelNode>(); |
| 220 | |
|
| 221 | 66 | } |
| 222 | |
|
| 223 | |
|
| 224 | |
@Override |
| 225 | |
public void visitEnd() { |
| 226 | |
try { |
| 227 | |
|
| 228 | |
|
| 229 | 66 | maxLocals = 2 * (maxLocals + MAX_TEMP) - 1; |
| 230 | 66 | usedLocals = new boolean[maxLocals]; |
| 231 | 66 | useLocal(0, 1); |
| 232 | 66 | useLocal(1, 4); |
| 233 | |
|
| 234 | |
|
| 235 | 66 | addSpareLocalVariables(); |
| 236 | |
|
| 237 | |
|
| 238 | 66 | final Frame[] array = |
| 239 | |
new FlowAnalyzer(new TrackingInterpreter()).analyze(primitiveName, this); |
| 240 | |
|
| 241 | |
|
| 242 | 347 | for (int i = 0; i < array.length; ++i) { |
| 243 | 281 | frames.put(instructions.get(i), array[i]); |
| 244 | |
} |
| 245 | |
|
| 246 | |
|
| 247 | 66 | final Set<AbstractInsnNode> changes = identifyChanges(); |
| 248 | |
|
| 249 | 66 | if (changes.isEmpty()) { |
| 250 | |
|
| 251 | |
|
| 252 | |
|
| 253 | 1 | instructions.clear(); |
| 254 | 1 | instructions.add(new FieldInsnNode(Opcodes.GETSTATIC, DP_NAME, "ZERO", DP_DESCRIPTOR)); |
| 255 | 1 | instructions.add(new InsnNode(Opcodes.ARETURN)); |
| 256 | |
|
| 257 | |
} else { |
| 258 | |
|
| 259 | |
|
| 260 | 65 | changeCode(changes); |
| 261 | |
|
| 262 | |
|
| 263 | 65 | removeUnusedSpareLocalVariables(); |
| 264 | |
|
| 265 | |
|
| 266 | 65 | SwappedDloadTrimmer.getInstance().trim(instructions); |
| 267 | 65 | SwappedDstoreTrimmer.getInstance().trim(instructions); |
| 268 | 65 | DLoadPop2Trimmer.getInstance().trim(instructions); |
| 269 | |
|
| 270 | |
} |
| 271 | |
|
| 272 | |
|
| 273 | 66 | desc = DP_RETURN_DP_DESCRIPTOR; |
| 274 | |
|
| 275 | |
|
| 276 | 66 | accept(generator); |
| 277 | |
|
| 278 | 0 | } catch (AnalyzerException ae) { |
| 279 | 0 | if ((ae.getCause() != null) && ae.getCause() instanceof DifferentiationException) { |
| 280 | 0 | errorReporter.register((DifferentiationException) ae.getCause()); |
| 281 | |
} else { |
| 282 | 0 | final DifferentiationException de = |
| 283 | |
new DifferentiationException("unable to analyze the {0}.{1} method ({2})", |
| 284 | |
new Object[] { |
| 285 | |
primitiveName, name, ae.getMessage() |
| 286 | |
}); |
| 287 | 0 | errorReporter.register(de); |
| 288 | |
} |
| 289 | 0 | } catch (DifferentiationException de) { |
| 290 | 0 | errorReporter.register(de); |
| 291 | 66 | } |
| 292 | 66 | } |
| 293 | |
|
| 294 | |
|
| 295 | |
|
| 296 | |
|
| 297 | |
|
| 298 | |
|
| 299 | |
|
| 300 | |
|
| 301 | |
|
| 302 | |
|
| 303 | |
|
| 304 | |
|
| 305 | |
|
| 306 | |
|
| 307 | |
private void addSpareLocalVariables() throws DifferentiationException { |
| 308 | 66 | for (final Iterator<?> i = instructions.iterator(); i.hasNext();) { |
| 309 | 281 | final AbstractInsnNode insn = (AbstractInsnNode) i.next(); |
| 310 | 281 | if (insn.getType() == AbstractInsnNode.VAR_INSN) { |
| 311 | 95 | final VarInsnNode varInsn = (VarInsnNode) insn; |
| 312 | 95 | if (varInsn.var > 2) { |
| 313 | 15 | varInsn.var = 2 * varInsn.var - 1; |
| 314 | 15 | final int opcode = varInsn.getOpcode(); |
| 315 | 15 | if ((opcode == Opcodes.ILOAD) || (opcode == Opcodes.FLOAD) || |
| 316 | |
(opcode == Opcodes.ALOAD) || (opcode == Opcodes.ISTORE) || |
| 317 | |
(opcode == Opcodes.FSTORE) || (opcode == Opcodes.ASTORE)) { |
| 318 | 4 | useLocal(varInsn.var, 1); |
| 319 | |
} else { |
| 320 | 11 | useLocal(varInsn.var, 2); |
| 321 | |
} |
| 322 | |
} |
| 323 | 95 | } else if (insn.getOpcode() == Opcodes.IINC) { |
| 324 | 2 | final IincInsnNode iincInsn = (IincInsnNode) insn; |
| 325 | 2 | if (iincInsn.var > 2) { |
| 326 | 2 | iincInsn.var = 2 * iincInsn.var - 1; |
| 327 | 2 | useLocal(iincInsn.var, 1); |
| 328 | |
} |
| 329 | |
} |
| 330 | 281 | } |
| 331 | 66 | } |
| 332 | |
|
| 333 | |
|
| 334 | |
|
| 335 | |
|
| 336 | |
private void removeUnusedSpareLocalVariables() { |
| 337 | 65 | for (final Iterator<?> i = instructions.iterator(); i.hasNext();) { |
| 338 | 1973 | final AbstractInsnNode insn = (AbstractInsnNode) i.next(); |
| 339 | 1973 | if (insn.getType() == AbstractInsnNode.VAR_INSN) { |
| 340 | 873 | shiftVariable((VarInsnNode) insn); |
| 341 | |
} |
| 342 | 1973 | } |
| 343 | 65 | } |
| 344 | |
|
| 345 | |
|
| 346 | |
|
| 347 | |
|
| 348 | |
|
| 349 | |
|
| 350 | |
|
| 351 | |
|
| 352 | |
|
| 353 | |
|
| 354 | |
private Set<AbstractInsnNode> identifyChanges() { |
| 355 | |
|
| 356 | |
|
| 357 | |
|
| 358 | |
|
| 359 | 66 | final Set<TrackingValue> pending = new HashSet<TrackingValue>(); |
| 360 | |
|
| 361 | |
|
| 362 | 66 | final Set<AbstractInsnNode> changes = new HashSet<AbstractInsnNode>(); |
| 363 | |
|
| 364 | |
|
| 365 | |
|
| 366 | 66 | final TrackingValue dpParameter = (TrackingValue) frames.get(instructions.get(0)).getLocal(1); |
| 367 | 66 | pending.add(dpParameter); |
| 368 | |
|
| 369 | |
|
| 370 | 210 | while (!pending.isEmpty()) { |
| 371 | |
|
| 372 | |
|
| 373 | 144 | final Iterator<TrackingValue> iterator = pending.iterator(); |
| 374 | 144 | final TrackingValue value = iterator.next(); |
| 375 | 144 | iterator.remove(); |
| 376 | |
|
| 377 | |
|
| 378 | 144 | converted.add(value); |
| 379 | |
|
| 380 | |
|
| 381 | 144 | for (final AbstractInsnNode consumer : value.getConsumers()) { |
| 382 | |
|
| 383 | |
|
| 384 | |
|
| 385 | |
|
| 386 | 252 | for (TrackingValue produced : getProducedDoubleValues(consumer)) { |
| 387 | |
|
| 388 | |
|
| 389 | 184 | if (!converted.contains(produced)) { |
| 390 | 83 | pending.add(produced); |
| 391 | |
} |
| 392 | |
|
| 393 | |
} |
| 394 | |
|
| 395 | |
|
| 396 | 252 | changes.add(consumer); |
| 397 | |
|
| 398 | |
} |
| 399 | |
|
| 400 | |
|
| 401 | 144 | for (final AbstractInsnNode producer : value.getProducers()) { |
| 402 | |
|
| 403 | |
|
| 404 | 181 | changes.add(producer); |
| 405 | |
|
| 406 | |
} |
| 407 | 144 | } |
| 408 | |
|
| 409 | 66 | return changes; |
| 410 | |
|
| 411 | |
} |
| 412 | |
|
| 413 | |
|
| 414 | |
|
| 415 | |
|
| 416 | |
|
| 417 | |
private List<TrackingValue> getProducedDoubleValues(final AbstractInsnNode instruction) { |
| 418 | |
|
| 419 | 252 | final List<TrackingValue> values = new ArrayList<TrackingValue>(); |
| 420 | |
|
| 421 | |
|
| 422 | 252 | final Frame before = frames.get(instruction); |
| 423 | 252 | final int beforeStackSize = before.getStackSize(); |
| 424 | 252 | final int locals = before.getLocals(); |
| 425 | |
|
| 426 | |
|
| 427 | |
|
| 428 | 252 | final Set<AbstractInsnNode> set = successors.get(instruction); |
| 429 | 252 | if (set != null) { |
| 430 | |
|
| 431 | |
|
| 432 | 185 | for (final AbstractInsnNode successor : set) { |
| 433 | 185 | final Frame produced = frames.get(successor); |
| 434 | |
|
| 435 | |
|
| 436 | 393 | for (int i = 0; i < produced.getStackSize(); ++i) { |
| 437 | 208 | final TrackingValue value = (TrackingValue) produced.getStack(i); |
| 438 | 208 | if (((i >= beforeStackSize) || (value != before.getStack(i))) && |
| 439 | |
value.getValue().equals(BasicValue.DOUBLE_VALUE)) { |
| 440 | 175 | values.add(value); |
| 441 | |
} |
| 442 | |
} |
| 443 | |
|
| 444 | |
|
| 445 | 3128 | for (int i = 0; i < locals; ++i) { |
| 446 | 2943 | final TrackingValue value = (TrackingValue) produced.getLocal(i); |
| 447 | 2943 | if ((value != before.getLocal(i)) && |
| 448 | |
value.getValue().equals(BasicValue.DOUBLE_VALUE)) { |
| 449 | 9 | values.add(value); |
| 450 | |
} |
| 451 | |
} |
| 452 | 185 | } |
| 453 | |
} |
| 454 | |
|
| 455 | 252 | return values; |
| 456 | |
|
| 457 | |
} |
| 458 | |
|
| 459 | |
|
| 460 | |
|
| 461 | |
|
| 462 | |
|
| 463 | |
private void changeCode(final Set<AbstractInsnNode> changes) |
| 464 | |
throws DifferentiationException { |
| 465 | |
|
| 466 | |
|
| 467 | 65 | final InsnList list = new InsnList(); |
| 468 | 65 | list.add(new VarInsnNode(Opcodes.ALOAD, 1)); |
| 469 | 65 | list.add(new InsnNode(Opcodes.DUP)); |
| 470 | 65 | list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, DP_NAME, |
| 471 | |
"getValue", VOID_RETURN_D_DESCRIPTOR)); |
| 472 | 65 | list.add(new VarInsnNode(Opcodes.DSTORE, 1)); |
| 473 | 65 | list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, DP_NAME, |
| 474 | |
"getFirstDerivative", VOID_RETURN_D_DESCRIPTOR)); |
| 475 | 65 | list.add(new VarInsnNode(Opcodes.DSTORE, 3)); |
| 476 | |
|
| 477 | 65 | instructions.insertBefore(instructions.get(0), list); |
| 478 | |
|
| 479 | |
|
| 480 | 65 | for (final AbstractInsnNode insn : changes) { |
| 481 | 235 | instructions.insert(insn, getReplacement(insn)); |
| 482 | 235 | instructions.remove(insn); |
| 483 | |
} |
| 484 | |
|
| 485 | 65 | } |
| 486 | |
|
| 487 | |
|
| 488 | |
|
| 489 | |
|
| 490 | |
|
| 491 | |
|
| 492 | |
|
| 493 | |
private InsnList getReplacement(final AbstractInsnNode insn) |
| 494 | |
throws DifferentiationException { |
| 495 | |
|
| 496 | |
|
| 497 | 235 | final Frame frame = frames.get(insn); |
| 498 | 235 | final int size = frame.getStackSize(); |
| 499 | 235 | final boolean stack1Converted = (size > 0) && converted.contains(frame.getStack(size - 2)); |
| 500 | 235 | final boolean stack0Converted = (size > 1) && converted.contains(frame.getStack(size - 1)); |
| 501 | |
|
| 502 | 235 | switch(insn.getOpcode()) { |
| 503 | |
case Opcodes.DLOAD : |
| 504 | 86 | useLocal(((VarInsnNode) insn).var, 4); |
| 505 | 86 | return DLoadTransformer.getInstance().getReplacement(insn, this); |
| 506 | |
case Opcodes.DALOAD : |
| 507 | |
|
| 508 | 0 | throw new RuntimeException("DALOAD not handled yet"); |
| 509 | |
case Opcodes.DSTORE : |
| 510 | 5 | useLocal(((VarInsnNode) insn).var, 4); |
| 511 | 5 | return DStoreTransformer.getInstance().getReplacement(insn, this); |
| 512 | |
case Opcodes.DASTORE : |
| 513 | |
|
| 514 | 0 | throw new RuntimeException("DASTORE not handled yet"); |
| 515 | |
case Opcodes.DUP2 : |
| 516 | 0 | return Dup2Transformer.getInstance().getReplacement(insn, this); |
| 517 | |
case Opcodes.DUP2_X1 : |
| 518 | 0 | return Dup2X1Transformer.getInstance().getReplacement(insn, this); |
| 519 | |
case Opcodes.DUP2_X2 : |
| 520 | 0 | if (stack1Converted) { |
| 521 | 0 | if (stack0Converted) { |
| 522 | 0 | return Dup2X2Transformer12.getInstance().getReplacement(insn, this); |
| 523 | |
} |
| 524 | 0 | return Dup2X2Transformer1.getInstance().getReplacement(insn, this); |
| 525 | |
} |
| 526 | 0 | return Dup2X2Transformer2.getInstance().getReplacement(insn, this); |
| 527 | |
case Opcodes.DADD : |
| 528 | 8 | if (stack1Converted) { |
| 529 | 7 | if (stack0Converted) { |
| 530 | 1 | return DAddTransformer12.getInstance().getReplacement(insn, this); |
| 531 | |
} |
| 532 | 6 | return DAddTransformer1.getInstance().getReplacement(insn, this); |
| 533 | |
} |
| 534 | 1 | return DAddTransformer2.getInstance().getReplacement(insn, this); |
| 535 | |
case Opcodes.DSUB : |
| 536 | 5 | if (stack1Converted) { |
| 537 | 4 | if (stack0Converted) { |
| 538 | 1 | return DSubTransformer12.getInstance().getReplacement(insn, this); |
| 539 | |
} |
| 540 | 3 | return DSubTransformer1.getInstance().getReplacement(insn, this); |
| 541 | |
} |
| 542 | 1 | return DSubTransformer2.getInstance().getReplacement(insn, this); |
| 543 | |
case Opcodes.DMUL : |
| 544 | 12 | if (stack1Converted) { |
| 545 | 9 | if (stack0Converted) { |
| 546 | 8 | return DMulTransformer12.getInstance().getReplacement(insn, this); |
| 547 | |
} |
| 548 | 1 | return DMulTransformer1.getInstance().getReplacement(insn, this); |
| 549 | |
} |
| 550 | 3 | return DMulTransformer2.getInstance().getReplacement(insn, this); |
| 551 | |
case Opcodes.DDIV : |
| 552 | 4 | if (stack1Converted) { |
| 553 | 2 | if (stack0Converted) { |
| 554 | 1 | return DDivTransformer12.getInstance().getReplacement(insn, this); |
| 555 | |
} |
| 556 | 1 | return DDivTransformer1.getInstance().getReplacement(insn, this); |
| 557 | |
} |
| 558 | 2 | return DDivTransformer2.getInstance().getReplacement(insn, this); |
| 559 | |
case Opcodes.DREM : |
| 560 | 3 | if (stack1Converted) { |
| 561 | 2 | if (stack0Converted) { |
| 562 | 1 | return DRemTransformer12.getInstance().getReplacement(insn, this); |
| 563 | |
} |
| 564 | 1 | return DRemTransformer1.getInstance().getReplacement(insn, this); |
| 565 | |
} |
| 566 | 1 | return DRemTransformer2.getInstance().getReplacement(insn, this); |
| 567 | |
case Opcodes.DNEG : |
| 568 | 1 | return DNegTransformer.getInstance().getReplacement(insn, this); |
| 569 | |
case Opcodes.DCONST_0 : |
| 570 | |
case Opcodes.DCONST_1 : |
| 571 | |
case Opcodes.LDC : |
| 572 | |
case Opcodes.I2D : |
| 573 | |
case Opcodes.L2D : |
| 574 | |
case Opcodes.F2D : |
| 575 | 2 | return WideningTransformer.getInstance().getReplacement(insn, this); |
| 576 | |
case Opcodes.POP2 : |
| 577 | |
case Opcodes.D2I : |
| 578 | |
case Opcodes.D2L : |
| 579 | |
case Opcodes.D2F : |
| 580 | 1 | return NarrowingTransformer.getInstance().getReplacement(insn, this); |
| 581 | |
case Opcodes.DCMPL : |
| 582 | |
case Opcodes.DCMPG : |
| 583 | 0 | if (stack1Converted) { |
| 584 | 0 | if (stack0Converted) { |
| 585 | 0 | return DcmpTransformer12.getInstance().getReplacement(insn, this); |
| 586 | |
} |
| 587 | 0 | return DcmpTransformer1.getInstance().getReplacement(insn, this); |
| 588 | |
} |
| 589 | 0 | return DcmpTransformer2.getInstance().getReplacement(insn, this); |
| 590 | |
case Opcodes.DRETURN : |
| 591 | 65 | return DReturnTransformer.getInstance().getReplacement(insn, this); |
| 592 | |
case Opcodes.GETSTATIC : |
| 593 | |
|
| 594 | 0 | throw new RuntimeException("GETSTATIC not handled yet"); |
| 595 | |
case Opcodes.PUTSTATIC : |
| 596 | |
|
| 597 | 0 | throw new RuntimeException("PUTSTATIC not handled yet"); |
| 598 | |
case Opcodes.GETFIELD : |
| 599 | |
|
| 600 | 0 | throw new RuntimeException("GETFIELD not handled yet"); |
| 601 | |
case Opcodes.PUTFIELD : |
| 602 | |
|
| 603 | 0 | throw new RuntimeException("PUTFIELD not handled yet"); |
| 604 | |
case Opcodes.INVOKEVIRTUAL : |
| 605 | |
|
| 606 | 0 | throw new RuntimeException("INVOKEVIRTUAL not handled yet"); |
| 607 | |
case Opcodes.INVOKESPECIAL : |
| 608 | |
|
| 609 | 0 | throw new RuntimeException("INVOKESPECIAL not handled yet"); |
| 610 | |
case Opcodes.INVOKESTATIC : |
| 611 | 43 | return replaceInvocation((MethodInsnNode) insn, |
| 612 | |
stack1Converted, stack0Converted); |
| 613 | |
case Opcodes.INVOKEINTERFACE : |
| 614 | |
|
| 615 | 0 | throw new RuntimeException("INVOKEINTERFACE not handled yet"); |
| 616 | |
case Opcodes.NEWARRAY : |
| 617 | |
|
| 618 | 0 | throw new RuntimeException("NEWARRAY not handled yet"); |
| 619 | |
case Opcodes.ANEWARRAY : |
| 620 | |
|
| 621 | 0 | throw new RuntimeException("ANEWARRAY not handled yet"); |
| 622 | |
case Opcodes.MULTIANEWARRAY : |
| 623 | |
|
| 624 | 0 | throw new RuntimeException("MULTIANEWARRAY not handled yet"); |
| 625 | |
default: |
| 626 | 0 | throw new DifferentiationException("unable to handle instruction with opcode {0}", |
| 627 | |
new Object[] { |
| 628 | |
Integer.valueOf(insn.getOpcode()) |
| 629 | |
}); |
| 630 | |
} |
| 631 | |
|
| 632 | |
} |
| 633 | |
|
| 634 | |
|
| 635 | |
|
| 636 | |
|
| 637 | |
|
| 638 | |
|
| 639 | |
|
| 640 | |
|
| 641 | |
|
| 642 | |
|
| 643 | |
private InsnList replaceInvocation(final MethodInsnNode methodInsn, |
| 644 | |
final boolean stack1Converted, |
| 645 | |
final boolean stack0Converted) |
| 646 | |
throws DifferentiationException { |
| 647 | 43 | if (isMathImplementationClass(methodInsn.owner)) { |
| 648 | 43 | if ("(D)D".equals(methodInsn.desc)) { |
| 649 | |
|
| 650 | 32 | final MathInvocationTransformer transformer = MATH_TRANSFORMERS.get(methodInsn.name); |
| 651 | 32 | if (transformer == null) { |
| 652 | 0 | throw new DifferentiationException(UNKNOWN_METHOD_FMT, |
| 653 | |
methodInsn.owner, methodInsn.name); |
| 654 | |
} |
| 655 | 32 | return transformer.getReplacementList(methodInsn.owner, this); |
| 656 | 11 | } else if ("(DD)D".equals(methodInsn.desc)) { |
| 657 | |
|
| 658 | |
|
| 659 | |
|
| 660 | 11 | String name = null; |
| 661 | 11 | if (stack1Converted) { |
| 662 | 8 | if (stack0Converted) { |
| 663 | 5 | name = methodInsn.name + "_12"; |
| 664 | |
} else { |
| 665 | 3 | name = methodInsn.name + "_1"; |
| 666 | |
} |
| 667 | 3 | } else if (stack0Converted) { |
| 668 | 3 | name = methodInsn.name + "_2"; |
| 669 | |
} |
| 670 | |
|
| 671 | 11 | if (name != null) { |
| 672 | 11 | final MathInvocationTransformer transformer = MATH_TRANSFORMERS.get(name); |
| 673 | 11 | if (transformer == null) { |
| 674 | 0 | throw new DifferentiationException(UNKNOWN_METHOD_FMT, |
| 675 | |
methodInsn.owner, methodInsn.name); |
| 676 | |
} |
| 677 | 11 | return transformer.getReplacementList(methodInsn.owner, this); |
| 678 | |
} |
| 679 | |
} |
| 680 | |
} |
| 681 | 0 | throw new DifferentiationException("unexpected instruction {0}", |
| 682 | |
Integer.valueOf(methodInsn.getOpcode())); |
| 683 | |
} |
| 684 | |
|
| 685 | |
|
| 686 | |
|
| 687 | |
|
| 688 | |
|
| 689 | |
public boolean isMathImplementationClass(final String name) { |
| 690 | 43 | return mathClasses.contains(name); |
| 691 | |
} |
| 692 | |
|
| 693 | |
|
| 694 | |
|
| 695 | |
|
| 696 | |
|
| 697 | |
|
| 698 | |
|
| 699 | |
|
| 700 | |
public void useLocal(final int index, final int size) |
| 701 | |
throws DifferentiationException { |
| 702 | 394 | if ((index < 0) || ((index + size - 1) >= usedLocals.length)) { |
| 703 | 0 | throw new DifferentiationException("index of size {0} local variable ({1}) " + |
| 704 | |
"outside of [{2}, {3}] range", |
| 705 | |
Integer.valueOf(size), Integer.valueOf(index), |
| 706 | |
Integer.valueOf(1), Integer.valueOf(MAX_TEMP)); |
| 707 | |
} |
| 708 | 1554 | for (int i = index; i < index + size; ++i) { |
| 709 | 1160 | usedLocals[i] = true; |
| 710 | |
} |
| 711 | 394 | } |
| 712 | |
|
| 713 | |
|
| 714 | |
|
| 715 | |
|
| 716 | |
|
| 717 | |
|
| 718 | |
|
| 719 | |
|
| 720 | |
|
| 721 | |
|
| 722 | |
|
| 723 | |
|
| 724 | |
|
| 725 | |
|
| 726 | |
|
| 727 | |
|
| 728 | |
|
| 729 | |
|
| 730 | |
|
| 731 | |
|
| 732 | |
|
| 733 | |
public int getTmp(final int number) throws DifferentiationException { |
| 734 | 89 | if ((number < 0) || (number > MAX_TEMP)) { |
| 735 | 0 | throw new DifferentiationException("number of temporary variable ({0}) outside of [{1}, {2}] range", |
| 736 | |
Integer.valueOf(number), |
| 737 | |
Integer.valueOf(1), |
| 738 | |
Integer.valueOf(MAX_TEMP)); |
| 739 | |
} |
| 740 | 89 | final int index = usedLocals.length - 2 * number; |
| 741 | 89 | useLocal(index, 2); |
| 742 | 89 | return index; |
| 743 | |
} |
| 744 | |
|
| 745 | |
|
| 746 | |
|
| 747 | |
|
| 748 | |
public void shiftVariable(final VarInsnNode insn) { |
| 749 | 873 | int shifted = 0; |
| 750 | 5026 | for (int i = 0; i < insn.var; ++i) { |
| 751 | 4153 | if (usedLocals[i]) { |
| 752 | 3141 | ++shifted; |
| 753 | |
} |
| 754 | |
} |
| 755 | 873 | insn.var = shifted; |
| 756 | 873 | } |
| 757 | |
|
| 758 | |
|
| 759 | |
|
| 760 | |
|
| 761 | |
|
| 762 | |
public AbstractInsnNode clone(final AbstractInsnNode insn) { |
| 763 | 3 | return insn.clone(clonedLabels); |
| 764 | |
} |
| 765 | |
|
| 766 | |
|
| 767 | |
private class FlowAnalyzer extends Analyzer { |
| 768 | |
|
| 769 | |
|
| 770 | |
|
| 771 | |
|
| 772 | 66 | public FlowAnalyzer(final Interpreter interpreter) { |
| 773 | 66 | super(interpreter); |
| 774 | 66 | } |
| 775 | |
|
| 776 | |
|
| 777 | |
|
| 778 | |
|
| 779 | |
|
| 780 | |
protected void newControlFlowEdge(final int insn, final int successor) { |
| 781 | |
|
| 782 | 217 | final AbstractInsnNode node = instructions.get(insn); |
| 783 | 217 | Set<AbstractInsnNode> set = successors.get(node); |
| 784 | 217 | if (set == null) { |
| 785 | 215 | set = new HashSet<AbstractInsnNode>(); |
| 786 | 215 | successors.put(node, set); |
| 787 | |
} |
| 788 | 217 | set.add(instructions.get(successor)); |
| 789 | 217 | } |
| 790 | |
|
| 791 | |
} |
| 792 | |
|
| 793 | |
} |