HuffmanDecoder.java

  1. /*
  2.  *  Licensed to the Apache Software Foundation (ASF) under one or more
  3.  *  contributor license agreements.  See the NOTICE file distributed with
  4.  *  this work for additional information regarding copyright ownership.
  5.  *  The ASF licenses this file to You under the Apache License, Version 2.0
  6.  *  (the "License"); you may not use this file except in compliance with
  7.  *  the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  *  Unless required by applicable law or agreed to in writing, software
  12.  *  distributed under the License is distributed on an "AS IS" BASIS,
  13.  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  *  See the License for the specific language governing permissions and
  15.  *  limitations under the License.
  16.  */
  17. package org.apache.commons.compress.compressors.deflate64;

  18. import static org.apache.commons.compress.compressors.deflate64.HuffmanState.DYNAMIC_CODES;
  19. import static org.apache.commons.compress.compressors.deflate64.HuffmanState.FIXED_CODES;
  20. import static org.apache.commons.compress.compressors.deflate64.HuffmanState.INITIAL;
  21. import static org.apache.commons.compress.compressors.deflate64.HuffmanState.STORED;

  22. import java.io.Closeable;
  23. import java.io.EOFException;
  24. import java.io.IOException;
  25. import java.io.InputStream;
  26. import java.nio.ByteOrder;
  27. import java.util.Arrays;

  28. import org.apache.commons.compress.utils.BitInputStream;
  29. import org.apache.commons.compress.utils.ByteUtils;
  30. import org.apache.commons.compress.utils.ExactMath;
  31. import org.apache.commons.lang3.ArrayFill;

  32. /**
  33.  * TODO This class can't be final because it is mocked by Mockito.
  34.  */
  35. class HuffmanDecoder implements Closeable {

  36.     private static final class BinaryTreeNode {
  37.         private final int bits;
  38.         int literal = -1;
  39.         BinaryTreeNode leftNode;
  40.         BinaryTreeNode rightNode;

  41.         private BinaryTreeNode(final int bits) {
  42.             this.bits = bits;
  43.         }

  44.         void leaf(final int symbol) {
  45.             literal = symbol;
  46.             leftNode = null;
  47.             rightNode = null;
  48.         }

  49.         BinaryTreeNode left() {
  50.             if (leftNode == null && literal == -1) {
  51.                 leftNode = new BinaryTreeNode(bits + 1);
  52.             }
  53.             return leftNode;
  54.         }

  55.         BinaryTreeNode right() {
  56.             if (rightNode == null && literal == -1) {
  57.                 rightNode = new BinaryTreeNode(bits + 1);
  58.             }
  59.             return rightNode;
  60.         }
  61.     }

  62.     private abstract static class DecoderState {
  63.         abstract int available() throws IOException;

  64.         abstract boolean hasData();

  65.         abstract int read(byte[] b, int off, int len) throws IOException;

  66.         abstract HuffmanState state();
  67.     }

  68.     private static final class DecodingMemory {
  69.         private final byte[] memory;
  70.         private final int mask;
  71.         private int wHead;
  72.         private boolean wrappedAround;

  73.         private DecodingMemory() {
  74.             this(16);
  75.         }

  76.         private DecodingMemory(final int bits) {
  77.             memory = new byte[1 << bits];
  78.             mask = memory.length - 1;
  79.         }

  80.         byte add(final byte b) {
  81.             memory[wHead] = b;
  82.             wHead = incCounter(wHead);
  83.             return b;
  84.         }

  85.         void add(final byte[] b, final int off, final int len) {
  86.             for (int i = off; i < off + len; i++) {
  87.                 add(b[i]);
  88.             }
  89.         }

  90.         private int incCounter(final int counter) {
  91.             final int newCounter = counter + 1 & mask;
  92.             if (!wrappedAround && newCounter < counter) {
  93.                 wrappedAround = true;
  94.             }
  95.             return newCounter;
  96.         }

  97.         void recordToBuffer(final int distance, final int length, final byte[] buff) {
  98.             if (distance > memory.length) {
  99.                 throw new IllegalStateException("Illegal distance parameter: " + distance);
  100.             }
  101.             final int start = wHead - distance & mask;
  102.             if (!wrappedAround && start >= wHead) {
  103.                 throw new IllegalStateException("Attempt to read beyond memory: dist=" + distance);
  104.             }
  105.             for (int i = 0, pos = start; i < length; i++, pos = incCounter(pos)) {
  106.                 buff[i] = add(memory[pos]);
  107.             }
  108.         }
  109.     }

  110.     private final class HuffmanCodes extends DecoderState {
  111.         private boolean endOfBlock;
  112.         private final HuffmanState state;
  113.         private final BinaryTreeNode lengthTree;
  114.         private final BinaryTreeNode distanceTree;

  115.         private int runBufferPos;
  116.         private byte[] runBuffer = ByteUtils.EMPTY_BYTE_ARRAY;
  117.         private int runBufferLength;

  118.         HuffmanCodes(final HuffmanState state, final int[] lengths, final int[] distance) {
  119.             this.state = state;
  120.             lengthTree = buildTree(lengths);
  121.             distanceTree = buildTree(distance);
  122.         }

  123.         @Override
  124.         int available() {
  125.             return runBufferLength - runBufferPos;
  126.         }

  127.         private int copyFromRunBuffer(final byte[] b, final int off, final int len) {
  128.             final int bytesInBuffer = runBufferLength - runBufferPos;
  129.             int copiedBytes = 0;
  130.             if (bytesInBuffer > 0) {
  131.                 copiedBytes = Math.min(len, bytesInBuffer);
  132.                 System.arraycopy(runBuffer, runBufferPos, b, off, copiedBytes);
  133.                 runBufferPos += copiedBytes;
  134.             }
  135.             return copiedBytes;
  136.         }

  137.         private int decodeNext(final byte[] b, final int off, final int len) throws IOException {
  138.             if (endOfBlock) {
  139.                 return -1;
  140.             }
  141.             int result = copyFromRunBuffer(b, off, len);

  142.             while (result < len) {
  143.                 final int symbol = nextSymbol(reader, lengthTree);
  144.                 if (symbol < 256) {
  145.                     b[off + result++] = memory.add((byte) symbol);
  146.                 } else if (symbol > 256) {
  147.                     final int runMask = RUN_LENGTH_TABLE[symbol - 257];
  148.                     int run = runMask >>> 5;
  149.                     final int runXtra = runMask & 0x1F;
  150.                     run = ExactMath.add(run, readBits(runXtra));

  151.                     final int distSym = nextSymbol(reader, distanceTree);

  152.                     final int distMask = DISTANCE_TABLE[distSym];
  153.                     int dist = distMask >>> 4;
  154.                     final int distXtra = distMask & 0xF;
  155.                     dist = ExactMath.add(dist, readBits(distXtra));

  156.                     if (runBuffer.length < run) {
  157.                         runBuffer = new byte[run];
  158.                     }
  159.                     runBufferLength = run;
  160.                     runBufferPos = 0;
  161.                     memory.recordToBuffer(dist, run, runBuffer);

  162.                     result += copyFromRunBuffer(b, off + result, len - result);
  163.                 } else {
  164.                     endOfBlock = true;
  165.                     return result;
  166.                 }
  167.             }

  168.             return result;
  169.         }

  170.         @Override
  171.         boolean hasData() {
  172.             return !endOfBlock;
  173.         }

  174.         @Override
  175.         int read(final byte[] b, final int off, final int len) throws IOException {
  176.             if (len == 0) {
  177.                 return 0;
  178.             }
  179.             return decodeNext(b, off, len);
  180.         }

  181.         @Override
  182.         HuffmanState state() {
  183.             return endOfBlock ? INITIAL : state;
  184.         }
  185.     }

  186.     private static final class InitialState extends DecoderState {
  187.         @Override
  188.         int available() {
  189.             return 0;
  190.         }

  191.         @Override
  192.         boolean hasData() {
  193.             return false;
  194.         }

  195.         @Override
  196.         int read(final byte[] b, final int off, final int len) throws IOException {
  197.             if (len == 0) {
  198.                 return 0;
  199.             }
  200.             throw new IllegalStateException("Cannot read in this state");
  201.         }

  202.         @Override
  203.         HuffmanState state() {
  204.             return INITIAL;
  205.         }
  206.     }

  207.     private final class UncompressedState extends DecoderState {
  208.         private final long blockLength;
  209.         private long read;

  210.         private UncompressedState(final long blockLength) {
  211.             this.blockLength = blockLength;
  212.         }

  213.         @Override
  214.         int available() throws IOException {
  215.             return (int) Math.min(blockLength - read, reader.bitsAvailable() / Byte.SIZE);
  216.         }

  217.         @Override
  218.         boolean hasData() {
  219.             return read < blockLength;
  220.         }

  221.         @Override
  222.         int read(final byte[] b, final int off, final int len) throws IOException {
  223.             if (len == 0) {
  224.                 return 0;
  225.             }
  226.             // as len is an int and (blockLength - read) is >= 0 the min must fit into an int as well
  227.             final int max = (int) Math.min(blockLength - read, len);
  228.             int readSoFar = 0;
  229.             while (readSoFar < max) {
  230.                 final int readNow;
  231.                 if (reader.bitsCached() > 0) {
  232.                     final byte next = (byte) readBits(Byte.SIZE);
  233.                     b[off + readSoFar] = memory.add(next);
  234.                     readNow = 1;
  235.                 } else {
  236.                     readNow = in.read(b, off + readSoFar, max - readSoFar);
  237.                     if (readNow == -1) {
  238.                         throw new EOFException("Truncated Deflate64 Stream");
  239.                     }
  240.                     memory.add(b, off + readSoFar, readNow);
  241.                 }
  242.                 read += readNow;
  243.                 readSoFar += readNow;
  244.             }
  245.             return max;
  246.         }

  247.         @Override
  248.         HuffmanState state() {
  249.             return read < blockLength ? STORED : INITIAL;
  250.         }
  251.     }

  252.     /**
  253.      * <pre>
  254.      * --------------------------------------------------------------------
  255.      * idx  xtra  base     idx  xtra  base     idx  xtra  base
  256.      * --------------------------------------------------------------------
  257.      * 257   0     3       267   1   15,16     277   4   67-82
  258.      * 258   0     4       268   1   17,18     278   4   83-98
  259.      * 259   0     5       269   2   19-22     279   4   99-114
  260.      * 260   0     6       270   2   23-26     280   4   115-130
  261.      * 261   0     7       271   2   27-30     281   5   131-162
  262.      * 262   0     8       272   2   31-34     282   5   163-194
  263.      * 263   0     9       273   3   35-42     283   5   195-226
  264.      * 264   0     10      274   3   43-50     284   5   227-257
  265.      * 265   1     11,12   275   3   51-58     285   16  3
  266.      * 266   1     13,14   276   3   59-66
  267.      * --------------------------------------------------------------------
  268.      * </pre>
  269.      *
  270.      * value = (base of run length) << 5 | (number of extra bits to read)
  271.      */
  272.     private static final short[] RUN_LENGTH_TABLE = { 96, 128, 160, 192, 224, 256, 288, 320, 353, 417, 481, 545, 610, 738, 866, 994, 1123, 1379, 1635, 1891,
  273.             2148, 2660, 3172, 3684, 4197, 5221, 6245, 7269, 112 };
  274.     /**
  275.      * <pre>
  276.      * --------------------------------------------------------------------
  277.      * idx  xtra  dist     idx  xtra  dist       idx  xtra  dist
  278.      * --------------------------------------------------------------------
  279.      * 0    0     1        10   4     33-48      20    9   1025-1536
  280.      * 1    0     2        11   4     49-64      21    9   1537-2048
  281.      * 2    0     3        12   5     65-96      22   10   2049-3072
  282.      * 3    0     4        13   5     97-128     23   10   3073-4096
  283.      * 4    1     5,6      14   6     129-192    24   11   4097-6144
  284.      * 5    1     7,8      15   6     193-256    25   11   6145-8192
  285.      * 6    2     9-12     16   7     257-384    26   12   8193-12288
  286.      * 7    2     13-16    17   7     385-512    27   12   12289-16384
  287.      * 8    3     17-24    18   8     513-768    28   13   16385-24576
  288.      * 9    3     25-32    19   8     769-1024   29   13   24577-32768
  289.      * 30   14   32769-49152
  290.      * 31   14   49153-65536
  291.      * --------------------------------------------------------------------
  292.      * </pre>
  293.      *
  294.      * value = (base of distance) << 4 | (number of extra bits to read)
  295.      */
  296.     private static final int[] DISTANCE_TABLE = { 16, 32, 48, 64, 81, 113, 146, 210, 275, 403, // 0-9
  297.             532, 788, 1045, 1557, 2070, 3094, 4119, 6167, 8216, 12312, // 10-19
  298.             16409, 24601, 32794, 49178, 65563, 98331, 131100, 196636, 262173, 393245, // 20-29
  299.             524318, 786462 // 30-31
  300.     };
  301.     /**
  302.      * When using dynamic huffman codes the order in which the values are stored follows the positioning below
  303.      */
  304.     private static final int[] CODE_LENGTHS_ORDER = { 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15 };
  305.     /**
  306.      * Huffman Fixed Literal / Distance tables for mode 1
  307.      */
  308.     private static final int[] FIXED_LITERALS;

  309.     private static final int[] FIXED_DISTANCE;

  310.     static {
  311.         FIXED_LITERALS = new int[288];
  312.         Arrays.fill(FIXED_LITERALS, 0, 144, 8);
  313.         Arrays.fill(FIXED_LITERALS, 144, 256, 9);
  314.         Arrays.fill(FIXED_LITERALS, 256, 280, 7);
  315.         Arrays.fill(FIXED_LITERALS, 280, 288, 8);

  316.         FIXED_DISTANCE = ArrayFill.fill(new int[32], 5);
  317.     }

  318.     private static BinaryTreeNode buildTree(final int[] litTable) {
  319.         final int[] literalCodes = getCodes(litTable);

  320.         final BinaryTreeNode root = new BinaryTreeNode(0);

  321.         for (int i = 0; i < litTable.length; i++) {
  322.             final int len = litTable[i];
  323.             if (len != 0) {
  324.                 BinaryTreeNode node = root;
  325.                 final int lit = literalCodes[len - 1];
  326.                 for (int p = len - 1; p >= 0; p--) {
  327.                     final int bit = lit & 1 << p;
  328.                     node = bit == 0 ? node.left() : node.right();
  329.                     if (node == null) {
  330.                         throw new IllegalStateException("node doesn't exist in Huffman tree");
  331.                     }
  332.                 }
  333.                 node.leaf(i);
  334.                 literalCodes[len - 1]++;
  335.             }
  336.         }
  337.         return root;
  338.     }

  339.     private static int[] getCodes(final int[] litTable) {
  340.         int max = 0;
  341.         int[] blCount = new int[65];

  342.         for (final int aLitTable : litTable) {
  343.             if (aLitTable < 0 || aLitTable > 64) {
  344.                 throw new IllegalArgumentException("Invalid code " + aLitTable + " in literal table");
  345.             }
  346.             max = Math.max(max, aLitTable);
  347.             blCount[aLitTable]++;
  348.         }
  349.         blCount = Arrays.copyOf(blCount, max + 1);

  350.         int code = 0;
  351.         final int[] nextCode = new int[max + 1];
  352.         for (int i = 0; i <= max; i++) {
  353.             code = code + blCount[i] << 1;
  354.             nextCode[i] = code;
  355.         }

  356.         return nextCode;
  357.     }

  358.     private static int nextSymbol(final BitInputStream reader, final BinaryTreeNode tree) throws IOException {
  359.         BinaryTreeNode node = tree;
  360.         while (node != null && node.literal == -1) {
  361.             final long bit = readBits(reader, 1);
  362.             node = bit == 0 ? node.leftNode : node.rightNode;
  363.         }
  364.         return node != null ? node.literal : -1;
  365.     }

  366.     private static void populateDynamicTables(final BitInputStream reader, final int[] literals, final int[] distances) throws IOException {
  367.         final int codeLengths = (int) (readBits(reader, 4) + 4);

  368.         final int[] codeLengthValues = new int[19];
  369.         for (int cLen = 0; cLen < codeLengths; cLen++) {
  370.             codeLengthValues[CODE_LENGTHS_ORDER[cLen]] = (int) readBits(reader, 3);
  371.         }

  372.         final BinaryTreeNode codeLengthTree = buildTree(codeLengthValues);

  373.         final int[] auxBuffer = new int[literals.length + distances.length];

  374.         int value = -1;
  375.         int length = 0;
  376.         int off = 0;
  377.         while (off < auxBuffer.length) {
  378.             if (length > 0) {
  379.                 auxBuffer[off++] = value;
  380.                 length--;
  381.             } else {
  382.                 final int symbol = nextSymbol(reader, codeLengthTree);
  383.                 if (symbol < 16) {
  384.                     value = symbol;
  385.                     auxBuffer[off++] = value;
  386.                 } else {
  387.                     switch (symbol) {
  388.                     case 16:
  389.                         length = (int) (readBits(reader, 2) + 3);
  390.                         break;
  391.                     case 17:
  392.                         value = 0;
  393.                         length = (int) (readBits(reader, 3) + 3);
  394.                         break;
  395.                     case 18:
  396.                         value = 0;
  397.                         length = (int) (readBits(reader, 7) + 11);
  398.                         break;
  399.                     default:
  400.                         break;
  401.                     }
  402.                 }
  403.             }
  404.         }

  405.         System.arraycopy(auxBuffer, 0, literals, 0, literals.length);
  406.         System.arraycopy(auxBuffer, literals.length, distances, 0, distances.length);
  407.     }

  408.     private static long readBits(final BitInputStream reader, final int numBits) throws IOException {
  409.         final long r = reader.readBits(numBits);
  410.         if (r == -1) {
  411.             throw new EOFException("Truncated Deflate64 Stream");
  412.         }
  413.         return r;
  414.     }

  415.     private boolean finalBlock;

  416.     private DecoderState state;

  417.     private BitInputStream reader;

  418.     private final InputStream in;

  419.     private final DecodingMemory memory = new DecodingMemory();

  420.     HuffmanDecoder(final InputStream in) {
  421.         this.reader = new BitInputStream(in, ByteOrder.LITTLE_ENDIAN);
  422.         this.in = in;
  423.         state = new InitialState();
  424.     }

  425.     int available() throws IOException {
  426.         return state.available();
  427.     }

  428.     @Override
  429.     public void close() {
  430.         state = new InitialState();
  431.         reader = null;
  432.     }

  433.     public int decode(final byte[] b) throws IOException {
  434.         return decode(b, 0, b.length);
  435.     }

  436.     public int decode(final byte[] b, final int off, final int len) throws IOException {
  437.         while (!finalBlock || state.hasData()) {
  438.             if (state.state() == INITIAL) {
  439.                 finalBlock = readBits(1) == 1;
  440.                 final int mode = (int) readBits(2);
  441.                 switch (mode) {
  442.                 case 0:
  443.                     switchToUncompressedState();
  444.                     break;
  445.                 case 1:
  446.                     state = new HuffmanCodes(FIXED_CODES, FIXED_LITERALS, FIXED_DISTANCE);
  447.                     break;
  448.                 case 2:
  449.                     final int[][] tables = readDynamicTables();
  450.                     state = new HuffmanCodes(DYNAMIC_CODES, tables[0], tables[1]);
  451.                     break;
  452.                 default:
  453.                     throw new IllegalStateException("Unsupported compression: " + mode);
  454.                 }
  455.             } else {
  456.                 final int r = state.read(b, off, len);
  457.                 if (r != 0) {
  458.                     return r;
  459.                 }
  460.             }
  461.         }
  462.         return -1;
  463.     }

  464.     /**
  465.      * @since 1.17
  466.      */
  467.     long getBytesRead() {
  468.         return reader.getBytesRead();
  469.     }

  470.     private long readBits(final int numBits) throws IOException {
  471.         return readBits(reader, numBits);
  472.     }

  473.     private int[][] readDynamicTables() throws IOException {
  474.         final int[][] result = new int[2][];
  475.         final int literals = (int) (readBits(5) + 257);
  476.         result[0] = new int[literals];

  477.         final int distances = (int) (readBits(5) + 1);
  478.         result[1] = new int[distances];

  479.         populateDynamicTables(reader, result[0], result[1]);
  480.         return result;
  481.     }

  482.     private void switchToUncompressedState() throws IOException {
  483.         reader.alignWithByteBoundary();
  484.         final long bLen = readBits(16);
  485.         final long bNLen = readBits(16);
  486.         if (((bLen ^ 0xFFFF) & 0xFFFF) != bNLen) {
  487.             // noinspection DuplicateStringLiteralInspection
  488.             throw new IllegalStateException("Illegal LEN / NLEN values");
  489.         }
  490.         state = new UncompressedState(bLen);
  491.     }
  492. }