View Javadoc
1   /*
2    *  Licensed under the Apache License, Version 2.0 (the "License");
3    *  you may not use this file except in compliance with the License.
4    *  You may obtain a copy of the License at
5    *
6    *       http://www.apache.org/licenses/LICENSE-2.0
7    *
8    *  Unless required by applicable law or agreed to in writing, software
9    *  distributed under the License is distributed on an "AS IS" BASIS,
10   *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11   *  See the License for the specific language governing permissions and
12   *  limitations under the License.
13   *  under the License.
14   */
15  
16  package org.apache.commons.imaging.formats.jpeg.decoder;
17  
18  import static org.apache.commons.imaging.common.BinaryFunctions.read2Bytes;
19  import static org.apache.commons.imaging.common.BinaryFunctions.readBytes;
20  
21  import java.awt.image.BufferedImage;
22  import java.awt.image.ColorModel;
23  import java.awt.image.DataBuffer;
24  import java.awt.image.DirectColorModel;
25  import java.awt.image.Raster;
26  import java.awt.image.WritableRaster;
27  import java.io.ByteArrayInputStream;
28  import java.io.IOException;
29  import java.util.ArrayList;
30  import java.util.Arrays;
31  import java.util.List;
32  import java.util.Properties;
33  
34  import org.apache.commons.imaging.ImageReadException;
35  import org.apache.commons.imaging.color.ColorConversions;
36  import org.apache.commons.imaging.common.BinaryFileParser;
37  import org.apache.commons.imaging.common.bytesource.ByteSource;
38  import org.apache.commons.imaging.formats.jpeg.JpegConstants;
39  import org.apache.commons.imaging.formats.jpeg.JpegUtils;
40  import org.apache.commons.imaging.formats.jpeg.segments.DhtSegment;
41  import org.apache.commons.imaging.formats.jpeg.segments.DhtSegment.HuffmanTable;
42  import org.apache.commons.imaging.formats.jpeg.segments.DqtSegment;
43  import org.apache.commons.imaging.formats.jpeg.segments.DqtSegment.QuantizationTable;
44  import org.apache.commons.imaging.formats.jpeg.segments.SofnSegment;
45  import org.apache.commons.imaging.formats.jpeg.segments.SosSegment;
46  
47  public class JpegDecoder extends BinaryFileParser implements JpegUtils.Visitor {
48      /*
49       * JPEG is an advanced image format that takes significant computation to
50       * decode. Keep decoding fast: - Don't allocate memory inside loops,
51       * allocate it once and reuse. - Minimize calculations per pixel and per
52       * block (using lookup tables for YCbCr->RGB conversion doubled
53       * performance). - Math.round() is slow, use (int)(x+0.5f) instead for
54       * positive numbers.
55       */
56  
57      private final DqtSegment.QuantizationTable[] quantizationTables = new DqtSegment.QuantizationTable[4];
58      private final DhtSegment.HuffmanTable[] huffmanDCTables = new DhtSegment.HuffmanTable[4];
59      private final DhtSegment.HuffmanTable[] huffmanACTables = new DhtSegment.HuffmanTable[4];
60      private SofnSegment sofnSegment;
61      private SosSegment sosSegment;
62      private final float[][] scaledQuantizationTables = new float[4][];
63      private BufferedImage image;
64      private ImageReadException imageReadException;
65      private IOException ioException;
66      private final int[] zz = new int[64];
67      private final int[] blockInt = new int[64];
68      private final float[] block = new float[64];
69  
70      @Override
71      public boolean beginSOS() {
72          return true;
73      }
74  
75      @Override
76      public void visitSOS(final int marker, final byte[] markerBytes, final byte[] imageData) {
77          final ByteArrayInputStream is = new ByteArrayInputStream(imageData);
78          try {
79              // read the scan header
80              final int segmentLength = read2Bytes("segmentLength", is,"Not a Valid JPEG File", getByteOrder());
81              final byte[] sosSegmentBytes = readBytes("SosSegment", is, segmentLength - 2, "Not a Valid JPEG File");
82              sosSegment = new SosSegment(marker, sosSegmentBytes);
83              // read the payload of the scan, this is the remainder of image data after the header
84              // the payload contains the entropy-encoded segments (or ECS) divided by RST markers
85              // or only one ECS if the entropy-encoded data is not divided by RST markers
86              // length of payload = length of image data - length of data already read
87              final int[] scanPayload = new int[imageData.length - segmentLength];
88              int payloadReadCount = 0;
89              while (payloadReadCount < scanPayload.length) {
90                  scanPayload[payloadReadCount] = is.read();
91                  payloadReadCount++;
92              }
93  
94              int hMax = 0;
95              int vMax = 0;
96              for (int i = 0; i < sofnSegment.numberOfComponents; i++) {
97                  hMax = Math.max(hMax,
98                          sofnSegment.getComponents(i).horizontalSamplingFactor);
99                  vMax = Math.max(vMax,
100                         sofnSegment.getComponents(i).verticalSamplingFactor);
101             }
102             final int hSize = 8 * hMax;
103             final int vSize = 8 * vMax;
104 
105             final int xMCUs = (sofnSegment.width + hSize - 1) / hSize;
106             final int yMCUs = (sofnSegment.height + vSize - 1) / vSize;
107             final Block[] mcu = allocateMCUMemory();
108             final Blocking/formats/jpeg/decoder/Block.html#Block">Block[] scaledMCU = new Block[mcu.length];
109             for (int i = 0; i < scaledMCU.length; i++) {
110                 scaledMCU[i] = new Block(hSize, vSize);
111             }
112             final int[] preds = new int[sofnSegment.numberOfComponents];
113             ColorModel colorModel;
114             WritableRaster raster;
115             switch (sofnSegment.numberOfComponents) {
116             case 4:
117                 colorModel = new DirectColorModel(24, 0x00ff0000, 0x0000ff00, 0x000000ff);
118                 final int[] bandMasks = new int[] { 0x00ff0000, 0x0000ff00, 0x000000ff };
119                 raster = Raster.createPackedRaster(DataBuffer.TYPE_INT, sofnSegment.width, sofnSegment.height, bandMasks, null);
120                 break;
121             case 3:
122                 colorModel = new DirectColorModel(24, 0x00ff0000, 0x0000ff00,
123                         0x000000ff);
124                 raster = Raster.createPackedRaster(DataBuffer.TYPE_INT,
125                         sofnSegment.width, sofnSegment.height, new int[] {
126                                 0x00ff0000, 0x0000ff00, 0x000000ff }, null);
127                 break;
128             case 1:
129                 colorModel = new DirectColorModel(24, 0x00ff0000, 0x0000ff00,
130                         0x000000ff);
131                 raster = Raster.createPackedRaster(DataBuffer.TYPE_INT,
132                         sofnSegment.width, sofnSegment.height, new int[] {
133                                 0x00ff0000, 0x0000ff00, 0x000000ff }, null);
134                 // FIXME: why do images come out too bright with CS_GRAY?
135                 // colorModel = new ComponentColorModel(
136                 // ColorSpace.getInstance(ColorSpace.CS_GRAY), false, true,
137                 // Transparency.OPAQUE, DataBuffer.TYPE_BYTE);
138                 // raster = colorModel.createCompatibleWritableRaster(
139                 // sofnSegment.width, sofnSegment.height);
140                 break;
141             default:
142                 throw new ImageReadException(sofnSegment.numberOfComponents
143                         + " components are invalid or unsupported");
144             }
145             final DataBuffer dataBuffer = raster.getDataBuffer();
146 
147             final JpegInputStream[] bitInputStreams = splitByRstMarkers(scanPayload);
148             int bitInputStreamCount = 0;
149             JpegInputStream bitInputStream = bitInputStreams[0];
150 
151             for (int y1 = 0; y1 < vSize * yMCUs; y1 += vSize) {
152                 for (int x1 = 0; x1 < hSize * xMCUs; x1 += hSize) {
153                     // Provide the next interval if an interval is read until it's end
154                     // as long there are unread intervals available
155                     if (!bitInputStream.hasNext()) {
156                         bitInputStreamCount++;
157                         if (bitInputStreamCount < bitInputStreams.length) {
158                             bitInputStream = bitInputStreams[bitInputStreamCount];
159                         }
160                     }
161 
162                     readMCU(bitInputStream, preds, mcu);
163                     rescaleMCU(mcu, hSize, vSize, scaledMCU);
164                     int srcRowOffset = 0;
165                     int dstRowOffset = y1 * sofnSegment.width + x1;
166                     for (int y2 = 0; y2 < vSize && y1 + y2 < sofnSegment.height; y2++) {
167                         for (int x2 = 0; x2 < hSize
168                                 && x1 + x2 < sofnSegment.width; x2++) {
169                             if (scaledMCU.length == 4) {
170                                 final int C = scaledMCU[0].samples[srcRowOffset + x2];
171                                 final int M = scaledMCU[1].samples[srcRowOffset + x2];
172                                 final int Y = scaledMCU[2].samples[srcRowOffset + x2];
173                                 final int K = scaledMCU[3].samples[srcRowOffset + x2];
174                                 final int rgb = ColorConversions.convertCMYKtoRGB(C, M, Y, K);
175                                 dataBuffer.setElem(dstRowOffset + x2, rgb);
176                             } else if (scaledMCU.length == 3) {
177                                 final int Y = scaledMCU[0].samples[srcRowOffset + x2];
178                                 final int Cb = scaledMCU[1].samples[srcRowOffset + x2];
179                                 final int Cr = scaledMCU[2].samples[srcRowOffset + x2];
180                                 final int rgb = YCbCrConverter.convertYCbCrToRGB(Y,
181                                         Cb, Cr);
182                                 dataBuffer.setElem(dstRowOffset + x2, rgb);
183                             } else if (mcu.length == 1) {
184                                 final int Y = scaledMCU[0].samples[srcRowOffset + x2];
185                                 dataBuffer.setElem(dstRowOffset + x2, (Y << 16)
186                                         | (Y << 8) | Y);
187                             } else {
188                                 throw new ImageReadException(
189                                         "Unsupported JPEG with " + mcu.length
190                                                 + " components");
191                             }
192                         }
193                         srcRowOffset += hSize;
194                         dstRowOffset += sofnSegment.width;
195                     }
196                 }
197             }
198             image = new BufferedImage(colorModel, raster,
199                     colorModel.isAlphaPremultiplied(), new Properties());
200             // byte[] remainder = super.getStreamBytes(is);
201             // for (int i = 0; i < remainder.length; i++)
202             // {
203             // System.out.println("" + i + " = " +
204             // Integer.toHexString(remainder[i]));
205             // }
206         } catch (final ImageReadException imageReadEx) {
207             imageReadException = imageReadEx;
208         } catch (final IOException ioEx) {
209             ioException = ioEx;
210         } catch (final RuntimeException ex) {
211             // Corrupt images can throw NPE and IOOBE
212             imageReadException = new ImageReadException("Error parsing JPEG",ex);
213         }
214     }
215 
216     @Override
217     public boolean visitSegment(final int marker, final byte[] markerBytes,
218             final int segmentLength, final byte[] segmentLengthBytes, final byte[] segmentData)
219             throws ImageReadException, IOException {
220         final int[] sofnSegments = {
221                 JpegConstants.SOF0_MARKER,
222                 JpegConstants.SOF1_MARKER,
223                 JpegConstants.SOF2_MARKER,
224                 JpegConstants.SOF3_MARKER,
225                 JpegConstants.SOF5_MARKER,
226                 JpegConstants.SOF6_MARKER,
227                 JpegConstants.SOF7_MARKER,
228                 JpegConstants.SOF9_MARKER,
229                 JpegConstants.SOF10_MARKER,
230                 JpegConstants.SOF11_MARKER,
231                 JpegConstants.SOF13_MARKER,
232                 JpegConstants.SOF14_MARKER,
233                 JpegConstants.SOF15_MARKER,
234         };
235 
236         if (Arrays.binarySearch(sofnSegments, marker) >= 0) {
237             if (marker != JpegConstants.SOF0_MARKER) {
238                 throw new ImageReadException("Only sequential, baseline JPEGs "
239                         + "are supported at the moment");
240             }
241             sofnSegment = new SofnSegment(marker, segmentData);
242         } else if (marker == JpegConstants.DQT_MARKER) {
243             final DqtSegmentformats/jpeg/segments/DqtSegment.html#DqtSegment">DqtSegment dqtSegment = new DqtSegment(marker, segmentData);
244             for (final QuantizationTable table : dqtSegment.quantizationTables) {
245                 if (0 > table.destinationIdentifier
246                         || table.destinationIdentifier >= quantizationTables.length) {
247                     throw new ImageReadException(
248                             "Invalid quantization table identifier "
249                                     + table.destinationIdentifier);
250                 }
251                 quantizationTables[table.destinationIdentifier] = table;
252                 final int[] quantizationMatrixInt = new int[64];
253                 ZigZag.zigZagToBlock(table.getElements(), quantizationMatrixInt);
254                 final float[] quantizationMatrixFloat = new float[64];
255                 for (int j = 0; j < 64; j++) {
256                     quantizationMatrixFloat[j] = quantizationMatrixInt[j];
257                 }
258                 Dct.scaleDequantizationMatrix(quantizationMatrixFloat);
259                 scaledQuantizationTables[table.destinationIdentifier] = quantizationMatrixFloat;
260             }
261         } else if (marker == JpegConstants.DHT_MARKER) {
262             final DhtSegmentformats/jpeg/segments/DhtSegment.html#DhtSegment">DhtSegment dhtSegment = new DhtSegment(marker, segmentData);
263             for (final HuffmanTable table : dhtSegment.huffmanTables) {
264                 DhtSegment.HuffmanTable[] tables;
265                 if (table.tableClass == 0) {
266                     tables = huffmanDCTables;
267                 } else if (table.tableClass == 1) {
268                     tables = huffmanACTables;
269                 } else {
270                     throw new ImageReadException("Invalid huffman table class "
271                             + table.tableClass);
272                 }
273                 if (0 > table.destinationIdentifier
274                         || table.destinationIdentifier >= tables.length) {
275                     throw new ImageReadException(
276                             "Invalid huffman table identifier "
277                                     + table.destinationIdentifier);
278                 }
279                 tables[table.destinationIdentifier] = table;
280             }
281         }
282         return true;
283     }
284 
285     private void rescaleMCU(final Block[] dataUnits, final int hSize, Block">final int vSize, final Block[] ret) {
286         for (int i = 0; i < dataUnits.length; i++) {
287             final Block dataUnit = dataUnits[i];
288             if (dataUnit.width == hSize && dataUnit.height == vSize) {
289                 System.arraycopy(dataUnit.samples, 0, ret[i].samples, 0, hSize
290                         * vSize);
291             } else {
292                 final int hScale = hSize / dataUnit.width;
293                 final int vScale = vSize / dataUnit.height;
294                 if (hScale == 2 && vScale == 2) {
295                     int srcRowOffset = 0;
296                     int dstRowOffset = 0;
297                     for (int y = 0; y < dataUnit.height; y++) {
298                         for (int x = 0; x < hSize; x++) {
299                             final int sample = dataUnit.samples[srcRowOffset + (x >> 1)];
300                             ret[i].samples[dstRowOffset + x] = sample;
301                             ret[i].samples[dstRowOffset + hSize + x] = sample;
302                         }
303                         srcRowOffset += dataUnit.width;
304                         dstRowOffset += 2 * hSize;
305                     }
306                 } else {
307                     // FIXME: optimize
308                     int dstRowOffset = 0;
309                     for (int y = 0; y < vSize; y++) {
310                         for (int x = 0; x < hSize; x++) {
311                             ret[i].samples[dstRowOffset + x] = dataUnit.samples[(y / vScale)
312                                     * dataUnit.width + (x / hScale)];
313                         }
314                         dstRowOffset += hSize;
315                     }
316                 }
317             }
318         }
319     }
320 
321     private Block[] allocateMCUMemory() throws ImageReadException {
322         final Blocks/imaging/formats/jpeg/decoder/Block.html#Block">Block[] mcu = new Block[sosSegment.numberOfComponents];
323         for (int i = 0; i < sosSegment.numberOfComponents; i++) {
324             final SosSegment.Component scanComponent = sosSegment.getComponents(i);
325             SofnSegment.Component frameComponent = null;
326             for (int j = 0; j < sofnSegment.numberOfComponents; j++) {
327                 if (sofnSegment.getComponents(j).componentIdentifier == scanComponent.scanComponentSelector) {
328                     frameComponent = sofnSegment.getComponents(j);
329                     break;
330                 }
331             }
332             if (frameComponent == null) {
333                 throw new ImageReadException("Invalid component");
334             }
335             final Blockaging/formats/jpeg/decoder/Block.html#Block">Block fullBlock = new Block(
336                     8 * frameComponent.horizontalSamplingFactor,
337                     8 * frameComponent.verticalSamplingFactor);
338             mcu[i] = fullBlock;
339         }
340         return mcu;
341     }
342 
343     private void readMCU(final JpegInputStream is, final int[] preds, final Block[] mcu)
344             throws ImageReadException {
345         for (int i = 0; i < sosSegment.numberOfComponents; i++) {
346             final SosSegment.Component scanComponent = sosSegment.getComponents(i);
347             SofnSegment.Component frameComponent = null;
348             for (int j = 0; j < sofnSegment.numberOfComponents; j++) {
349                 if (sofnSegment.getComponents(j).componentIdentifier == scanComponent.scanComponentSelector) {
350                     frameComponent = sofnSegment.getComponents(j);
351                     break;
352                 }
353             }
354             if (frameComponent == null) {
355                 throw new ImageReadException("Invalid component");
356             }
357             final Block fullBlock = mcu[i];
358             for (int y = 0; y < frameComponent.verticalSamplingFactor; y++) {
359                 for (int x = 0; x < frameComponent.horizontalSamplingFactor; x++) {
360                     Arrays.fill(zz, 0);
361                     // page 104 of T.81
362                     final int t = decode(
363                             is,
364                             huffmanDCTables[scanComponent.dcCodingTableSelector]);
365                     int diff = receive(t, is);
366                     diff = extend(diff, t);
367                     zz[0] = preds[i] + diff;
368                     preds[i] = zz[0];
369 
370                     // "Decode_AC_coefficients", figure F.13, page 106 of T.81
371                     int k = 1;
372                     while (true) {
373                         final int rs = decode(
374                                 is,
375                                 huffmanACTables[scanComponent.acCodingTableSelector]);
376                         final int ssss = rs & 0xf;
377                         final int rrrr = rs >> 4;
378                         final int r = rrrr;
379 
380                         if (ssss == 0) {
381                             if (r != 15) {
382                                 break;
383                             }
384                             k += 16;
385                         } else {
386                             k += r;
387 
388                             // "Decode_ZZ(k)", figure F.14, page 107 of T.81
389                             zz[k] = receive(ssss, is);
390                             zz[k] = extend(zz[k], ssss);
391 
392                             if (k == 63) {
393                                 break;
394                             }
395                             k++;
396                         }
397                     }
398 
399                     final int shift = (1 << (sofnSegment.precision - 1));
400                     final int max = (1 << sofnSegment.precision) - 1;
401 
402                     final float[] scaledQuantizationTable = scaledQuantizationTables[frameComponent.quantTabDestSelector];
403                     ZigZag.zigZagToBlock(zz, blockInt);
404                     for (int j = 0; j < 64; j++) {
405                         block[j] = blockInt[j] * scaledQuantizationTable[j];
406                     }
407                     Dct.inverseDCT8x8(block);
408 
409                     int dstRowOffset = 8 * y * 8
410                             * frameComponent.horizontalSamplingFactor + 8 * x;
411                     int srcNext = 0;
412                     for (int yy = 0; yy < 8; yy++) {
413                         for (int xx = 0; xx < 8; xx++) {
414                             float sample = block[srcNext++];
415                             sample += shift;
416                             int result;
417                             if (sample < 0) {
418                                 result = 0;
419                             } else if (sample > max) {
420                                 result = max;
421                             } else {
422                                 result = fastRound(sample);
423                             }
424                             fullBlock.samples[dstRowOffset + xx] = result;
425                         }
426                         dstRowOffset += 8 * frameComponent.horizontalSamplingFactor;
427                     }
428                 }
429             }
430         }
431     }
432 
433     /**
434      * Returns an array of JpegInputStream where each field contains the JpegInputStream
435      * for one interval.
436      * @param scanPayload array to read intervals from
437      * @return JpegInputStreams for all intervals, at least one stream is always provided
438      */
439     static JpegInputStream[] splitByRstMarkers(final int[] scanPayload) {
440         final List<Integer> intervalStarts = getIntervalStartPositions(scanPayload);
441         // get number of intervals in payload to init an array of appropriate length
442         final int intervalCount = intervalStarts.size();
443         final JpegInputStreamats/jpeg/decoder/JpegInputStream.html#JpegInputStream">JpegInputStream[] streams = new JpegInputStream[intervalCount];
444         for (int i = 0; i < intervalCount; i++) {
445             final int from = intervalStarts.get(i);
446             int to;
447             if (i < intervalCount - 1) {
448                 // because each restart marker needs two bytes the end of
449                 // this interval is two bytes before the next interval starts
450                 to = intervalStarts.get(i + 1) - 2;
451             } else { // the last interval ends with the array
452                 to = scanPayload.length;
453             }
454             final int[] interval = Arrays.copyOfRange(scanPayload, from, to);
455             streams[i] = new JpegInputStream(interval);
456         }
457         return streams;
458     }
459 
460     /**
461      * Returns the positions of where each interval in the provided array starts. The number
462      * of start positions is also the count of intervals while the number of restart markers
463      * found is equal to the number of start positions minus one (because restart markers
464      * are between intervals).
465      *
466      * @param scanPayload array to examine
467      * @return the start positions
468      */
469     static List<Integer> getIntervalStartPositions(final int[] scanPayload) {
470         final List<Integer> intervalStarts = new ArrayList<>();
471         intervalStarts.add(0);
472         boolean foundFF = false;
473         boolean foundD0toD7 = false;
474         int pos = 0;
475         while (pos < scanPayload.length) {
476             if (foundFF) {
477                 // found 0xFF D0 .. 0xFF D7 => RST marker
478                 if (scanPayload[pos] >= (0xff & JpegConstants.RST0_MARKER) &&
479                     scanPayload[pos] <= (0xff & JpegConstants.RST7_MARKER)) {
480                     foundD0toD7 = true;
481                 } else { // found 0xFF followed by something else => no RST marker
482                     foundFF = false;
483                 }
484             }
485 
486             if (scanPayload[pos] == 0xFF) {
487                 foundFF = true;
488             }
489 
490             // true if one of the RST markers was found
491             if (foundFF && foundD0toD7) {
492                 // we need to add the position after the current position because
493                 // we had already read 0xFF and are now at 0xDn
494                 intervalStarts.add(pos + 1);
495                 foundFF = foundD0toD7 = false;
496             }
497             pos++;
498         }
499         return intervalStarts;
500     }
501 
502     private static int fastRound(final float x) {
503         return (int) (x + 0.5f);
504     }
505 
506     private int extend(int v, final int t) {
507         // "EXTEND", section F.2.2.1, figure F.12, page 105 of T.81
508         int vt = (1 << (t - 1));
509         if (v < vt) {
510             vt = (-1 << t) + 1;
511             v += vt;
512         }
513         return v;
514     }
515 
516     private int receive(final int ssss, final JpegInputStream is) throws ImageReadException {
517         // "RECEIVE", section F.2.2.4, figure F.17, page 110 of T.81
518         int i = 0;
519         int v = 0;
520         while (i != ssss) {
521             i++;
522             v = (v << 1) + is.nextBit();
523         }
524         return v;
525     }
526 
527     private int decode(final JpegInputStream is, final DhtSegment.HuffmanTable huffmanTable)
528             throws ImageReadException {
529         // "DECODE", section F.2.2.3, figure F.16, page 109 of T.81
530         int i = 1;
531         int code = is.nextBit();
532         while (code > huffmanTable.getMaxCode(i)) {
533             i++;
534             code = (code << 1) | is.nextBit();
535         }
536         int j = huffmanTable.getValPtr(i);
537         j += code - huffmanTable.getMinCode(i);
538         return huffmanTable.getHuffVal(j);
539     }
540 
541     public BufferedImage decode(final ByteSource byteSource) throws IOException,
542             ImageReadException {
543         final JpegUtilsg/formats/jpeg/JpegUtils.html#JpegUtils">JpegUtils jpegUtils = new JpegUtils();
544         jpegUtils.traverseJFIF(byteSource, this);
545         if (imageReadException != null) {
546             throw imageReadException;
547         }
548         if (ioException != null) {
549             throw ioException;
550         }
551         return image;
552     }
553 }