001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *   https://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing,
013 * software distributed under the License is distributed on an
014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 * KIND, either express or implied.  See the License for the
016 * specific language governing permissions and limitations
017 * under the License.
018 */
019package org.apache.commons.compress.compressors.lz4;
020
021import java.io.IOException;
022import java.io.InputStream;
023import java.util.Arrays;
024import java.util.zip.CheckedInputStream;
025
026import org.apache.commons.compress.compressors.CompressorInputStream;
027import org.apache.commons.compress.utils.ByteUtils;
028import org.apache.commons.compress.utils.IOUtils;
029import org.apache.commons.compress.utils.InputStreamStatistics;
030import org.apache.commons.io.input.BoundedInputStream;
031
032/**
033 * CompressorInputStream for the LZ4 frame format.
034 *
035 * <p>
036 * Based on the "spec" in the version "1.5.1 (31/03/2015)"
037 * </p>
038 *
039 * @see <a href="https://lz4.github.io/lz4/lz4_Frame_format.html">LZ4 Frame Format Description</a>
040 * @since 1.14
041 * @NotThreadSafe
042 */
043public class FramedLZ4CompressorInputStream extends CompressorInputStream implements InputStreamStatistics {
044
045    /** Used by FramedLZ4CompressorOutputStream as well. */
046    static final byte[] LZ4_SIGNATURE = { 4, 0x22, 0x4d, 0x18 };
047    private static final byte[] SKIPPABLE_FRAME_TRAILER = { 0x2a, 0x4d, 0x18 };
048    private static final byte SKIPPABLE_FRAME_PREFIX_BYTE_MASK = 0x50;
049
050    static final int VERSION_MASK = 0xC0;
051    static final int SUPPORTED_VERSION = 0x40;
052    static final int BLOCK_INDEPENDENCE_MASK = 0x20;
053    static final int BLOCK_CHECKSUM_MASK = 0x10;
054    static final int CONTENT_SIZE_MASK = 0x08;
055    static final int CONTENT_CHECKSUM_MASK = 0x04;
056    static final int BLOCK_MAX_SIZE_MASK = 0x70;
057    static final int UNCOMPRESSED_FLAG_MASK = 0x80000000;
058
059    private static boolean isSkippableFrameSignature(final byte[] b) {
060        if ((b[0] & SKIPPABLE_FRAME_PREFIX_BYTE_MASK) != SKIPPABLE_FRAME_PREFIX_BYTE_MASK) {
061            return false;
062        }
063        for (int i = 1; i < 4; i++) {
064            if (b[i] != SKIPPABLE_FRAME_TRAILER[i - 1]) {
065                return false;
066            }
067        }
068        return true;
069    }
070
071    /**
072     * Checks if the signature matches what is expected for a .lz4 file.
073     * <p>
074     * .lz4 files start with a four byte signature.
075     * </p>
076     *
077     * @param signature the bytes to check
078     * @param length    the number of bytes to check
079     * @return true if this is a .sz stream, false otherwise
080     */
081    public static boolean matches(final byte[] signature, final int length) {
082
083        if (length < LZ4_SIGNATURE.length) {
084            return false;
085        }
086
087        byte[] shortenedSig = signature;
088        if (signature.length > LZ4_SIGNATURE.length) {
089            shortenedSig = Arrays.copyOf(signature, LZ4_SIGNATURE.length);
090        }
091
092        return Arrays.equals(shortenedSig, LZ4_SIGNATURE);
093    }
094
095    /** Used in no-arg read method. */
096    private final byte[] oneByte = new byte[1];
097    private final ByteUtils.ByteSupplier supplier = this::readOneByte;
098
099    private final BoundedInputStream inputStream;
100    private final boolean decompressConcatenated;
101    private boolean expectBlockChecksum;
102    private boolean expectBlockDependency;
103
104    private boolean expectContentChecksum;
105
106    private InputStream currentBlock;
107
108    private boolean endReached;
109    private boolean inUncompressed;
110
111    /** Used for frame header checksum and content checksum, if present. */
112    private final org.apache.commons.codec.digest.XXHash32 contentHash = new org.apache.commons.codec.digest.XXHash32();
113
114    /** Used for block checksum, if present. */
115    private final org.apache.commons.codec.digest.XXHash32 blockHash = new org.apache.commons.codec.digest.XXHash32();
116
117    /** Only created if the frame doesn't set the block independence flag. */
118    private byte[] blockDependencyBuffer;
119
120    /**
121     * Creates a new input stream that decompresses streams compressed using the LZ4 frame format and stops after decompressing the first frame.
122     *
123     * @param in the InputStream from which to read the compressed data
124     * @throws IOException if reading fails
125     */
126    public FramedLZ4CompressorInputStream(final InputStream in) throws IOException {
127        this(in, false);
128    }
129
130    /**
131     * Creates a new input stream that decompresses streams compressed using the LZ4 frame format.
132     *
133     * @param in                     the InputStream from which to read the compressed data
134     * @param decompressConcatenated if true, decompress until the end of the input; if false, stop after the first LZ4 frame and leave the input position to
135     *                               point to the next byte after the frame stream
136     * @throws IOException if reading fails
137     */
138    public FramedLZ4CompressorInputStream(final InputStream in, final boolean decompressConcatenated) throws IOException {
139        this.inputStream = BoundedInputStream.builder().setInputStream(in).get();
140        this.decompressConcatenated = decompressConcatenated;
141        init(true);
142    }
143
144    private void appendToBlockDependencyBuffer(final byte[] b, final int off, int len) {
145        len = Math.min(len, blockDependencyBuffer.length);
146        if (len > 0) {
147            final int keep = blockDependencyBuffer.length - len;
148            if (keep > 0) {
149                // move last keep bytes towards the start of the buffer
150                System.arraycopy(blockDependencyBuffer, len, blockDependencyBuffer, 0, keep);
151            }
152            // append new data
153            System.arraycopy(b, off, blockDependencyBuffer, keep, len);
154        }
155    }
156
157    /** {@inheritDoc} */
158    @Override
159    public void close() throws IOException {
160        try {
161            org.apache.commons.io.IOUtils.close(currentBlock);
162            currentBlock = null;
163        } finally {
164            inputStream.close();
165        }
166    }
167
168    /**
169     * @since 1.17
170     */
171    @Override
172    public long getCompressedCount() {
173        return inputStream.getCount();
174    }
175
176    private void init(final boolean firstFrame) throws IOException {
177        if (readSignature(firstFrame)) {
178            readFrameDescriptor();
179            nextBlock();
180        }
181    }
182
183    private void maybeFinishCurrentBlock() throws IOException {
184        if (currentBlock != null) {
185            currentBlock.close();
186            currentBlock = null;
187            if (expectBlockChecksum) {
188                verifyChecksum(blockHash, "block");
189                blockHash.reset();
190            }
191        }
192    }
193
194    private void nextBlock() throws IOException {
195        maybeFinishCurrentBlock();
196        final long len = ByteUtils.fromLittleEndian(supplier, 4);
197        final boolean uncompressed = (len & UNCOMPRESSED_FLAG_MASK) != 0;
198        final int realLen = (int) (len & ~UNCOMPRESSED_FLAG_MASK);
199        if (realLen == 0) {
200            verifyContentChecksum();
201            if (!decompressConcatenated) {
202                endReached = true;
203            } else {
204                init(false);
205            }
206            return;
207        }
208        // @formatter:off
209        InputStream capped = BoundedInputStream.builder()
210                .setInputStream(inputStream)
211                .setMaxCount(realLen)
212                .setPropagateClose(false)
213                .get();
214        // @formatter:on
215        if (expectBlockChecksum) {
216            capped = new CheckedInputStream(capped, blockHash);
217        }
218        if (uncompressed) {
219            inUncompressed = true;
220            currentBlock = capped;
221        } else {
222            inUncompressed = false;
223            final BlockLZ4CompressorInputStream s = new BlockLZ4CompressorInputStream(capped);
224            if (expectBlockDependency) {
225                s.prefill(blockDependencyBuffer);
226            }
227            currentBlock = s;
228        }
229    }
230
231    /** {@inheritDoc} */
232    @Override
233    public int read() throws IOException {
234        return read(oneByte, 0, 1) == -1 ? -1 : oneByte[0] & 0xFF;
235    }
236
237    /** {@inheritDoc} */
238    @Override
239    public int read(final byte[] b, final int off, final int len) throws IOException {
240        if (len == 0) {
241            return 0;
242        }
243        if (endReached) {
244            return -1;
245        }
246        int r = readOnce(b, off, len);
247        if (r == -1) {
248            nextBlock();
249            if (!endReached) {
250                r = readOnce(b, off, len);
251            }
252        }
253        if (r != -1) {
254            if (expectBlockDependency) {
255                appendToBlockDependencyBuffer(b, off, r);
256            }
257            if (expectContentChecksum) {
258                contentHash.update(b, off, r);
259            }
260        }
261        return r;
262    }
263
264    private void readFrameDescriptor() throws IOException {
265        final int flags = readOneByte();
266        if (flags == -1) {
267            throw new IOException("Premature end of stream while reading frame flags");
268        }
269        contentHash.update(flags);
270        if ((flags & VERSION_MASK) != SUPPORTED_VERSION) {
271            throw new IOException("Unsupported version " + (flags >> 6));
272        }
273        expectBlockDependency = (flags & BLOCK_INDEPENDENCE_MASK) == 0;
274        if (expectBlockDependency) {
275            if (blockDependencyBuffer == null) {
276                blockDependencyBuffer = new byte[BlockLZ4CompressorInputStream.WINDOW_SIZE];
277            }
278        } else {
279            blockDependencyBuffer = null;
280        }
281        expectBlockChecksum = (flags & BLOCK_CHECKSUM_MASK) != 0;
282        final boolean expectContentSize = (flags & CONTENT_SIZE_MASK) != 0;
283        expectContentChecksum = (flags & CONTENT_CHECKSUM_MASK) != 0;
284        final int bdByte = readOneByte();
285        if (bdByte == -1) { // max size is irrelevant for this implementation
286            throw new IOException("Premature end of stream while reading frame BD byte");
287        }
288        contentHash.update(bdByte);
289        if (expectContentSize) { // for now, we don't care, contains the uncompressed size
290            final byte[] contentSize = new byte[8];
291            final int skipped = IOUtils.readFully(inputStream, contentSize);
292            count(skipped);
293            if (8 != skipped) {
294                throw new IOException("Premature end of stream while reading content size");
295            }
296            contentHash.update(contentSize, 0, contentSize.length);
297        }
298        final int headerHash = readOneByte();
299        if (headerHash == -1) { // partial hash of header.
300            throw new IOException("Premature end of stream while reading frame header checksum");
301        }
302        final int expectedHash = (int) (contentHash.getValue() >> 8 & 0xff);
303        contentHash.reset();
304        if (headerHash != expectedHash) {
305            throw new IOException("Frame header checksum mismatch");
306        }
307    }
308
309    private int readOnce(final byte[] b, final int off, final int len) throws IOException {
310        if (inUncompressed) {
311            final int cnt = currentBlock.read(b, off, len);
312            count(cnt);
313            return cnt;
314        }
315        final BlockLZ4CompressorInputStream l = (BlockLZ4CompressorInputStream) currentBlock;
316        final long before = l.getBytesRead();
317        final int cnt = currentBlock.read(b, off, len);
318        count(l.getBytesRead() - before);
319        return cnt;
320    }
321
322    private int readOneByte() throws IOException {
323        final int b = inputStream.read();
324        if (b != -1) {
325            count(1);
326            return b & 0xFF;
327        }
328        return -1;
329    }
330
331    private boolean readSignature(final boolean firstFrame) throws IOException {
332        final String garbageMessage = firstFrame ? "Not a LZ4 frame stream" : "LZ4 frame stream followed by garbage";
333        final byte[] b = new byte[4];
334        int read = IOUtils.readFully(inputStream, b);
335        count(read);
336        if (0 == read && !firstFrame) {
337            // good LZ4 frame and nothing after it
338            endReached = true;
339            return false;
340        }
341        if (4 != read) {
342            throw new IOException(garbageMessage);
343        }
344
345        read = skipSkippableFrame(b);
346        if (0 == read && !firstFrame) {
347            // good LZ4 frame with only some skippable frames after it
348            endReached = true;
349            return false;
350        }
351        if (4 != read || !matches(b, 4)) {
352            throw new IOException(garbageMessage);
353        }
354        return true;
355    }
356
357    /**
358     * Skips over the contents of a skippable frame as well as skippable frames following it.
359     * <p>
360     * It then tries to read four more bytes which are supposed to hold an LZ4 signature and returns the number of bytes read while storing the bytes in the
361     * given array.
362     * </p>
363     */
364    private int skipSkippableFrame(final byte[] b) throws IOException {
365        int read = 4;
366        while (read == 4 && isSkippableFrameSignature(b)) {
367            final long len = ByteUtils.fromLittleEndian(supplier, 4);
368            if (len < 0) {
369                throw new IOException("Found illegal skippable frame with negative size");
370            }
371            final long skipped = org.apache.commons.io.IOUtils.skip(inputStream, len);
372            count(skipped);
373            if (len != skipped) {
374                throw new IOException("Premature end of stream while skipping frame");
375            }
376            read = IOUtils.readFully(inputStream, b);
377            count(read);
378        }
379        return read;
380    }
381
382    private void verifyChecksum(final org.apache.commons.codec.digest.XXHash32 hash, final String kind) throws IOException {
383        final byte[] checksum = new byte[4];
384        final int read = IOUtils.readFully(inputStream, checksum);
385        count(read);
386        if (4 != read) {
387            throw new IOException("Premature end of stream while reading " + kind + " checksum");
388        }
389        final long expectedHash = hash.getValue();
390        if (expectedHash != ByteUtils.fromLittleEndian(checksum)) {
391            throw new IOException(kind + " checksum mismatch.");
392        }
393    }
394
395    private void verifyContentChecksum() throws IOException {
396        if (expectContentChecksum) {
397            verifyChecksum(contentHash, "content");
398        }
399        contentHash.reset();
400    }
401}