001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.proxy2.asm;
018
019import java.io.Serializable;
020import java.lang.reflect.InvocationTargetException;
021import java.lang.reflect.Method;
022import java.lang.reflect.UndeclaredThrowableException;
023import java.util.concurrent.atomic.AtomicInteger;
024
025import org.apache.commons.lang3.ArrayUtils;
026import org.apache.commons.lang3.ObjectUtils;
027import org.apache.commons.proxy2.Interceptor;
028import org.apache.commons.proxy2.Invocation;
029import org.apache.commons.proxy2.Invoker;
030import org.apache.commons.proxy2.ObjectProvider;
031import org.apache.commons.proxy2.ProxyUtils;
032import org.apache.commons.proxy2.exception.ProxyFactoryException;
033import org.apache.commons.proxy2.impl.AbstractProxyClassGenerator;
034import org.apache.commons.proxy2.impl.AbstractSubclassingProxyFactory;
035import org.apache.commons.proxy2.impl.ProxyClassCache;
036import org.objectweb.asm.ClassWriter;
037import org.objectweb.asm.Label;
038import org.objectweb.asm.Opcodes;
039import org.objectweb.asm.Type;
040import org.objectweb.asm.commons.GeneratorAdapter;
041
042public class ASMProxyFactory extends AbstractSubclassingProxyFactory
043{
044    private static final ProxyClassCache PROXY_CLASS_CACHE = new ProxyClassCache(new ProxyGenerator());
045
046    @Override
047    public <T> T createDelegatorProxy(final ClassLoader classLoader, final ObjectProvider<?> delegateProvider,
048            final Class<?>... proxyClasses)
049    {
050        return createProxy(classLoader, new DelegatorInvoker(delegateProvider), proxyClasses);
051    }
052
053    @Override
054    public <T> T createInterceptorProxy(final ClassLoader classLoader, final Object target,
055            final Interceptor interceptor, final Class<?>... proxyClasses)
056    {
057        return createProxy(classLoader, new InterceptorInvoker(target, interceptor), proxyClasses);
058    }
059
060    @Override
061    public <T> T createInvokerProxy(final ClassLoader classLoader, final Invoker invoker,
062            final Class<?>... proxyClasses)
063    {
064        return createProxy(classLoader, new InvokerInvoker(invoker), proxyClasses);
065    }
066
067    private <T> T createProxy(final ClassLoader classLoader, final AbstractInvoker invoker,
068            final Class<?>... proxyClasses)
069    {
070        final Class<?> proxyClass = PROXY_CLASS_CACHE.getProxyClass(classLoader, proxyClasses);
071        try
072        {
073            @SuppressWarnings("unchecked") // type inference
074            final T result = (T) proxyClass.getConstructor(Invoker.class).newInstance(invoker);
075            return result;
076        }
077        catch (Exception e)
078        {
079            throw e instanceof RuntimeException ? ((RuntimeException) e) : new RuntimeException(e);
080        }
081    }
082
083    private static class ProxyGenerator extends AbstractProxyClassGenerator implements Opcodes
084    {
085        private static final AtomicInteger CLASS_NUMBER = new AtomicInteger(0);
086        private static final String CLASSNAME_PREFIX = "CommonsProxyASM_";
087        private static final String HANDLER_NAME = "__handler";
088        private static final Type INVOKER_TYPE = Type.getType(Invoker.class);
089
090        @Override
091        public Class<?> generateProxyClass(final ClassLoader classLoader, final Class<?>... proxyClasses)
092        {
093            final Class<?> superclass = getSuperclass(proxyClasses);
094            final String proxyName = CLASSNAME_PREFIX + CLASS_NUMBER.incrementAndGet();
095            final Method[] implementationMethods = getImplementationMethods(proxyClasses);
096            final Class<?>[] interfaces = toInterfaces(proxyClasses);
097            final String classFileName = proxyName.replace('.', '/');
098
099            try
100            {
101                final byte[] proxyBytes = generateProxy(superclass, classFileName, implementationMethods, interfaces);
102                return loadClass(classLoader, proxyName, proxyBytes);
103            }
104            catch (final Exception e)
105            {
106                throw new ProxyFactoryException(e);
107            }
108        }
109
110        private static byte[] generateProxy(final Class<?> classToProxy, final String proxyName,
111                final Method[] methods, final Class<?>... interfaces) throws ProxyFactoryException
112        {
113            final ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_MAXS);
114
115            final Type proxyType = Type.getObjectType(proxyName);
116
117            // push class signature
118            final String[] interfaceNames = new String[interfaces.length];
119            for (int i = 0; i < interfaces.length; i++)
120            {
121                interfaceNames[i] = Type.getType(interfaces[i]).getInternalName();
122            }
123
124            final Type superType = Type.getType(classToProxy);
125            cw.visit(V1_6, ACC_PUBLIC + ACC_SUPER, proxyType.getInternalName(), null, superType.getInternalName(),
126                    interfaceNames);
127
128            // create Invoker field
129            cw.visitField(ACC_FINAL + ACC_PRIVATE, HANDLER_NAME, INVOKER_TYPE.getDescriptor(), null, null).visitEnd();
130
131            init(cw, proxyType, superType);
132
133            for (final Method method : methods)
134            {
135                processMethod(cw, method, proxyType, HANDLER_NAME);
136            }
137
138            return cw.toByteArray();
139        }
140
141        private static void init(final ClassWriter cw, final Type proxyType, Type superType)
142        {
143            final GeneratorAdapter mg = new GeneratorAdapter(ACC_PUBLIC, new org.objectweb.asm.commons.Method("<init>",
144                    Type.VOID_TYPE, new Type[] { INVOKER_TYPE }), null, null, cw);
145            // invoke super constructor:
146            mg.loadThis();
147            mg.invokeConstructor(superType, org.objectweb.asm.commons.Method.getMethod("void <init> ()"));
148
149            // assign handler:
150            mg.loadThis();
151            mg.loadArg(0);
152            mg.putField(proxyType, HANDLER_NAME, INVOKER_TYPE);
153            mg.returnValue();
154            mg.endMethod();
155        }
156
157        private static void processMethod(final ClassWriter cw, final Method method, final Type proxyType,
158                final String handlerName) throws ProxyFactoryException
159        {
160            final Type sig = Type.getType(method);
161            final Type[] exceptionTypes = getTypes(method.getExceptionTypes());
162
163            // push the method definition
164            final int access = (ACC_PUBLIC | ACC_PROTECTED) & method.getModifiers();
165            final org.objectweb.asm.commons.Method m = org.objectweb.asm.commons.Method.getMethod(method);
166            final GeneratorAdapter mg = new GeneratorAdapter(access, m, null, getTypes(method.getExceptionTypes()), cw);
167
168            final Label tryBlock = exceptionTypes.length > 0 ? mg.mark() : null;
169
170            mg.push(Type.getType(method.getDeclaringClass()));
171
172            // the following code generates the bytecode for this line of Java:
173            // Method method = <proxy>.class.getMethod("add", new Class[] {
174            // <array of function argument classes> });
175
176            // get the method name to invoke, and push to stack
177
178            mg.push(method.getName());
179
180            // create the Class[]
181            mg.push(sig.getArgumentTypes().length);
182            final Type classType = Type.getType(Class.class);
183            mg.newArray(classType);
184
185            // push parameters into array
186            for (int i = 0; i < sig.getArgumentTypes().length; i++)
187            {
188                // keep copy of array on stack
189                mg.dup();
190
191                // push index onto stack
192                mg.push(i);
193                mg.push(sig.getArgumentTypes()[i]);
194                mg.arrayStore(classType);
195            }
196
197            // invoke getMethod() with the method name and the array of types
198            mg.invokeVirtual(classType, org.objectweb.asm.commons.Method
199                    .getMethod("java.lang.reflect.Method getDeclaredMethod(String, Class[])"));
200            // store the returned method for later
201
202            // the following code generates bytecode equivalent to:
203            // return ((<returntype>) invoker.invoke(this, method, new Object[]
204            // { <function arguments }))[.<primitive>Value()];
205
206            mg.loadThis();
207
208            mg.getField(proxyType, handlerName, INVOKER_TYPE);
209            // put below method:
210            mg.swap();
211
212            // we want to pass "this" in as the first parameter
213            mg.loadThis();
214            // put below method:
215            mg.swap();
216
217            // need to construct the array of objects passed in
218
219            // create the Object[]
220            mg.push(sig.getArgumentTypes().length);
221            final Type objectType = Type.getType(Object.class);
222            mg.newArray(objectType);
223
224            // push parameters into array
225            for (int i = 0; i < sig.getArgumentTypes().length; i++)
226            {
227                // keep copy of array on stack
228                mg.dup();
229
230                // push index onto stack
231                mg.push(i);
232
233                mg.loadArg(i);
234                mg.valueOf(sig.getArgumentTypes()[i]);
235                mg.arrayStore(objectType);
236            }
237
238            // invoke the invoker
239            mg.invokeInterface(INVOKER_TYPE, org.objectweb.asm.commons.Method
240                    .getMethod("Object invoke(Object, java.lang.reflect.Method, Object[])"));
241
242            // cast the result
243            mg.unbox(sig.getReturnType());
244
245            // push return
246            mg.returnValue();
247
248            // catch InvocationTargetException
249            if (exceptionTypes.length > 0)
250            {
251                final Type caughtExceptionType = Type.getType(InvocationTargetException.class);
252                mg.catchException(tryBlock, mg.mark(), caughtExceptionType);
253
254                final Label throwCause = new Label();
255
256                mg.invokeVirtual(caughtExceptionType,
257                        org.objectweb.asm.commons.Method.getMethod("Throwable getCause()"));
258
259                for (int i = 0; i < exceptionTypes.length; i++)
260                {
261                    mg.dup();
262                    mg.push(exceptionTypes[i]);
263                    mg.swap();
264                    mg.invokeVirtual(classType,
265                            org.objectweb.asm.commons.Method.getMethod("boolean isInstance(Object)"));
266                    // if true, throw cause:
267                    mg.ifZCmp(GeneratorAdapter.NE, throwCause);
268                }
269                // no exception types matched; throw
270                // UndeclaredThrowableException:
271                final int cause = mg.newLocal(Type.getType(Exception.class));
272                mg.storeLocal(cause);
273                final Type undeclaredType = Type.getType(UndeclaredThrowableException.class);
274                mg.newInstance(undeclaredType);
275                mg.dup();
276                mg.loadLocal(cause);
277                mg.invokeConstructor(undeclaredType, new org.objectweb.asm.commons.Method("<init>", Type.VOID_TYPE,
278                        new Type[] { Type.getType(Throwable.class) }));
279                mg.throwException();
280
281                mg.mark(throwCause);
282                mg.throwException();
283            }
284
285            // finish this method
286            mg.endMethod();
287        }
288
289        private static Type[] getTypes(Class<?>... src)
290        {
291            final Type[] result = new Type[src.length];
292            for (int i = 0; i < result.length; i++)
293            {
294                result[i] = Type.getType(src[i]);
295            }
296            return result;
297        }
298
299        /**
300         * Adapted from http://asm.ow2.org/doc/faq.html#Q5
301         * 
302         * @param b
303         * @return Class<?>
304         */
305        private static Class<?> loadClass(final ClassLoader loader, String className, byte[] b)
306        {
307            // override classDefine (as it is protected) and define the class.
308            try
309            {
310                final Method method = ClassLoader.class.getDeclaredMethod("defineClass", String.class, byte[].class,
311                        int.class, int.class);
312
313                // protected method invocation
314                final boolean accessible = method.isAccessible();
315                if (!accessible)
316                {
317                    method.setAccessible(true);
318                }
319                try
320                {
321                    return (Class<?>) method
322                            .invoke(loader, className, b, Integer.valueOf(0), Integer.valueOf(b.length));
323                }
324                finally
325                {
326                    if (!accessible)
327                    {
328                        method.setAccessible(false);
329                    }
330                }
331            }
332            catch (Exception e)
333            {
334                throw e instanceof RuntimeException ? ((RuntimeException) e) : new RuntimeException(e);
335            }
336        }
337    }
338
339    @SuppressWarnings("serial")
340    private static class DelegatorInvoker extends AbstractInvoker
341    {
342        private final ObjectProvider<?> delegateProvider;
343
344        protected DelegatorInvoker(ObjectProvider<?> delegateProvider)
345        {
346            this.delegateProvider = delegateProvider;
347        }
348
349        @Override
350        public Object invokeImpl(Object proxy, Method method, Object[] args) throws Throwable
351        {
352            try
353            {
354                return method.invoke(delegateProvider.getObject(), args);
355            }
356            catch (InvocationTargetException e)
357            {
358                throw e.getTargetException();
359            }
360        }
361    }
362
363    @SuppressWarnings("serial")
364    private static class InterceptorInvoker extends AbstractInvoker
365    {
366        private final Object target;
367        private final Interceptor methodInterceptor;
368
369        public InterceptorInvoker(Object target, Interceptor methodInterceptor)
370        {
371            this.target = target;
372            this.methodInterceptor = methodInterceptor;
373        }
374
375        @Override
376        public Object invokeImpl(Object proxy, Method method, Object[] args) throws Throwable
377        {
378            final ReflectionInvocation invocation = new ReflectionInvocation(target, proxy, method, args);
379            return methodInterceptor.intercept(invocation);
380        }
381    }
382
383    @SuppressWarnings("serial")
384    private abstract static class AbstractInvoker implements Invoker, Serializable
385    {
386        @Override
387        public Object invoke(Object proxy, Method method, Object[] args) throws Throwable
388        {
389            if (isHashCode(method))
390            {
391                return Integer.valueOf(System.identityHashCode(proxy));
392            }
393            if (isEqualsMethod(method))
394            {
395                return Boolean.valueOf(proxy == args[0]);
396            }
397            return invokeImpl(proxy, method, args);
398        }
399
400        protected abstract Object invokeImpl(Object proxy, Method method, Object[] args) throws Throwable;
401    }
402
403    @SuppressWarnings("serial")
404    private static class InvokerInvoker extends AbstractInvoker
405    {
406        private final Invoker invoker;
407
408        public InvokerInvoker(Invoker invoker)
409        {
410            this.invoker = invoker;
411        }
412
413        @Override
414        public Object invokeImpl(Object proxy, Method method, Object[] args) throws Throwable
415        {
416            return invoker.invoke(proxy, method, args);
417        }
418    }
419
420    protected static boolean isHashCode(Method method)
421    {
422        return "hashCode".equals(method.getName()) && Integer.TYPE.equals(method.getReturnType())
423                && method.getParameterTypes().length == 0;
424    }
425
426    protected static boolean isEqualsMethod(Method method)
427    {
428        return "equals".equals(method.getName()) && Boolean.TYPE.equals(method.getReturnType())
429                && method.getParameterTypes().length == 1 && Object.class.equals(method.getParameterTypes()[0]);
430    }
431
432    private static class ReflectionInvocation implements Invocation
433    {
434        private final Method method;
435        private final Object[] arguments;
436        private final Object proxy;
437        private final Object target;
438
439        public ReflectionInvocation(final Object target, final Object proxy, final Method method,
440                final Object[] arguments)
441        {
442            this.method = method;
443            this.arguments = ObjectUtils.defaultIfNull(ArrayUtils.clone(arguments), ProxyUtils.EMPTY_ARGUMENTS);
444            this.proxy = proxy;
445            this.target = target;
446        }
447
448        @Override
449        public Object[] getArguments()
450        {
451            return arguments;
452        }
453
454        @Override
455        public Method getMethod()
456        {
457            return method;
458        }
459
460        @Override
461        public Object getProxy()
462        {
463            return proxy;
464        }
465
466        @Override
467        public Object proceed() throws Throwable
468        {
469            try
470            {
471                return method.invoke(target, arguments);
472            }
473            catch (InvocationTargetException e)
474            {
475                throw e.getTargetException();
476            }
477        }
478    }
479}