OpenSslGaloisCounterMode.java

  1.  /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one
  3.  * or more contributor license agreements.  See the NOTICE file
  4.  * distributed with this work for additional information
  5.  * regarding copyright ownership.  The ASF licenses this file
  6.  * to you under the Apache License, Version 2.0 (the
  7.  * "License"); you may not use this file except in compliance
  8.  * with the License.  You may obtain a copy of the License at
  9.  *
  10.  *     http://www.apache.org/licenses/LICENSE-2.0
  11.  *
  12.  * Unless required by applicable law or agreed to in writing, software
  13.  * distributed under the License is distributed on an "AS IS" BASIS,
  14.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15.  * See the License for the specific language governing permissions and
  16.  * limitations under the License.
  17.  */
  18. package org.apache.commons.crypto.cipher;

  19. import java.io.ByteArrayOutputStream;
  20. import java.nio.ByteBuffer;
  21. import java.nio.ByteOrder;
  22. import java.security.InvalidAlgorithmParameterException;
  23. import java.security.spec.AlgorithmParameterSpec;

  24. import javax.crypto.AEADBadTagException;
  25. import javax.crypto.BadPaddingException;
  26. import javax.crypto.IllegalBlockSizeException;
  27. import javax.crypto.ShortBufferException;
  28. import javax.crypto.spec.GCMParameterSpec;

  29. /**
  30.  * This class do the real work(Encryption/Decryption/Authentication) for the authenticated mode: GCM.
  31.  *
  32.  * It calls the OpenSSL API to implement the JCE-like behavior
  33.  *
  34.  * @since 1.1
  35.  */
  36. final class OpenSslGaloisCounterMode extends AbstractOpenSslFeedbackCipher {

  37.     static final int DEFAULT_TAG_LEN = 16;
  38.     // buffer for AAD data; if consumed, set as null
  39.     private ByteArrayOutputStream aadBuffer = new ByteArrayOutputStream();

  40.     private int tagBitLen = -1;

  41.     // buffer for storing input in decryption, not used for encryption
  42.     private ByteArrayOutputStream inBuffer;

  43.     public OpenSslGaloisCounterMode(final long context, final int algorithmMode, final int padding) {
  44.         super(context, algorithmMode, padding);
  45.     }

  46.     @Override
  47.     public void clean() {
  48.         super.clean();
  49.         aadBuffer = null;
  50.     }

  51.     @Override
  52.     public int doFinal(final byte[] input, final int inputOffset, final int inputLen, final byte[] output, final int outputOffset)
  53.             throws ShortBufferException, IllegalBlockSizeException, BadPaddingException {
  54.         checkState();

  55.         processAAD();

  56.         final int outputLength = output.length;
  57.         int len;
  58.         if (this.cipherMode == OpenSsl.DECRYPT_MODE) {
  59.             // if GCM-DECRYPT, we have to handle the buffered input
  60.             // and the retrieve the trailing tag from input
  61.             int inputOffsetFinal = inputOffset;
  62.             int inputLenFinal = inputLen;
  63.             final byte[] inputFinal;
  64.             if (inBuffer != null && inBuffer.size() > 0) {
  65.                 inBuffer.write(input, inputOffset, inputLen);
  66.                 inputFinal = inBuffer.toByteArray();
  67.                 inputOffsetFinal = 0;
  68.                 inputLenFinal = inputFinal.length;
  69.                 inBuffer.reset();
  70.             } else {
  71.                 inputFinal = input;
  72.             }

  73.             if (inputFinal.length < getTagLen()) {
  74.                 throw new AEADBadTagException("Input too short - need tag");
  75.             }

  76.             final int inputDataLen = inputLenFinal - getTagLen();
  77.             len = OpenSslNative.updateByteArray(context, inputFinal, inputOffsetFinal,
  78.                     inputDataLen, output, outputOffset, outputLength - outputOffset);

  79.             // set tag to EVP_Cipher for integrity verification in doFinal
  80.             final ByteBuffer tag = ByteBuffer.allocate(getTagLen());
  81.             tag.put(input, input.length - getTagLen(), getTagLen());
  82.             tag.flip();
  83.             evpCipherCtxCtrl(context, OpenSslEvpCtrlValues.AEAD_SET_TAG.getValue(), getTagLen(), tag);
  84.         } else {
  85.             len = OpenSslNative.updateByteArray(context, input, inputOffset,
  86.                     inputLen, output, outputOffset, outputLength - outputOffset);
  87.         }

  88.         len += OpenSslNative.doFinalByteArray(context, output, outputOffset + len,
  89.                 outputLength - outputOffset - len);

  90.         // Keep the similar behavior as JCE, append the tag to end of output
  91.         if (this.cipherMode == OpenSsl.ENCRYPT_MODE) {
  92.             final ByteBuffer tag;
  93.             tag = ByteBuffer.allocate(getTagLen());
  94.             evpCipherCtxCtrl(context, OpenSslEvpCtrlValues.AEAD_GET_TAG.getValue(), getTagLen(), tag);
  95.             tag.get(output, outputLength - getTagLen(), getTagLen());
  96.             len += getTagLen();
  97.         }

  98.         return len;
  99.     }

  100.     @Override
  101.     public int doFinal(final ByteBuffer input, final ByteBuffer output)
  102.             throws ShortBufferException, IllegalBlockSizeException, BadPaddingException {
  103.         checkState();

  104.         processAAD();

  105.         int totalLen = 0;
  106.         int len;
  107.         if (this.cipherMode == OpenSsl.DECRYPT_MODE) {
  108.             final ByteBuffer tag = ByteBuffer.allocate(getTagLen());

  109.             // if GCM-DECRYPT, we have to handle the buffered input
  110.             // and the retrieve the trailing tag from input
  111.             if (inBuffer != null && inBuffer.size() > 0) {
  112.                 final byte[] inputBytes = new byte[input.remaining()];
  113.                 input.get(inputBytes, 0, inputBytes.length);
  114.                 inBuffer.write(inputBytes, 0, inputBytes.length);
  115.                 final byte[] inputFinal = inBuffer.toByteArray();
  116.                 inBuffer.reset();

  117.                 if (inputFinal.length < getTagLen()) {
  118.                     throw new AEADBadTagException("Input too short - need tag");
  119.                 }

  120.                 len = OpenSslNative.updateByteArrayByteBuffer(context, inputFinal, 0,
  121.                         inputFinal.length - getTagLen(),
  122.                         output, output.position(), output.remaining());

  123.                 // retrieve tag
  124.                 tag.put(inputFinal, inputFinal.length - getTagLen(), getTagLen());

  125.             } else {
  126.                 // if no buffered input, just use the input directly
  127.                 if (input.remaining() < getTagLen()) {
  128.                     throw new AEADBadTagException("Input too short - need tag");
  129.                 }

  130.                 len = OpenSslNative.update(context, input, input.position(),
  131.                         input.remaining() - getTagLen(), output, output.position(),
  132.                         output.remaining());

  133.                 input.position(input.position() + len);

  134.                 // retrieve tag
  135.                 tag.put(input);
  136.             }
  137.             tag.flip();

  138.             // set tag to EVP_Cipher for integrity verification in doFinal
  139.             evpCipherCtxCtrl(context, OpenSslEvpCtrlValues.AEAD_SET_TAG.getValue(),
  140.                     getTagLen(), tag);
  141.         } else {
  142.             len = OpenSslNative.update(context, input, input.position(),
  143.                     input.remaining(), output, output.position(),
  144.                     output.remaining());
  145.             input.position(input.limit());
  146.         }

  147.         totalLen += len;
  148.         output.position(output.position() + len);

  149.         len = OpenSslNative.doFinal(context, output, output.position(),
  150.                 output.remaining());
  151.         output.position(output.position() + len);
  152.         totalLen += len;

  153.         // Keep the similar behavior as JCE, append the tag to end of output
  154.         if (this.cipherMode == OpenSsl.ENCRYPT_MODE) {
  155.             final ByteBuffer tag;
  156.             tag = ByteBuffer.allocate(getTagLen());
  157.             evpCipherCtxCtrl(context, OpenSslEvpCtrlValues.AEAD_GET_TAG.getValue(), getTagLen(), tag);
  158.             output.put(tag);
  159.             totalLen += getTagLen();
  160.         }

  161.         return totalLen;
  162.     }

  163.     /**
  164.      * Wraps of OpenSslNative.ctrl(long context, int type, int arg, byte[] data)
  165.      * Since native interface EVP_CIPHER_CTX *ctx, int type, int arg, void *ptr) is generic,
  166.      * it may set/get any native char or long type to the data buffer(ptr).
  167.      * Here we use ByteBuffer and set nativeOrder to handle the endianness.
  168.      *
  169.      * @param context The cipher context address
  170.      * @param type CtrlValues
  171.      * @param arg argument like a tag length
  172.      * @param data byte buffer or null
  173.      * @return return 0 if there is any error, else return 1.
  174.      */
  175.     private int evpCipherCtxCtrl(final long context, final int type, final int arg, final ByteBuffer data) {
  176.         checkState();
  177.         try {
  178.             if (data != null) {
  179.                 data.order(ByteOrder.nativeOrder());
  180.                 return OpenSslNative.ctrl(context, type, arg, data.array());
  181.             }
  182.             return OpenSslNative.ctrl(context, type, arg, null);
  183.         } catch (final Exception e) {
  184.             System.out.println(e.getMessage());
  185.             return 0;
  186.         }
  187.     }

  188.     private int getTagLen() {
  189.         return tagBitLen < 0 ? DEFAULT_TAG_LEN : tagBitLen >> 3;
  190.     }

  191.     @Override
  192.     public void init(final int mode, final byte[] key, final AlgorithmParameterSpec params)
  193.             throws InvalidAlgorithmParameterException {

  194.         if (aadBuffer == null) {
  195.             aadBuffer = new ByteArrayOutputStream();
  196.         } else {
  197.             aadBuffer.reset();
  198.         }

  199.         this.cipherMode = mode;
  200.         final byte[] iv;
  201.         if (!(params instanceof GCMParameterSpec)) {
  202.             // other AlgorithmParameterSpec is not supported now.
  203.             throw new InvalidAlgorithmParameterException("Illegal parameters");
  204.         }
  205.         final GCMParameterSpec gcmParam = (GCMParameterSpec) params;
  206.         iv = gcmParam.getIV();
  207.         this.tagBitLen = gcmParam.getTLen();

  208.         if (this.cipherMode == OpenSsl.DECRYPT_MODE) {
  209.             inBuffer = new ByteArrayOutputStream();
  210.         }

  211.         context = OpenSslNative.init(context, mode, algorithmMode, padding, key, iv);
  212.     }

  213.     private void processAAD() {
  214.         if (aadBuffer != null && aadBuffer.size() > 0) {
  215.             OpenSslNative.updateByteArray(context, aadBuffer.toByteArray(), 0, aadBuffer.size(), null, 0, 0);
  216.             aadBuffer = null;
  217.         }
  218.     }

  219.     @Override
  220.     public int update(final byte[] input, final int inputOffset, final int inputLen, final byte[] output, final int outputOffset)
  221.             throws ShortBufferException {
  222.         checkState();

  223.         processAAD();

  224.         if (this.cipherMode == OpenSsl.DECRYPT_MODE) {
  225.             // store internally until doFinal(decrypt) is called because
  226.             // spec mentioned that only return recovered data after tag
  227.             // is successfully verified
  228.             inBuffer.write(input, inputOffset, inputLen);
  229.             return 0;
  230.         }
  231.         return OpenSslNative.updateByteArray(context, input, inputOffset,
  232.                 inputLen, output, outputOffset, output.length - outputOffset);
  233.     }

  234.     @Override
  235.     public int update(final ByteBuffer input, final ByteBuffer output) throws ShortBufferException {
  236.         checkState();

  237.         processAAD();

  238.         final int len;
  239.         if (this.cipherMode == OpenSsl.DECRYPT_MODE) {
  240.             // store internally until doFinal(decrypt) is called because
  241.             // spec mentioned that only return recovered data after tag
  242.             // is successfully verified
  243.             final int inputLen = input.remaining();
  244.             final byte[] inputBuf = new byte[inputLen];
  245.             input.get(inputBuf, 0, inputLen);
  246.             inBuffer.write(inputBuf, 0, inputLen);
  247.             return 0;
  248.         }
  249.         len = OpenSslNative.update(context, input, input.position(),
  250.                 input.remaining(), output, output.position(),
  251.                 output.remaining());
  252.         input.position(input.limit());
  253.         output.position(output.position() + len);

  254.         return len;
  255.     }

  256.     @Override
  257.     public void updateAAD(final byte[] aad) {
  258.         // must be called after initialized.
  259.         if (aadBuffer == null) {
  260.             // update has already been called
  261.             throw new IllegalStateException("Update has been called; no more AAD data");
  262.         }
  263.         aadBuffer.write(aad, 0, aad.length);
  264.     }
  265. }