1   
2   
3   
4   
5   
6   
7   
8   
9   
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  package org.apache.commons.compress.utils;
21  
22  import java.io.File;
23  import java.io.IOException;
24  import java.nio.ByteBuffer;
25  import java.nio.channels.ClosedChannelException;
26  import java.nio.channels.NonWritableChannelException;
27  import java.nio.channels.SeekableByteChannel;
28  import java.nio.file.Files;
29  import java.nio.file.Path;
30  import java.nio.file.StandardOpenOption;
31  import java.util.ArrayList;
32  import java.util.Arrays;
33  import java.util.Collections;
34  import java.util.List;
35  import java.util.Objects;
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  public class MultiReadOnlySeekableByteChannel implements SeekableByteChannel {
49  
50      private static final Path[] EMPTY_PATH_ARRAY = {};
51  
52      
53  
54  
55  
56  
57  
58  
59  
60      public static SeekableByteChannel forFiles(final File... files) throws IOException {
61          final List<Path> paths = new ArrayList<>();
62          for (final File f : Objects.requireNonNull(files, "files")) {
63              paths.add(f.toPath());
64          }
65          return forPaths(paths.toArray(EMPTY_PATH_ARRAY));
66      }
67  
68      
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79      public static SeekableByteChannel forPaths(final Path... paths) throws IOException {
80          final List<SeekableByteChannel> channels = new ArrayList<>();
81          for (final Path path : Objects.requireNonNull(paths, "paths")) {
82              channels.add(Files.newByteChannel(path, StandardOpenOption.READ));
83          }
84          if (channels.size() == 1) {
85              return channels.get(0);
86          }
87          return new MultiReadOnlySeekableByteChannel(channels);
88      }
89  
90      
91  
92  
93  
94  
95  
96  
97      public static SeekableByteChannel forSeekableByteChannels(final SeekableByteChannel... channels) {
98          if (Objects.requireNonNull(channels, "channels").length == 1) {
99              return channels[0];
100         }
101         return new MultiReadOnlySeekableByteChannel(Arrays.asList(channels));
102     }
103 
104     private final List<SeekableByteChannel> channelList;
105 
106     private long globalPosition;
107 
108     private int currentChannelIdx;
109 
110     
111 
112 
113 
114 
115 
116     public MultiReadOnlySeekableByteChannel(final List<SeekableByteChannel> channels) {
117         this.channelList = Collections.unmodifiableList(new ArrayList<>(Objects.requireNonNull(channels, "channels")));
118     }
119 
120     @Override
121     public void close() throws IOException {
122         IOException first = null;
123         for (final SeekableByteChannel ch : channelList) {
124             try {
125                 ch.close();
126             } catch (final IOException ex) {
127                 if (first == null) {
128                     first = ex;
129                 }
130             }
131         }
132         if (first != null) {
133             throw new IOException("failed to close wrapped channel", first);
134         }
135     }
136 
137     @Override
138     public boolean isOpen() {
139         return channelList.stream().allMatch(SeekableByteChannel::isOpen);
140     }
141 
142     
143 
144 
145 
146 
147 
148 
149     @Override
150     public long position() {
151         return globalPosition;
152     }
153 
154     @Override
155     public synchronized SeekableByteChannel position(final long newPosition) throws IOException {
156         if (newPosition < 0) {
157             throw new IllegalArgumentException("Negative position: " + newPosition);
158         }
159         if (!isOpen()) {
160             throw new ClosedChannelException();
161         }
162         globalPosition = newPosition;
163         long pos = newPosition;
164         for (int i = 0; i < channelList.size(); i++) {
165             final SeekableByteChannel currentChannel = channelList.get(i);
166             final long size = currentChannel.size();
167 
168             final long newChannelPos;
169             if (pos == -1L) {
170                 
171                 
172                 newChannelPos = 0;
173             } else if (pos <= size) {
174                 
175                 currentChannelIdx = i;
176                 final long tmp = pos;
177                 pos = -1L; 
178                 newChannelPos = tmp;
179             } else {
180                 
181                 
182                 
183                 pos -= size;
184                 newChannelPos = size;
185             }
186             currentChannel.position(newChannelPos);
187         }
188         return this;
189     }
190 
191     
192 
193 
194 
195 
196 
197 
198 
199     public synchronized SeekableByteChannel position(final long channelNumber, final long relativeOffset) throws IOException {
200         if (!isOpen()) {
201             throw new ClosedChannelException();
202         }
203         long globalPosition = relativeOffset;
204         for (int i = 0; i < channelNumber; i++) {
205             globalPosition += channelList.get(i).size();
206         }
207 
208         return position(globalPosition);
209     }
210 
211     @Override
212     public synchronized int read(final ByteBuffer dst) throws IOException {
213         if (!isOpen()) {
214             throw new ClosedChannelException();
215         }
216         if (!dst.hasRemaining()) {
217             return 0;
218         }
219 
220         int totalBytesRead = 0;
221         while (dst.hasRemaining() && currentChannelIdx < channelList.size()) {
222             final SeekableByteChannel currentChannel = channelList.get(currentChannelIdx);
223             final int newBytesRead = currentChannel.read(dst);
224             if (newBytesRead == -1) {
225                 
226                 currentChannelIdx += 1;
227                 continue;
228             }
229             if (currentChannel.position() >= currentChannel.size()) {
230                 
231                 currentChannelIdx++;
232             }
233             totalBytesRead += newBytesRead;
234         }
235         if (totalBytesRead > 0) {
236             globalPosition += totalBytesRead;
237             return totalBytesRead;
238         }
239         return -1;
240     }
241 
242     @Override
243     public long size() throws IOException {
244         if (!isOpen()) {
245             throw new ClosedChannelException();
246         }
247         long acc = 0;
248         for (final SeekableByteChannel ch : channelList) {
249             acc += ch.size();
250         }
251         return acc;
252     }
253 
254     
255 
256 
257     @Override
258     public SeekableByteChannel truncate(final long size) {
259         throw new NonWritableChannelException();
260     }
261 
262     
263 
264 
265     @Override
266     public int write(final ByteBuffer src) {
267         throw new NonWritableChannelException();
268     }
269 
270 }