View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.io.channels;
19  
20  import java.lang.reflect.InvocationHandler;
21  import java.lang.reflect.InvocationTargetException;
22  import java.lang.reflect.Method;
23  import java.lang.reflect.Proxy;
24  import java.nio.channels.AsynchronousChannel;
25  import java.nio.channels.ByteChannel;
26  import java.nio.channels.Channel;
27  import java.nio.channels.ClosedChannelException;
28  import java.nio.channels.GatheringByteChannel;
29  import java.nio.channels.InterruptibleChannel;
30  import java.nio.channels.NetworkChannel;
31  import java.nio.channels.ReadableByteChannel;
32  import java.nio.channels.ScatteringByteChannel;
33  import java.nio.channels.SeekableByteChannel;
34  import java.nio.channels.WritableByteChannel;
35  import java.util.Collections;
36  import java.util.HashSet;
37  import java.util.Objects;
38  import java.util.Set;
39  
40  final class CloseShieldChannelHandler implements InvocationHandler {
41  
42      private static final Set<Class<? extends Channel>> SUPPORTED_INTERFACES;
43  
44      static {
45          final Set<Class<? extends Channel>> interfaces = new HashSet<>();
46          interfaces.add(AsynchronousChannel.class);
47          interfaces.add(ByteChannel.class);
48          interfaces.add(Channel.class);
49          interfaces.add(GatheringByteChannel.class);
50          interfaces.add(InterruptibleChannel.class);
51          interfaces.add(NetworkChannel.class);
52          interfaces.add(ReadableByteChannel.class);
53          interfaces.add(ScatteringByteChannel.class);
54          interfaces.add(SeekableByteChannel.class);
55          interfaces.add(WritableByteChannel.class);
56          SUPPORTED_INTERFACES = Collections.unmodifiableSet(interfaces);
57      }
58  
59      /**
60       * Tests whether the given method is allowed to be called after the shield is closed.
61       *
62       * @param declaringClass The class declaring the method.
63       * @param name           The method name.
64       * @param parameterCount The number of parameters.
65       * @return {@code true} if the method is allowed after {@code close()}, {@code false} otherwise.
66       */
67      private static boolean isAllowedAfterClose(final Class<?> declaringClass, final String name, final int parameterCount) {
68          // JDK explicitly allows NetworkChannel.supportedOptions() post-close
69          return parameterCount == 0 && name.equals("supportedOptions") && NetworkChannel.class.equals(declaringClass);
70      }
71  
72      static boolean isSupported(final Class<?> interfaceClass) {
73          return SUPPORTED_INTERFACES.contains(interfaceClass);
74      }
75  
76      /**
77       * Tests whether the given method returns 'this' (the channel) as per JDK spec.
78       *
79       * @param declaringClass The class declaring the method.
80       * @param name           The method name.
81       * @param parameterCount The number of parameters.
82       * @return {@code true} if the method returns 'this', {@code false} otherwise.
83       */
84      private static boolean returnsThis(final Class<?> declaringClass, final String name, final int parameterCount) {
85          if (SeekableByteChannel.class.equals(declaringClass)) {
86              // SeekableByteChannel.position(long) and truncate(long) return 'this'
87              return parameterCount == 1 && (name.equals("position") || name.equals("truncate"));
88          }
89          if (NetworkChannel.class.equals(declaringClass)) {
90              // NetworkChannel.bind and NetworkChannel.setOption returns 'this'
91              return parameterCount == 1 && name.equals("bind") || parameterCount == 2 && name.equals("setOption");
92          }
93          return false;
94      }
95  
96      private final Channel delegate;
97      private volatile boolean closed;
98  
99      CloseShieldChannelHandler(final Channel delegate) {
100         this.delegate = Objects.requireNonNull(delegate, "delegate");
101     }
102 
103     @Override
104     public Object invoke(final Object proxy, final Method method, final Object[] args) throws Throwable {
105         final Class<?> declaringClass = method.getDeclaringClass();
106         final String name = method.getName();
107         final int parameterCount = method.getParameterCount();
108         // 1) java.lang.Object methods
109         if (declaringClass == Object.class) {
110             return invokeObjectMethod(proxy, method, args);
111         }
112         // 2) Channel.close(): mark shield closed, do NOT close the delegate
113         if (parameterCount == 0 && name.equals("close")) {
114             closed = true;
115             return null;
116         }
117         // 3) Channel.isOpen(): reflect shield state only
118         if (parameterCount == 0 && name.equals("isOpen")) {
119             return !closed && delegate.isOpen();
120         }
121         // 4) After the shield is closed, only allow a tiny allowlist of safe queries
122         if (closed && !isAllowedAfterClose(declaringClass, name, parameterCount)) {
123             throw new ClosedChannelException();
124         }
125         // 5) Delegate to the underlying channel and unwrap target exceptions
126         try {
127             final Object result = method.invoke(delegate, args);
128             return returnsThis(declaringClass, name, parameterCount) ? proxy : result;
129         } catch (final InvocationTargetException e) {
130             throw e.getCause();
131         }
132     }
133 
134     private Object invokeObjectMethod(final Object proxy, final Method method, final Object[] args) {
135         switch (method.getName()) {
136         case "toString":
137             return "CloseShieldChannel(" + delegate + ")";
138         case "hashCode":
139             return Objects.hashCode(delegate);
140         case "equals": {
141             final Object other = args[0];
142             if (other == null) {
143                 return false;
144             }
145             if (proxy == other) {
146                 return true;
147             }
148             if (Proxy.isProxyClass(other.getClass())) {
149                 final InvocationHandler h = Proxy.getInvocationHandler(other);
150                 if (h instanceof CloseShieldChannelHandler) {
151                     return Objects.equals(((CloseShieldChannelHandler) h).delegate, this.delegate);
152                 }
153             }
154             return false;
155         }
156         default:
157             // Not possible, all non-final Object methods are handled above
158             return null;
159         }
160     }
161 }