OpenSslGaloisCounterMode.java

 /*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.commons.crypto.cipher;

import java.io.ByteArrayOutputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.security.InvalidAlgorithmParameterException;
import java.security.spec.AlgorithmParameterSpec;

import javax.crypto.AEADBadTagException;
import javax.crypto.BadPaddingException;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.ShortBufferException;
import javax.crypto.spec.GCMParameterSpec;

/**
 * This class do the real work(Encryption/Decryption/Authentication) for the authenticated mode: GCM.
 *
 * It calls the OpenSSL API to implement the JCE-like behavior
 *
 * @since 1.1
 */
final class OpenSslGaloisCounterMode extends AbstractOpenSslFeedbackCipher {

    static final int DEFAULT_TAG_LEN = 16;
    // buffer for AAD data; if consumed, set as null
    private ByteArrayOutputStream aadBuffer = new ByteArrayOutputStream();

    private int tagBitLen = -1;

    // buffer for storing input in decryption, not used for encryption
    private ByteArrayOutputStream inBuffer;

    public OpenSslGaloisCounterMode(final long context, final int algorithmMode, final int padding) {
        super(context, algorithmMode, padding);
    }

    @Override
    public void clean() {
        super.clean();
        aadBuffer = null;
    }

    @Override
    public int doFinal(final byte[] input, final int inputOffset, final int inputLen, final byte[] output, final int outputOffset)
            throws ShortBufferException, IllegalBlockSizeException, BadPaddingException {
        checkState();

        processAAD();

        final int outputLength = output.length;
        int len;
        if (this.cipherMode == OpenSsl.DECRYPT_MODE) {
            // if GCM-DECRYPT, we have to handle the buffered input
            // and the retrieve the trailing tag from input
            int inputOffsetFinal = inputOffset;
            int inputLenFinal = inputLen;
            final byte[] inputFinal;
            if (inBuffer != null && inBuffer.size() > 0) {
                inBuffer.write(input, inputOffset, inputLen);
                inputFinal = inBuffer.toByteArray();
                inputOffsetFinal = 0;
                inputLenFinal = inputFinal.length;
                inBuffer.reset();
            } else {
                inputFinal = input;
            }

            if (inputFinal.length < getTagLen()) {
                throw new AEADBadTagException("Input too short - need tag");
            }

            final int inputDataLen = inputLenFinal - getTagLen();
            len = OpenSslNative.updateByteArray(context, inputFinal, inputOffsetFinal,
                    inputDataLen, output, outputOffset, outputLength - outputOffset);

            // set tag to EVP_Cipher for integrity verification in doFinal
            final ByteBuffer tag = ByteBuffer.allocate(getTagLen());
            tag.put(input, input.length - getTagLen(), getTagLen());
            tag.flip();
            evpCipherCtxCtrl(context, OpenSslEvpCtrlValues.AEAD_SET_TAG.getValue(), getTagLen(), tag);
        } else {
            len = OpenSslNative.updateByteArray(context, input, inputOffset,
                    inputLen, output, outputOffset, outputLength - outputOffset);
        }

        len += OpenSslNative.doFinalByteArray(context, output, outputOffset + len,
                outputLength - outputOffset - len);

        // Keep the similar behavior as JCE, append the tag to end of output
        if (this.cipherMode == OpenSsl.ENCRYPT_MODE) {
            final ByteBuffer tag;
            tag = ByteBuffer.allocate(getTagLen());
            evpCipherCtxCtrl(context, OpenSslEvpCtrlValues.AEAD_GET_TAG.getValue(), getTagLen(), tag);
            tag.get(output, outputLength - getTagLen(), getTagLen());
            len += getTagLen();
        }

        return len;
    }

    @Override
    public int doFinal(final ByteBuffer input, final ByteBuffer output)
            throws ShortBufferException, IllegalBlockSizeException, BadPaddingException {
        checkState();

        processAAD();

        int totalLen = 0;
        int len;
        if (this.cipherMode == OpenSsl.DECRYPT_MODE) {
            final ByteBuffer tag = ByteBuffer.allocate(getTagLen());

            // if GCM-DECRYPT, we have to handle the buffered input
            // and the retrieve the trailing tag from input
            if (inBuffer != null && inBuffer.size() > 0) {
                final byte[] inputBytes = new byte[input.remaining()];
                input.get(inputBytes, 0, inputBytes.length);
                inBuffer.write(inputBytes, 0, inputBytes.length);
                final byte[] inputFinal = inBuffer.toByteArray();
                inBuffer.reset();

                if (inputFinal.length < getTagLen()) {
                    throw new AEADBadTagException("Input too short - need tag");
                }

                len = OpenSslNative.updateByteArrayByteBuffer(context, inputFinal, 0,
                        inputFinal.length - getTagLen(),
                        output, output.position(), output.remaining());

                // retrieve tag
                tag.put(inputFinal, inputFinal.length - getTagLen(), getTagLen());

            } else {
                // if no buffered input, just use the input directly
                if (input.remaining() < getTagLen()) {
                    throw new AEADBadTagException("Input too short - need tag");
                }

                len = OpenSslNative.update(context, input, input.position(),
                        input.remaining() - getTagLen(), output, output.position(),
                        output.remaining());

                input.position(input.position() + len);

                // retrieve tag
                tag.put(input);
            }
            tag.flip();

            // set tag to EVP_Cipher for integrity verification in doFinal
            evpCipherCtxCtrl(context, OpenSslEvpCtrlValues.AEAD_SET_TAG.getValue(),
                    getTagLen(), tag);
        } else {
            len = OpenSslNative.update(context, input, input.position(),
                    input.remaining(), output, output.position(),
                    output.remaining());
            input.position(input.limit());
        }

        totalLen += len;
        output.position(output.position() + len);

        len = OpenSslNative.doFinal(context, output, output.position(),
                output.remaining());
        output.position(output.position() + len);
        totalLen += len;

        // Keep the similar behavior as JCE, append the tag to end of output
        if (this.cipherMode == OpenSsl.ENCRYPT_MODE) {
            final ByteBuffer tag;
            tag = ByteBuffer.allocate(getTagLen());
            evpCipherCtxCtrl(context, OpenSslEvpCtrlValues.AEAD_GET_TAG.getValue(), getTagLen(), tag);
            output.put(tag);
            totalLen += getTagLen();
        }

        return totalLen;
    }

