001    /*
002     *  Copyright the original author or authors.
003     *
004     *  Licensed under the Apache License, Version 2.0 (the "License");
005     *  you may not use this file except in compliance with the License.
006     *  You may obtain a copy of the License at
007     *
008     *      http://www.apache.org/licenses/LICENSE-2.0
009     *
010     *  Unless required by applicable law or agreed to in writing, software
011     *  distributed under the License is distributed on an "AS IS" BASIS,
012     *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013     *  See the License for the specific language governing permissions and
014     *  limitations under the License.
015     */
016    package org.apache.commons.privilizer.weave;
017    
018    import java.io.IOException;
019    import java.lang.reflect.Modifier;
020    import java.nio.charset.Charset;
021    import java.security.AccessController;
022    import java.security.PrivilegedAction;
023    import java.security.PrivilegedActionException;
024    import java.security.PrivilegedExceptionAction;
025    import java.util.ArrayList;
026    import java.util.Comparator;
027    import java.util.List;
028    import java.util.Set;
029    import java.util.TreeSet;
030    import java.util.logging.Logger;
031    
032    import javassist.CannotCompileException;
033    import javassist.ClassPool;
034    import javassist.CtClass;
035    import javassist.CtField;
036    import javassist.CtMethod;
037    import javassist.CtNewConstructor;
038    import javassist.CtNewMethod;
039    import javassist.CtPrimitiveType;
040    import javassist.NotFoundException;
041    import javassist.bytecode.Descriptor;
042    
043    import org.apache.commons.lang3.ObjectUtils;
044    import org.apache.commons.lang3.StringUtils;
045    import org.apache.commons.lang3.Validate;
046    import org.apache.commons.lang3.text.StrBuilder;
047    import org.apache.commons.privilizer.Privileged;
048    
049    
050    /**
051     * Handles weaving of methods annotated with {@link Privileged}.
052     */
053    public abstract class Privilizer<SELF extends Privilizer<SELF>> {
054        public interface ClassFileWriter {
055            void write(CtClass type) throws CannotCompileException, IOException;
056        }
057    
058        public interface Log {
059            void debug(String message);
060    
061            void verbose(String message);
062    
063            void error(String message);
064    
065            void info(String message);
066    
067            void warn(String message);
068        }
069    
070        /**
071         * Weaving policy: when to use {@link PrivilegedAction}s.
072         */
073        public enum Policy {
074            /**
075             * Disables weaving.
076             */
077            NEVER,
078    
079            /**
080             * Weaves such that the check for an active {@link SecurityManager} is
081             * done once only.
082             */
083            ON_INIT(generateName("hasSecurityManager")),
084    
085            /**
086             * Weaves such that the check for an active {@link SecurityManager} is
087             * done for each {@link Privileged} method execution.
088             */
089            DYNAMIC(HAS_SECURITY_MANAGER_CONDITION),
090    
091            /**
092             * Weaves such that {@link Privileged} methods are always executed as
093             * such.
094             */
095            ALWAYS;
096    
097            private final String condition;
098    
099            private Policy() {
100                this(null);
101            }
102    
103            private Policy(String condition) {
104                this.condition = condition;
105            }
106    
107            private boolean isConditional() {
108                return condition != null;
109            }
110        }
111    
112        protected static final String POLICY_NAME = "policyName";
113    
114        private static final String ACTION_SUFFIX = "_ACTION";
115    
116        private static final String GENERATE_NAME = "__privileged_%s";
117        private static final String HAS_SECURITY_MANAGER_CONDITION = "System.getSecurityManager() != null";
118    
119        protected static String generateName(String simple) {
120            return String.format(GENERATE_NAME, simple);
121        }
122    
123        protected static String toString(byte[] b) {
124            return b == null ? null : new String(b, Charset.forName("UTF-8"));
125        }
126    
127        protected final Policy policy;
128    
129        protected final ClassPool classPool;
130    
131        private boolean settingsReported;
132    
133        private Log log = new Log() {
134            final Logger logger = Logger.getLogger(Privilizer.class.getName());
135    
136            @Override
137            public void debug(String message) {
138                logger.finer(message);
139            }
140    
141            @Override
142            public void verbose(String message) {
143                logger.fine(message);
144            }
145    
146            @Override
147            public void error(String message) {
148                logger.severe(message);
149            }
150    
151            @Override
152            public void info(String message) {
153                logger.info(message);
154            }
155    
156            @Override
157            public void warn(String message) {
158                logger.warning(message);
159            }
160    
161        };
162    
163        private static final Comparator<CtMethod> CTMETHOD_COMPARATOR = new Comparator<CtMethod>() {
164    
165            @Override
166            public int compare(CtMethod arg0, CtMethod arg1) {
167                if (ObjectUtils.equals(arg0, arg1)) {
168                    return 0;
169                }
170                if (arg0 == null) {
171                    return -1;
172                }
173                if (arg1 == null) {
174                    return 1;
175                }
176                final int result = ObjectUtils.compare(arg0.getName(), arg1.getName());
177                return result == 0 ? ObjectUtils.compare(arg0.getSignature(), arg1.getSignature()) : result;
178            }
179        };
180    
181        private static Set<CtMethod> getPrivilegedMethods(CtClass type) throws ClassNotFoundException {
182            final TreeSet<CtMethod> result = new TreeSet<CtMethod>(CTMETHOD_COMPARATOR);
183            for (final CtMethod m : type.getDeclaredMethods()) {
184                if (Modifier.isAbstract(m.getModifiers()) || m.getAnnotation(Privileged.class) == null) {
185                    continue;
186                }
187                result.add(m);
188            }
189            return result;
190        }
191    
192        public Privilizer(ClassPool classPool) {
193            this(Policy.DYNAMIC, classPool);
194        }
195    
196        public Privilizer(Policy policy, ClassPool classPool) {
197            this.policy = Validate.notNull(policy, "policy");
198            this.classPool = Validate.notNull(classPool, "classPool");
199        }
200    
201        public SELF loggingTo(Log log) {
202            this.log = Validate.notNull(log);
203            settingsReported = false;
204            @SuppressWarnings("unchecked")
205            final SELF self = (SELF) this;
206            return self;
207        }
208    
209        /**
210         * Weave the specified class.
211         * 
212         * @param type
213         * @return whether any work was done
214         * @throws NotFoundException
215         * @throws IOException
216         * @throws CannotCompileException
217         * @throws ClassNotFoundException
218         */
219        public boolean weave(CtClass type) throws NotFoundException, IOException, CannotCompileException,
220            ClassNotFoundException {
221            reportSettings();
222            final String policyName = generateName(POLICY_NAME);
223            final String policyValue = toString(type.getAttribute(policyName));
224            if (policyValue != null) {
225                verbose("%s already woven with policy %s", type.getName(), policyValue);
226                if (!policy.name().equals(policyValue)) {
227                    throw new AlreadyWovenException(type.getName(), Policy.valueOf(policyValue));
228                }
229                return false;
230            }
231            boolean result = false;
232            if (policy.compareTo(Policy.NEVER) > 0) {
233                if (policy == Policy.ON_INIT) {
234                    debug("Initializing field %s to %s", policy.condition, HAS_SECURITY_MANAGER_CONDITION);
235                    type.addField(new CtField(CtClass.booleanType, policy.condition, type),
236                        CtField.Initializer.byExpr(HAS_SECURITY_MANAGER_CONDITION));
237                }
238                for (final CtMethod m : getPrivilegedMethods(type)) {
239                    result |= weave(type, m);
240                }
241                if (result) {
242                    type.setAttribute(policyName, policy.name().getBytes(Charset.forName("UTF-8")));
243                    getClassFileWriter().write(type);
244                }
245            }
246            log.verbose(String.format(result ? "Wove class %s" : "Nothing to do for class %s", type.getName()));
247            return result;
248        }
249    
250        protected void debug(String message, Object... args) {
251            log.debug(String.format(message, args));
252        }
253    
254        protected void verbose(String message, Object... args) {
255            log.verbose(String.format(message, args));
256        }
257    
258        protected void warn(String message, Object... args) {
259            log.warn(String.format(message, args));
260        }
261    
262        protected abstract ClassFileWriter getClassFileWriter();
263    
264        protected void info(String message, Object... args) {
265            log.info(String.format(message, args));
266        }
267    
268        protected boolean permitMethodWeaving(AccessLevel accessLevel) {
269            return true;
270        }
271    
272        private CtClass createAction(CtClass type, CtMethod impl, Class<?> iface) throws NotFoundException,
273            CannotCompileException, IOException {
274            final boolean exc = impl.getExceptionTypes().length > 0;
275    
276            final CtClass actionType = classPool.get(iface.getName());
277    
278            final String simpleName = generateActionClassname(impl);
279            debug("Creating action type %s for method %s", simpleName, toString(impl));
280            final CtClass result = type.makeNestedClass(simpleName, true);
281            result.addInterface(actionType);
282    
283            final CtField owner;
284            if (Modifier.isStatic(impl.getModifiers())) {
285                owner = null;
286            } else {
287                owner = new CtField(type, generateName("owner"), result);
288                owner.setModifiers(Modifier.PRIVATE | Modifier.FINAL);
289                debug("Adding owner field %s to %s", owner.getName(), simpleName);
290                result.addField(owner);
291            }
292    
293            final List<String> propagatedParameters = new ArrayList<String>();
294            int index = -1;
295            for (final CtClass param : impl.getParameterTypes()) {
296                final String f = String.format("arg%s", Integer.valueOf(++index));
297                final CtField fld = new CtField(param, f, result);
298                fld.setModifiers(Modifier.PRIVATE | Modifier.FINAL);
299                debug("Copying parameter %s from %s to %s.%s", index, toString(impl), simpleName, f);
300                result.addField(fld);
301                propagatedParameters.add(f);
302            }
303            {
304                final StrBuilder constructor = new StrBuilder(simpleName).append('(');
305                boolean sep = false;
306                final Body body = new Body();
307    
308                for (final CtField fld : result.getDeclaredFields()) {
309                    if (sep) {
310                        constructor.append(", ");
311                    } else {
312                        sep = true;
313                    }
314                    constructor.append(fld.getType().getName()).append(' ').append(fld.getName());
315                    body.appendLine("this.%1$s = %1$s;", fld.getName());
316                }
317                constructor.append(") ").append(body.complete());
318    
319                final String c = constructor.toString();
320                debug("Creating action constructor:");
321                debug(c);
322                result.addConstructor(CtNewConstructor.make(c, result));
323            }
324            {
325                final StrBuilder run = new StrBuilder("public Object run() ");
326                if (exc) {
327                    run.append("throws Exception ");
328                }
329                final Body body = new Body();
330                final CtClass rt = impl.getReturnType();
331                final boolean isVoid = rt.equals(CtClass.voidType);
332                if (!isVoid) {
333                    body.append("return ");
334                }
335                final String deref = Modifier.isStatic(impl.getModifiers()) ? type.getName() : owner.getName();
336                final String call =
337                    String.format("%s.%s(%s)", deref, impl.getName(), StringUtils.join(propagatedParameters, ", "));
338    
339                if (!isVoid && rt.isPrimitive()) {
340                    body.appendLine("%2$s.valueOf(%1$s);", call, ((CtPrimitiveType) rt).getWrapperName());
341                } else {
342                    body.append(call).append(';').appendNewLine();
343    
344                    if (isVoid) {
345                        body.appendLine("return null;");
346                    }
347                }
348    
349                run.append(body.complete());
350    
351                final String r = run.toString();
352                debug("Creating run method:");
353                debug(r);
354                result.addMethod(CtNewMethod.make(r, result));
355            }
356            getClassFileWriter().write(result);
357            debug("Returning action type %s", result);
358            return result;
359        }
360    
361        private String generateActionClassname(CtMethod m) throws NotFoundException {
362            final StringBuilder b = new StringBuilder(m.getName());
363            if (m.getParameterTypes().length > 0) {
364                b.append("$$").append(
365                    StringUtils.strip(Descriptor.getParamDescriptor(m.getSignature()), "(;)").replace("[", "ARRAYOF_")
366                        .replace('/', '_').replace(';', '$'));
367            }
368            return b.append(ACTION_SUFFIX).toString();
369        }
370    
371        private String toString(CtMethod m) {
372            return String.format("%s%s", m.getName(), m.getSignature());
373        }
374    
375        private boolean weave(CtClass type, CtMethod method) throws ClassNotFoundException, CannotCompileException,
376            NotFoundException, IOException {
377            final AccessLevel accessLevel = AccessLevel.of(method.getModifiers());
378            if (!permitMethodWeaving(accessLevel)) {
379                warn("Ignoring %s method %s.%s", accessLevel, type.getName(), toString(method));
380                return false;
381            }
382            if (AccessLevel.PACKAGE.compareTo(accessLevel) > 0) {
383                warn("Possible security leak: granting privileges to %s method %s.%s", accessLevel, type.getName(),
384                    toString(method));
385            }
386            final String implName = generateName(method.getName());
387    
388            final CtMethod impl = CtNewMethod.copy(method, implName, type, null);
389            impl.setModifiers(AccessLevel.PRIVATE.merge(method.getModifiers()));
390            type.addMethod(impl);
391            debug("Copied %2$s %1$s.%3$s to %4$s %1$s.%5$s", type.getName(), accessLevel, toString(method),
392                AccessLevel.PRIVATE, toString(impl));
393    
394            final Body body = new Body();
395            if (policy.isConditional()) {
396                body.startBlock("if (%s)", policy.condition);
397            }
398    
399            final boolean exc = method.getExceptionTypes().length > 0;
400    
401            if (exc) {
402                body.startBlock("try");
403            }
404    
405            final Class<?> iface = exc ? PrivilegedExceptionAction.class : PrivilegedAction.class;
406            final CtClass actionType = createAction(type, impl, iface);
407            final String action = generateName("action");
408    
409            body.append("final %s %s = new %s(", iface.getName(), action, actionType.getName());
410            boolean firstParam;
411            if (Modifier.isStatic(impl.getModifiers())) {
412                firstParam = true;
413            } else {
414                body.append("$0");
415                firstParam = false;
416            }
417            for (int i = 1, sz = impl.getParameterTypes().length; i <= sz; i++) {
418                if (firstParam) {
419                    firstParam = false;
420                } else {
421                    body.append(", ");
422                }
423                body.append('$').append(Integer.toString(i));
424            }
425            body.appendLine(");");
426    
427            final CtClass rt = method.getReturnType();
428            final boolean isVoid = rt.equals(CtClass.voidType);
429    
430            final String doPrivileged = String.format("%1$s.doPrivileged(%2$s)", AccessController.class.getName(), action);
431            if (isVoid) {
432                body.append(doPrivileged).append(';').appendNewLine();
433                if (policy.isConditional()) {
434                    body.appendLine("return;");
435                }
436            } else {
437                final String cast = rt.isPrimitive() ? ((CtPrimitiveType) rt).getWrapperName() : rt.getName();
438                // don't worry about wrapper NPEs because we should be simply
439                // passing back an autoboxed value, then unboxing again
440                final String result = generateName("result");
441                body.appendLine("final %1$s %3$s = (%1$s) %2$s;", cast, doPrivileged, result);
442                body.append("return %s", result);
443                if (rt.isPrimitive()) {
444                    body.append(".%sValue()", rt.getName());
445                }
446                body.append(';').appendNewLine();
447            }
448    
449            if (exc) {
450                body.endBlock();
451                final String e = generateName("e");
452                body.startBlock("catch (%1$s %2$s)", PrivilegedActionException.class.getName(), e).appendNewLine();
453    
454                final String wrapped = generateName("wrapped");
455    
456                body.appendLine("final Exception %1$s = %2$s.getCause();", wrapped, e);
457                for (final CtClass thrown : method.getExceptionTypes()) {
458                    body.startBlock("if (%1$s instanceof %2$s)", wrapped, thrown.getName());
459                    body.appendLine("throw (%2$s) %1$s;", wrapped, thrown.getName());
460                    body.endBlock();
461                }
462                body.appendLine(
463                    "throw %1$s instanceof RuntimeException ? (RuntimeException) %1$s : new RuntimeException(%1$s);",
464                    wrapped);
465                body.endBlock();
466            }
467    
468            if (policy.isConditional()) {
469                // close if block we opened before:
470                body.endBlock();
471                // no security manager=> just call impl:
472                if (!isVoid) {
473                    body.append("return ");
474                }
475                body.appendLine("%s($$);", impl.getName());
476            }
477    
478            final String block = body.complete().toString();
479            debug("Setting body of %s to:\n%s", toString(method), block);
480            method.setBody(block);
481            return true;
482        }
483    
484        private void reportSettings() {
485            if (!settingsReported) {
486                settingsReported = true;
487                info("Weave policy == %s", policy);
488            }
489        }
490    }