001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *   https://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing,
013 * software distributed under the License is distributed on an
014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 * KIND, either express or implied.  See the License for the
016 * specific language governing permissions and limitations
017 * under the License.
018 */
019
020package org.apache.commons.compress.utils;
021
022import java.io.File;
023import java.io.IOException;
024import java.nio.ByteBuffer;
025import java.nio.channels.ClosedChannelException;
026import java.nio.channels.NonWritableChannelException;
027import java.nio.channels.SeekableByteChannel;
028import java.nio.file.Files;
029import java.nio.file.Path;
030import java.nio.file.StandardOpenOption;
031import java.util.ArrayList;
032import java.util.Arrays;
033import java.util.Collections;
034import java.util.List;
035import java.util.Objects;
036
037/**
038 * Implements a read-only {@link SeekableByteChannel} that concatenates a collection of other {@link SeekableByteChannel}s.
039 * <p>
040 * This is a lose port of <a href=
041 * "https://github.com/frugalmechanic/fm-common/blob/master/jvm/src/main/scala/fm/common/MultiReadOnlySeekableByteChannel.scala">
042 * MultiReadOnlySeekableByteChannel</a>
043 * by Tim Underwood.
044 * </p>
045 *
046 * @since 1.19
047 */
048public class MultiReadOnlySeekableByteChannel implements SeekableByteChannel {
049
050    private static final Path[] EMPTY_PATH_ARRAY = {};
051
052    /**
053     * Concatenates the given files.
054     *
055     * @param files the files to concatenate
056     * @throws NullPointerException if files is null
057     * @throws IOException          if opening a channel for one of the files fails
058     * @return SeekableByteChannel that concatenates all provided files
059     */
060    public static SeekableByteChannel forFiles(final File... files) throws IOException {
061        final List<Path> paths = new ArrayList<>();
062        for (final File f : Objects.requireNonNull(files, "files")) {
063            paths.add(f.toPath());
064        }
065        return forPaths(paths.toArray(EMPTY_PATH_ARRAY));
066    }
067
068    /**
069     * Concatenates the given file paths.
070     *
071     * @param paths the file paths to concatenate, note that the LAST FILE of files should be the LAST SEGMENT(.zip) and these files should be added in correct
072     *              order (for example: .z01, .z02... .z99, .zip)
073     * @return SeekableByteChannel that concatenates all provided files
074     * @throws NullPointerException if files is null
075     * @throws IOException          if opening a channel for one of the files fails
076     * @throws IOException          if the first channel doesn't seem to hold the beginning of a split archive
077     * @since 1.22
078     */
079    public static SeekableByteChannel forPaths(final Path... paths) throws IOException {
080        final List<SeekableByteChannel> channels = new ArrayList<>();
081        for (final Path path : Objects.requireNonNull(paths, "paths")) {
082            channels.add(Files.newByteChannel(path, StandardOpenOption.READ));
083        }
084        if (channels.size() == 1) {
085            return channels.get(0);
086        }
087        return new MultiReadOnlySeekableByteChannel(channels);
088    }
089
090    /**
091     * Concatenates the given channels.
092     *
093     * @param channels the channels to concatenate
094     * @throws NullPointerException if channels is null
095     * @return SeekableByteChannel that concatenates all provided channels
096     */
097    public static SeekableByteChannel forSeekableByteChannels(final SeekableByteChannel... channels) {
098        if (Objects.requireNonNull(channels, "channels").length == 1) {
099            return channels[0];
100        }
101        return new MultiReadOnlySeekableByteChannel(Arrays.asList(channels));
102    }
103
104    private final List<SeekableByteChannel> channelList;
105
106    private long globalPosition;
107
108    private int currentChannelIdx;
109
110    /**
111     * Concatenates the given channels.
112     *
113     * @param channels the channels to concatenate
114     * @throws NullPointerException if channels is null
115     */
116    public MultiReadOnlySeekableByteChannel(final List<SeekableByteChannel> channels) {
117        this.channelList = Collections.unmodifiableList(new ArrayList<>(Objects.requireNonNull(channels, "channels")));
118    }
119
120    @Override
121    public void close() throws IOException {
122        IOException first = null;
123        for (final SeekableByteChannel ch : channelList) {
124            try {
125                ch.close();
126            } catch (final IOException ex) {
127                if (first == null) {
128                    first = ex;
129                }
130            }
131        }
132        if (first != null) {
133            throw new IOException("failed to close wrapped channel", first);
134        }
135    }
136
137    @Override
138    public boolean isOpen() {
139        return channelList.stream().allMatch(SeekableByteChannel::isOpen);
140    }
141
142    /**
143     * Gets this channel's position.
144     * <p>
145     * This method violates the contract of {@link SeekableByteChannel#position()} as it will not throw any exception when invoked on a closed channel. Instead
146     * it will return the position the channel had when close has been called.
147     * </p>
148     */
149    @Override
150    public long position() {
151        return globalPosition;
152    }
153
154    @Override
155    public synchronized SeekableByteChannel position(final long newPosition) throws IOException {
156        if (newPosition < 0) {
157            throw new IllegalArgumentException("Negative position: " + newPosition);
158        }
159        if (!isOpen()) {
160            throw new ClosedChannelException();
161        }
162        globalPosition = newPosition;
163        long pos = newPosition;
164        for (int i = 0; i < channelList.size(); i++) {
165            final SeekableByteChannel currentChannel = channelList.get(i);
166            final long size = currentChannel.size();
167
168            final long newChannelPos;
169            if (pos == -1L) {
170                // Position is already set for the correct channel,
171                // the rest of the channels get reset to 0
172                newChannelPos = 0;
173            } else if (pos <= size) {
174                // This channel is where we want to be
175                currentChannelIdx = i;
176                final long tmp = pos;
177                pos = -1L; // Mark pos as already being set
178                newChannelPos = tmp;
179            } else {
180                // newPosition is past this channel. Set channel
181                // position to the end and substract channel size from
182                // pos
183                pos -= size;
184                newChannelPos = size;
185            }
186            currentChannel.position(newChannelPos);
187        }
188        return this;
189    }
190
191    /**
192     * Sets the position based on the given channel number and relative offset
193     *
194     * @param channelNumber  the channel number
195     * @param relativeOffset the relative offset in the corresponding channel
196     * @return global position of all channels as if they are a single channel
197     * @throws IOException if positioning fails
198     */
199    public synchronized SeekableByteChannel position(final long channelNumber, final long relativeOffset) throws IOException {
200        if (!isOpen()) {
201            throw new ClosedChannelException();
202        }
203        long globalPosition = relativeOffset;
204        for (int i = 0; i < channelNumber; i++) {
205            globalPosition += channelList.get(i).size();
206        }
207
208        return position(globalPosition);
209    }
210
211    @Override
212    public synchronized int read(final ByteBuffer dst) throws IOException {
213        if (!isOpen()) {
214            throw new ClosedChannelException();
215        }
216        if (!dst.hasRemaining()) {
217            return 0;
218        }
219
220        int totalBytesRead = 0;
221        while (dst.hasRemaining() && currentChannelIdx < channelList.size()) {
222            final SeekableByteChannel currentChannel = channelList.get(currentChannelIdx);
223            final int newBytesRead = currentChannel.read(dst);
224            if (newBytesRead == -1) {
225                // EOF for this channel -- advance to next channel idx
226                currentChannelIdx += 1;
227                continue;
228            }
229            if (currentChannel.position() >= currentChannel.size()) {
230                // we are at the end of the current channel
231                currentChannelIdx++;
232            }
233            totalBytesRead += newBytesRead;
234        }
235        if (totalBytesRead > 0) {
236            globalPosition += totalBytesRead;
237            return totalBytesRead;
238        }
239        return -1;
240    }
241
242    @Override
243    public long size() throws IOException {
244        if (!isOpen()) {
245            throw new ClosedChannelException();
246        }
247        long acc = 0;
248        for (final SeekableByteChannel ch : channelList) {
249            acc += ch.size();
250        }
251        return acc;
252    }
253
254    /**
255     * @throws NonWritableChannelException since this implementation is read-only.
256     */
257    @Override
258    public SeekableByteChannel truncate(final long size) {
259        throw new NonWritableChannelException();
260    }
261
262    /**
263     * @throws NonWritableChannelException since this implementation is read-only.
264     */
265    @Override
266    public int write(final ByteBuffer src) {
267        throw new NonWritableChannelException();
268    }
269
270}