001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing,
013 * software distributed under the License is distributed on an
014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 * KIND, either express or implied.  See the License for the
016 * specific language governing permissions and limitations
017 * under the License.
018 */
019package org.apache.commons.jcs3.io;
020
021import java.io.IOException;
022import java.io.InputStream;
023import java.io.ObjectInputStream;
024import java.io.ObjectStreamClass;
025import java.lang.reflect.Proxy;
026
027public class ObjectInputStreamClassLoaderAware extends ObjectInputStream {
028    private final ClassLoader classLoader;
029
030    public ObjectInputStreamClassLoaderAware(final InputStream in, final ClassLoader classLoader) throws IOException {
031        super(in);
032        this.classLoader = classLoader != null ? classLoader : Thread.currentThread().getContextClassLoader();
033    }
034
035    @Override
036    protected Class<?> resolveClass(final ObjectStreamClass desc) throws ClassNotFoundException {
037        return Class.forName(BlacklistClassResolver.DEFAULT.check(desc.getName()), false, classLoader);
038    }
039
040    @Override
041    protected Class<?> resolveProxyClass(final String[] interfaces) throws IOException, ClassNotFoundException {
042        final Class<?>[] cinterfaces = new Class[interfaces.length];
043        for (int i = 0; i < interfaces.length; i++) {
044            cinterfaces[i] = Class.forName(interfaces[i], false, classLoader);
045        }
046
047        try {
048            return Proxy.getProxyClass(classLoader, cinterfaces);
049        } catch (final IllegalArgumentException e) {
050            throw new ClassNotFoundException(null, e);
051        }
052    }
053
054    private static class BlacklistClassResolver {
055        private static final BlacklistClassResolver DEFAULT = new BlacklistClassResolver(
056            toArray(System.getProperty(
057                "jcs.serialization.class.blacklist",
058                "org.codehaus.groovy.runtime.,org.apache.commons.collections.functors.,org.apache.xalan")),
059            toArray(System.getProperty("jcs.serialization.class.whitelist")));
060
061        private final String[] blacklist;
062        private final String[] whitelist;
063
064        protected BlacklistClassResolver(final String[] blacklist, final String[] whitelist) {
065            this.whitelist = whitelist;
066            this.blacklist = blacklist;
067        }
068
069        protected boolean isBlacklisted(final String name) {
070            return whitelist != null && !contains(whitelist, name) || contains(blacklist, name);
071        }
072
073        public final String check(final String name) {
074            if (isBlacklisted(name)) {
075                throw new SecurityException(name + " is not whitelisted as deserialisable, prevented before loading.");
076            }
077            return name;
078        }
079
080        private static String[] toArray(final String property) {
081            return property == null ? null : property.split(" *, *");
082        }
083
084        private static boolean contains(final String[] list, final String name) {
085            if (list != null) {
086                for (final String white : list) {
087                    if (name.startsWith(white)) {
088                        return true;
089                    }
090                }
091            }
092            return false;
093        }
094    }
095}