    /**
     * Wraps of OpenSslNative.ctrl(long context, int type, int arg, byte[] data)
     * Since native interface EVP_CIPHER_CTX *ctx, int type, int arg, void *ptr) is generic,
     * it may set/get any native char or long type to the data buffer(ptr).
     * Here we use ByteBuffer and set nativeOrder to handle the endianness.
     *
     * @param context The cipher context address
     * @param type CtrlValues
     * @param arg argument like a tag length
     * @param data byte buffer or null
     * @return return 0 if there is any error, else return 1.
     */
    private int evpCipherCtxCtrl(final long context, final int type, final int arg, final ByteBuffer data) {
        checkState();
        try {
            if (data != null) {
                data.order(ByteOrder.nativeOrder());
                return OpenSslNative.ctrl(context, type, arg, data.array());
            }
            return OpenSslNative.ctrl(context, type, arg, null);
        } catch (final Exception e) {
            System.out.println(e.getMessage());
            return 0;
        }
    }

    private int getTagLen() {
        return tagBitLen < 0 ? DEFAULT_TAG_LEN : tagBitLen >> 3;
    }

    @Override
    public void init(final int mode, final byte[] key, final AlgorithmParameterSpec params)
            throws InvalidAlgorithmParameterException {

        if (aadBuffer == null) {
            aadBuffer = new ByteArrayOutputStream();
        } else {
            aadBuffer.reset();
        }

        this.cipherMode = mode;
        final byte[] iv;
        if (!(params instanceof GCMParameterSpec)) {
            // other AlgorithmParameterSpec is not supported now.
            throw new InvalidAlgorithmParameterException("Illegal parameters");
        }
        final GCMParameterSpec gcmParam = (GCMParameterSpec) params;
        iv = gcmParam.getIV();
        this.tagBitLen = gcmParam.getTLen();

        if (this.cipherMode == OpenSsl.DECRYPT_MODE) {
            inBuffer = new ByteArrayOutputStream();
        }

        context = OpenSslNative.init(context, mode, algorithmMode, padding, key, iv);
    }

    private void processAAD() {
        if (aadBuffer != null && aadBuffer.size() > 0) {
            OpenSslNative.updateByteArray(context, aadBuffer.toByteArray(), 0, aadBuffer.size(), null, 0, 0);
            aadBuffer = null;
        }
    }

    @Override
    public int update(final byte[] input, final int inputOffset, final int inputLen, final byte[] output, final int outputOffset)
            throws ShortBufferException {
        checkState();

        processAAD();

        if (this.cipherMode == OpenSsl.DECRYPT_MODE) {
            // store internally until doFinal(decrypt) is called because
            // spec mentioned that only return recovered data after tag
            // is successfully verified
            inBuffer.write(input, inputOffset, inputLen);
            return 0;
        }
        return OpenSslNative.updateByteArray(context, input, inputOffset,
                inputLen, output, outputOffset, output.length - outputOffset);
    }

    @Override
    public int update(final ByteBuffer input, final ByteBuffer output) throws ShortBufferException {
        checkState();

        processAAD();

        final int len;
        if (this.cipherMode == OpenSsl.DECRYPT_MODE) {
            // store internally until doFinal(decrypt) is called because
            // spec mentioned that only return recovered data after tag
            // is successfully verified
            final int inputLen = input.remaining();
            final byte[] inputBuf = new byte[inputLen];
            input.get(inputBuf, 0, inputLen);
            inBuffer.write(inputBuf, 0, inputLen);
            return 0;
        }
        len = OpenSslNative.update(context, input, input.position(),
                input.remaining(), output, output.position(),
                output.remaining());
        input.position(input.limit());
        output.position(output.position() + len);

        return len;
    }

    @Override
    public void updateAAD(final byte[] aad) {
        // must be called after initialized.
        if (aadBuffer == null) {
            // update has already been called
            throw new IllegalStateException("Update has been called; no more AAD data");
        }
        aadBuffer.write(aad, 0, aad.length);
    }
}