View Javadoc
1   /*
2    *  Licensed under the Apache License, Version 2.0 (the "License");
3    *  you may not use this file except in compliance with the License.
4    *  You may obtain a copy of the License at
5    * 
6    *       http://www.apache.org/licenses/LICENSE-2.0
7    * 
8    *  Unless required by applicable law or agreed to in writing, software
9    *  distributed under the License is distributed on an "AS IS" BASIS,
10   *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11   *  See the License for the specific language governing permissions and
12   *  limitations under the License.
13   *  under the License.
14   */
15  
16  package org.apache.commons.imaging.formats.jpeg.decoder;
17  
18  import java.awt.image.BufferedImage;
19  import java.awt.image.ColorModel;
20  import java.awt.image.DataBuffer;
21  import java.awt.image.DirectColorModel;
22  import java.awt.image.Raster;
23  import java.awt.image.WritableRaster;
24  import java.io.ByteArrayInputStream;
25  import java.io.IOException;
26  import java.util.Arrays;
27  import java.util.Properties;
28  
29  import org.apache.commons.imaging.ImageReadException;
30  import org.apache.commons.imaging.common.BinaryFileParser;
31  import org.apache.commons.imaging.common.bytesource.ByteSource;
32  import org.apache.commons.imaging.formats.jpeg.JpegConstants;
33  import org.apache.commons.imaging.formats.jpeg.JpegUtils;
34  import org.apache.commons.imaging.formats.jpeg.segments.DhtSegment;
35  import org.apache.commons.imaging.formats.jpeg.segments.DqtSegment;
36  import org.apache.commons.imaging.formats.jpeg.segments.SofnSegment;
37  import org.apache.commons.imaging.formats.jpeg.segments.SosSegment;
38  
39  import static org.apache.commons.imaging.common.BinaryFunctions.*;
40  
41  public class JpegDecoder extends BinaryFileParser implements JpegUtils.Visitor {
42      /*
43       * JPEG is an advanced image format that takes significant computation to
44       * decode. Keep decoding fast: - Don't allocate memory inside loops,
45       * allocate it once and reuse. - Minimize calculations per pixel and per
46       * block (using lookup tables for YCbCr->RGB conversion doubled
47       * performance). - Math.round() is slow, use (int)(x+0.5f) instead for
48       * positive numbers.
49       */
50  
51      private final DqtSegment.QuantizationTable[] quantizationTables = new DqtSegment.QuantizationTable[4];
52      private final DhtSegment.HuffmanTable[] huffmanDCTables = new DhtSegment.HuffmanTable[4];
53      private final DhtSegment.HuffmanTable[] huffmanACTables = new DhtSegment.HuffmanTable[4];
54      private SofnSegment sofnSegment;
55      private SosSegment sosSegment;
56      private final float[][] scaledQuantizationTables = new float[4][];
57      private BufferedImage image;
58      private ImageReadException imageReadException;
59      private IOException ioException;
60      private final int[] zz = new int[64];
61      private final int[] blockInt = new int[64];
62      private final float[] block = new float[64];
63  
64      @Override
65      public boolean beginSOS() {
66          return true;
67      }
68  
69      @Override
70      public void visitSOS(final int marker, final byte[] markerBytes, final byte[] imageData) {
71          final ByteArrayInputStream is = new ByteArrayInputStream(imageData);
72          try {
73              final int segmentLength = read2Bytes("segmentLength", is, "Not a Valid JPEG File", getByteOrder());
74              final byte[] sosSegmentBytes = readBytes("SosSegment",
75                      is, segmentLength - 2, "Not a Valid JPEG File");
76              sosSegment = new SosSegment(marker, sosSegmentBytes);
77  
78              int hMax = 0;
79              int vMax = 0;
80              for (int i = 0; i < sofnSegment.numberOfComponents; i++) {
81                  hMax = Math.max(hMax,
82                          sofnSegment.getComponents(i).horizontalSamplingFactor);
83                  vMax = Math.max(vMax,
84                          sofnSegment.getComponents(i).verticalSamplingFactor);
85              }
86              final int hSize = 8 * hMax;
87              final int vSize = 8 * vMax;
88  
89              final JpegInputStream bitInputStream = new JpegInputStream(is);
90              final int xMCUs = (sofnSegment.width + hSize - 1) / hSize;
91              final int yMCUs = (sofnSegment.height + vSize - 1) / vSize;
92              final Block[] mcu = allocateMCUMemory();
93              final Block[] scaledMCU = new Block[mcu.length];
94              for (int i = 0; i < scaledMCU.length; i++) {
95                  scaledMCU[i] = new Block(hSize, vSize);
96              }
97              final int[] preds = new int[sofnSegment.numberOfComponents];
98              ColorModel colorModel;
99              WritableRaster raster;
100             if (sofnSegment.numberOfComponents == 3) {
101                 colorModel = new DirectColorModel(24, 0x00ff0000, 0x0000ff00,
102                         0x000000ff);
103                 raster = Raster.createPackedRaster(DataBuffer.TYPE_INT,
104                         sofnSegment.width, sofnSegment.height, new int[] {
105                                 0x00ff0000, 0x0000ff00, 0x000000ff }, null);
106             } else if (sofnSegment.numberOfComponents == 1) {
107                 colorModel = new DirectColorModel(24, 0x00ff0000, 0x0000ff00,
108                         0x000000ff);
109                 raster = Raster.createPackedRaster(DataBuffer.TYPE_INT,
110                         sofnSegment.width, sofnSegment.height, new int[] {
111                                 0x00ff0000, 0x0000ff00, 0x000000ff }, null);
112                 // FIXME: why do images come out too bright with CS_GRAY?
113                 // colorModel = new ComponentColorModel(
114                 // ColorSpace.getInstance(ColorSpace.CS_GRAY), false, true,
115                 // Transparency.OPAQUE, DataBuffer.TYPE_BYTE);
116                 // raster = colorModel.createCompatibleWritableRaster(
117                 // sofnSegment.width, sofnSegment.height);
118             } else {
119                 throw new ImageReadException(sofnSegment.numberOfComponents
120                         + " components are invalid or unsupported");
121             }
122             final DataBuffer dataBuffer = raster.getDataBuffer();
123 
124             for (int y1 = 0; y1 < vSize * yMCUs; y1 += vSize) {
125                 for (int x1 = 0; x1 < hSize * xMCUs; x1 += hSize) {
126                     readMCU(bitInputStream, preds, mcu);
127                     rescaleMCU(mcu, hSize, vSize, scaledMCU);
128                     int srcRowOffset = 0;
129                     int dstRowOffset = y1 * sofnSegment.width + x1;
130                     for (int y2 = 0; y2 < vSize && y1 + y2 < sofnSegment.height; y2++) {
131                         for (int x2 = 0; x2 < hSize
132                                 && x1 + x2 < sofnSegment.width; x2++) {
133                             if (scaledMCU.length == 3) {
134                                 final int Y = scaledMCU[0].samples[srcRowOffset + x2];
135                                 final int Cb = scaledMCU[1].samples[srcRowOffset + x2];
136                                 final int Cr = scaledMCU[2].samples[srcRowOffset + x2];
137                                 final int rgb = YCbCrConverter.convertYCbCrToRGB(Y,
138                                         Cb, Cr);
139                                 dataBuffer.setElem(dstRowOffset + x2, rgb);
140                             } else if (mcu.length == 1) {
141                                 final int Y = scaledMCU[0].samples[srcRowOffset + x2];
142                                 dataBuffer.setElem(dstRowOffset + x2, (Y << 16)
143                                         | (Y << 8) | Y);
144                             } else {
145                                 throw new ImageReadException(
146                                         "Unsupported JPEG with " + mcu.length
147                                                 + " components");
148                             }
149                         }
150                         srcRowOffset += hSize;
151                         dstRowOffset += sofnSegment.width;
152                     }
153                 }
154             }
155             image = new BufferedImage(colorModel, raster,
156                     colorModel.isAlphaPremultiplied(), new Properties());
157             // byte[] remainder = super.getStreamBytes(is);
158             // for (int i = 0; i < remainder.length; i++)
159             // {
160             // System.out.println("" + i + " = " +
161             // Integer.toHexString(remainder[i]));
162             // }
163         } catch (final ImageReadException imageReadEx) {
164             imageReadException = imageReadEx;
165         } catch (final IOException ioEx) {
166             ioException = ioEx;
167         } catch (final RuntimeException ex) {
168             // Corrupt images can throw NPE and IOOBE
169             imageReadException = new ImageReadException("Error parsing JPEG",
170                     ex);
171         }
172     }
173 
174     @Override
175     public boolean visitSegment(final int marker, final byte[] markerBytes,
176             final int segmentLength, final byte[] segmentLengthBytes, final byte[] segmentData)
177             throws ImageReadException, IOException {
178         final int[] sofnSegments = {
179                 JpegConstants.SOF0_MARKER,
180                 JpegConstants.SOF1_MARKER,
181                 JpegConstants.SOF2_MARKER,
182                 JpegConstants.SOF3_MARKER,
183                 JpegConstants.SOF5_MARKER,
184                 JpegConstants.SOF6_MARKER,
185                 JpegConstants.SOF7_MARKER,
186                 JpegConstants.SOF9_MARKER,
187                 JpegConstants.SOF10_MARKER,
188                 JpegConstants.SOF11_MARKER,
189                 JpegConstants.SOF13_MARKER,
190                 JpegConstants.SOF14_MARKER,
191                 JpegConstants.SOF15_MARKER,
192         };
193 
194         if (Arrays.binarySearch(sofnSegments, marker) >= 0) {
195             if (marker != JpegConstants.SOF0_MARKER) {
196                 throw new ImageReadException("Only sequential, baseline JPEGs "
197                         + "are supported at the moment");
198             }
199             sofnSegment = new SofnSegment(marker, segmentData);
200         } else if (marker == JpegConstants.DQT_MARKER) {
201             final DqtSegment dqtSegment = new DqtSegment(marker, segmentData);
202             for (int i = 0; i < dqtSegment.quantizationTables.size(); i++) {
203                 final DqtSegment.QuantizationTable table = dqtSegment.quantizationTables.get(i);
204                 if (0 > table.destinationIdentifier
205                         || table.destinationIdentifier >= quantizationTables.length) {
206                     throw new ImageReadException(
207                             "Invalid quantization table identifier "
208                                     + table.destinationIdentifier);
209                 }
210                 quantizationTables[table.destinationIdentifier] = table;
211                 final int[] quantizationMatrixInt = new int[64];
212                 ZigZag.zigZagToBlock(table.getElements(), quantizationMatrixInt);
213                 final float[] quantizationMatrixFloat = new float[64];
214                 for (int j = 0; j < 64; j++) {
215                     quantizationMatrixFloat[j] = quantizationMatrixInt[j];
216                 }
217                 Dct.scaleDequantizationMatrix(quantizationMatrixFloat);
218                 scaledQuantizationTables[table.destinationIdentifier] = quantizationMatrixFloat;
219             }
220         } else if (marker == JpegConstants.DHT_MARKER) {
221             final DhtSegment dhtSegment = new DhtSegment(marker, segmentData);
222             for (int i = 0; i < dhtSegment.huffmanTables.size(); i++) {
223                 final DhtSegment.HuffmanTable table = dhtSegment.huffmanTables.get(i);
224                 DhtSegment.HuffmanTable[] tables;
225                 if (table.tableClass == 0) {
226                     tables = huffmanDCTables;
227                 } else if (table.tableClass == 1) {
228                     tables = huffmanACTables;
229                 } else {
230                     throw new ImageReadException("Invalid huffman table class "
231                             + table.tableClass);
232                 }
233                 if (0 > table.destinationIdentifier
234                         || table.destinationIdentifier >= tables.length) {
235                     throw new ImageReadException(
236                             "Invalid huffman table identifier "
237                                     + table.destinationIdentifier);
238                 }
239                 tables[table.destinationIdentifier] = table;
240             }
241         }
242         return true;
243     }
244 
245     private void rescaleMCU(final Block[] dataUnits, final int hSize, final int vSize, final Block[] ret) {
246         for (int i = 0; i < dataUnits.length; i++) {
247             final Block dataUnit = dataUnits[i];
248             if (dataUnit.width == hSize && dataUnit.height == vSize) {
249                 System.arraycopy(dataUnit.samples, 0, ret[i].samples, 0, hSize
250                         * vSize);
251             } else {
252                 final int hScale = hSize / dataUnit.width;
253                 final int vScale = vSize / dataUnit.height;
254                 if (hScale == 2 && vScale == 2) {
255                     int srcRowOffset = 0;
256                     int dstRowOffset = 0;
257                     for (int y = 0; y < dataUnit.height; y++) {
258                         for (int x = 0; x < hSize; x++) {
259                             final int sample = dataUnit.samples[srcRowOffset + (x >> 1)];
260                             ret[i].samples[dstRowOffset + x] = sample;
261                             ret[i].samples[dstRowOffset + hSize + x] = sample;
262                         }
263                         srcRowOffset += dataUnit.width;
264                         dstRowOffset += 2 * hSize;
265                     }
266                 } else {
267                     // FIXME: optimize
268                     int dstRowOffset = 0;
269                     for (int y = 0; y < vSize; y++) {
270                         for (int x = 0; x < hSize; x++) {
271                             ret[i].samples[dstRowOffset + x] = dataUnit.samples[(y / vScale)
272                                     * dataUnit.width + (x / hScale)];
273                         }
274                         dstRowOffset += hSize;
275                     }
276                 }
277             }
278         }
279     }
280 
281     private Block[] allocateMCUMemory() throws ImageReadException {
282         final Block[] mcu = new Block[sosSegment.numberOfComponents];
283         for (int i = 0; i < sosSegment.numberOfComponents; i++) {
284             final SosSegment.Component scanComponent = sosSegment.getComponents(i);
285             SofnSegment.Component frameComponent = null;
286             for (int j = 0; j < sofnSegment.numberOfComponents; j++) {
287                 if (sofnSegment.getComponents(j).componentIdentifier == scanComponent.scanComponentSelector) {
288                     frameComponent = sofnSegment.getComponents(j);
289                     break;
290                 }
291             }
292             if (frameComponent == null) {
293                 throw new ImageReadException("Invalid component");
294             }
295             final Block fullBlock = new Block(
296                     8 * frameComponent.horizontalSamplingFactor,
297                     8 * frameComponent.verticalSamplingFactor);
298             mcu[i] = fullBlock;
299         }
300         return mcu;
301     }
302 
303     private void readMCU(final JpegInputStream is, final int[] preds, final Block[] mcu)
304             throws IOException, ImageReadException {
305         for (int i = 0; i < sosSegment.numberOfComponents; i++) {
306             final SosSegment.Component scanComponent = sosSegment.getComponents(i);
307             SofnSegment.Component frameComponent = null;
308             for (int j = 0; j < sofnSegment.numberOfComponents; j++) {
309                 if (sofnSegment.getComponents(j).componentIdentifier == scanComponent.scanComponentSelector) {
310                     frameComponent = sofnSegment.getComponents(j);
311                     break;
312                 }
313             }
314             if (frameComponent == null) {
315                 throw new ImageReadException("Invalid component");
316             }
317             final Block fullBlock = mcu[i];
318             for (int y = 0; y < frameComponent.verticalSamplingFactor; y++) {
319                 for (int x = 0; x < frameComponent.horizontalSamplingFactor; x++) {
320                     Arrays.fill(zz, 0);
321                     // page 104 of T.81
322                     final int t = decode(
323                             is,
324                             huffmanDCTables[scanComponent.dcCodingTableSelector]);
325                     int diff = receive(t, is);
326                     diff = extend(diff, t);
327                     zz[0] = preds[i] + diff;
328                     preds[i] = zz[0];
329 
330                     // "Decode_AC_coefficients", figure F.13, page 106 of T.81
331                     int k = 1;
332                     while (true) {
333                         final int rs = decode(
334                                 is,
335                                 huffmanACTables[scanComponent.acCodingTableSelector]);
336                         final int ssss = rs & 0xf;
337                         final int rrrr = rs >> 4;
338                         final int r = rrrr;
339 
340                         if (ssss == 0) {
341                             if (r == 15) {
342                                 k += 16;
343                             } else {
344                                 break;
345                             }
346                         } else {
347                             k += r;
348 
349                             // "Decode_ZZ(k)", figure F.14, page 107 of T.81
350                             zz[k] = receive(ssss, is);
351                             zz[k] = extend(zz[k], ssss);
352 
353                             if (k == 63) {
354                                 break;
355                             } else {
356                                 k++;
357                             }
358                         }
359                     }
360 
361                     final int shift = (1 << (sofnSegment.precision - 1));
362                     final int max = (1 << sofnSegment.precision) - 1;
363 
364                     final float[] scaledQuantizationTable = scaledQuantizationTables[frameComponent.quantTabDestSelector];
365                     ZigZag.zigZagToBlock(zz, blockInt);
366                     for (int j = 0; j < 64; j++) {
367                         block[j] = blockInt[j] * scaledQuantizationTable[j];
368                     }
369                     Dct.inverseDCT8x8(block);
370 
371                     int dstRowOffset = 8 * y * 8
372                             * frameComponent.horizontalSamplingFactor + 8 * x;
373                     int srcNext = 0;
374                     for (int yy = 0; yy < 8; yy++) {
375                         for (int xx = 0; xx < 8; xx++) {
376                             float sample = block[srcNext++];
377                             sample += shift;
378                             int result;
379                             if (sample < 0) {
380                                 result = 0;
381                             } else if (sample > max) {
382                                 result = max;
383                             } else {
384                                 result = fastRound(sample);
385                             }
386                             fullBlock.samples[dstRowOffset + xx] = result;
387                         }
388                         dstRowOffset += 8 * frameComponent.horizontalSamplingFactor;
389                     }
390                 }
391             }
392         }
393     }
394 
395     private static int fastRound(final float x) {
396         return (int) (x + 0.5f);
397     }
398 
399     private int extend(int v, final int t) {
400         // "EXTEND", section F.2.2.1, figure F.12, page 105 of T.81
401         int vt = (1 << (t - 1));
402         while (v < vt) {
403             vt = (-1 << t) + 1;
404             v += vt;
405         }
406         return v;
407     }
408 
409     private int receive(final int ssss, final JpegInputStream is) throws IOException,
410             ImageReadException {
411         // "RECEIVE", section F.2.2.4, figure F.17, page 110 of T.81
412         int i = 0;
413         int v = 0;
414         while (i != ssss) {
415             i++;
416             v = (v << 1) + is.nextBit();
417         }
418         return v;
419     }
420 
421     private int decode(final JpegInputStream is, final DhtSegment.HuffmanTable huffmanTable)
422             throws IOException, ImageReadException {
423         // "DECODE", section F.2.2.3, figure F.16, page 109 of T.81
424         int i = 1;
425         int code = is.nextBit();
426         while (code > huffmanTable.getMaxCode(i)) {
427             i++;
428             code = (code << 1) | is.nextBit();
429         }
430         int j = huffmanTable.getValPtr(i);
431         j += code - huffmanTable.getMinCode(i);
432         return huffmanTable.getHuffVal(j);
433     }
434 
435     public BufferedImage decode(final ByteSource byteSource) throws IOException,
436             ImageReadException {
437         final JpegUtils jpegUtils = new JpegUtils();
438         jpegUtils.traverseJFIF(byteSource, this);
439         if (imageReadException != null) {
440             throw imageReadException;
441         }
442         if (ioException != null) {
443             throw ioException;
444         }
445         return image;
446     }
447 }