1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19 package org.apache.commons.crypto.stream;
20
21 import static org.junit.jupiter.api.Assertions.assertArrayEquals;
22 import static org.junit.jupiter.api.Assertions.assertEquals;
23 import static org.junit.jupiter.api.Assertions.assertThrows;
24 import static org.junit.jupiter.api.Assumptions.assumeTrue;
25
26 import java.io.ByteArrayOutputStream;
27 import java.io.IOException;
28 import java.io.OutputStream;
29 import java.nio.ByteBuffer;
30 import java.security.SecureRandom;
31 import java.util.Arrays;
32 import java.util.Objects;
33 import java.util.Properties;
34 import java.util.Random;
35
36 import javax.crypto.spec.IvParameterSpec;
37
38 import org.apache.commons.crypto.Crypto;
39 import org.apache.commons.crypto.cipher.AbstractCipherTest;
40 import org.apache.commons.crypto.cipher.CryptoCipher;
41 import org.apache.commons.crypto.stream.input.Input;
42 import org.apache.commons.crypto.utils.AES;
43 import org.apache.commons.crypto.utils.ReflectionUtils;
44 import org.junit.jupiter.api.BeforeEach;
45 import org.junit.jupiter.api.Test;
46
47 public class PositionedCryptoInputStreamTest {
48
49 static class PositionedInputForTest implements Input {
50
51 final byte[] data;
52 long pos;
53 final long count;
54
55 public PositionedInputForTest(final byte[] data) {
56 this.data = data;
57 this.pos = 0;
58 this.count = data.length;
59 }
60
61 @Override
62 public int available() {
63 return (int) (count - pos);
64 }
65
66 @Override
67 public void close() {
68 }
69
70 @Override
71 public int read(final ByteBuffer dst) {
72 final int remaining = (int) (count - pos);
73 if (remaining <= 0) {
74 return -1;
75 }
76
77 final int length = Math.min(dst.remaining(), remaining);
78 dst.put(data, (int) pos, length);
79 pos += length;
80 return length;
81 }
82
83 @Override
84 public int read(final long position, final byte[] buffer, final int offset, int length) {
85 Objects.requireNonNull(buffer, "buffer");
86 if (offset < 0 || length < 0
87 || length > buffer.length - offset) {
88 throw new IndexOutOfBoundsException();
89 }
90
91 if (position < 0 || position >= count) {
92 return -1;
93 }
94
95 final long avail = count - position;
96 if (length > avail) {
97 length = (int) avail;
98 }
99 if (length <= 0) {
100 return 0;
101 }
102 System.arraycopy(data, (int) position, buffer, offset, length);
103 return length;
104 }
105
106 @Override
107 public void seek(final long position) throws IOException {
108 if (pos < 0) {
109 throw new IOException("Negative seek offset");
110 }
111 if (position >= 0 && position < count) {
112 pos = position;
113 } else {
114
115 pos = count;
116 }
117 }
118
119 @Override
120 public long skip(long n) {
121 if (n <= 0) {
122 return 0;
123 }
124
125 final long remaining = count - pos;
126 if (remaining < n) {
127 n = remaining;
128 }
129 pos += n;
130
131 return n;
132 }
133 }
134 private final int dataLen = 20000;
135 private final byte[] testData = new byte[dataLen];
136 private byte[] encData;
137 private final Properties props = new Properties();
138 private final byte[] key = new byte[16];
139 private final byte[] iv = new byte[16];
140 private final int bufferSize = 2048;
141 private final int bufferSizeLess = bufferSize - 1;
142 private final int bufferSizeMore = bufferSize + 1;
143 private final int length = 1024;
144 private final int lengthLess = length - 1;
145
146 private final int lengthMore = length + 1;
147
148 private final String transformation = AES.CTR_NO_PADDING;
149
150 @BeforeEach
151 public void before() throws IOException {
152 final Random random = new SecureRandom();
153 random.nextBytes(testData);
154 random.nextBytes(key);
155 random.nextBytes(iv);
156 prepareData();
157 }
158
159
160 private void compareByteArray(final byte[] data1, final int pos, final byte[] data2,
161 final int length) {
162 final byte[] expectedData = new byte[length];
163 final byte[] realData = new byte[length];
164
165 System.arraycopy(data1, pos, expectedData, 0, length);
166
167 System.arraycopy(data2, 0, realData, 0, length);
168 assertArrayEquals(expectedData, realData);
169 }
170
171 private void doMultipleReadTest() throws Exception{
172 final PositionedCryptoInputStream in = getCryptoInputStream(0);
173 final String cipherClass = in.getCipher().getClass().getName();
174 doMultipleReadTest(cipherClass);
175 }
176
177
178
179 private void doMultipleReadTest(final String cipherClass) throws Exception {
180 try (PositionedCryptoInputStream in = getCryptoInputStream(getCipher(cipherClass), bufferSize)) {
181 int position = 0;
182 while (in.available() > 0) {
183 final ByteBuffer buf = ByteBuffer.allocate(length);
184 final byte[] bytes1 = new byte[length];
185 final byte[] bytes2 = new byte[lengthLess];
186
187 final int pn1 = in.read(position, bytes1, 0, length);
188 final int n = in.read(buf);
189 final int pn2 = in.read(position, bytes2, 0, lengthLess);
190
191
192 if (pn1 > 0) {
193 compareByteArray(testData, position, bytes1, pn1);
194 }
195
196 if (pn2 > 0) {
197 compareByteArray(testData, position, bytes2, pn2);
198 }
199
200 if (n <= 0) {
201 break;
202 }
203 compareByteArray(testData, position, buf.array(), n);
204 position += n;
205 }
206 }
207 }
208
209 private void doPositionedReadTests() throws Exception {
210 final PositionedCryptoInputStream in = getCryptoInputStream(0);
211 final String cipherClass = in.getCipher().getClass().getName();
212 doPositionedReadTests(cipherClass);
213 }
214
215 private void doPositionedReadTests(final String cipherClass) throws Exception {
216
217 testPositionedReadLoop(cipherClass, 0, length, bufferSize, dataLen);
218 testPositionedReadLoop(cipherClass, 0, length, bufferSizeLess, dataLen);
219 testPositionedReadLoop(cipherClass, 0, length, bufferSizeMore, dataLen);
220
221 testPositionedReadLoop(cipherClass, dataLen / 2, length, bufferSize,
222 dataLen);
223 testPositionedReadLoop(cipherClass, dataLen / 2 - 1, length,
224 bufferSizeLess, dataLen);
225 testPositionedReadLoop(cipherClass, dataLen / 2 + 1, length,
226 bufferSizeMore, dataLen);
227
228 testPositionedReadNone(cipherClass, -1, length, bufferSize);
229 testPositionedReadNone(cipherClass, dataLen, length, bufferSize);
230 }
231
232 private void doReadFullyTests() throws Exception {
233 final PositionedCryptoInputStream in = getCryptoInputStream(0);
234 final String cipherClass = in.getCipher().getClass().getName();
235 doReadFullyTests(cipherClass);
236 }
237
238 private void doReadFullyTests(final String cipherClass) throws Exception {
239
240 testReadFullyLoop(cipherClass, 0, length, bufferSize, dataLen);
241 testReadFullyLoop(cipherClass, 0, length, bufferSizeLess, dataLen);
242 testReadFullyLoop(cipherClass, 0, length, bufferSizeMore, dataLen);
243
244 testReadFullyLoop(cipherClass, 0, length, bufferSize, dataLen);
245 testReadFullyLoop(cipherClass, 0, lengthLess, bufferSize, dataLen);
246 testReadFullyLoop(cipherClass, 0, lengthMore, bufferSize, dataLen);
247
248 testReadFullyFailed(cipherClass, -1, length, bufferSize);
249 testReadFullyFailed(cipherClass, dataLen, length, bufferSize);
250 testReadFullyFailed(cipherClass, dataLen - length + 1, length,
251 bufferSize);
252 }
253
254 private void doSeekTests() throws Exception{
255 final PositionedCryptoInputStream in = getCryptoInputStream(0);
256 final String cipherClass = in.getCipher().getClass().getName();
257 doSeekTests(cipherClass);
258 }
259
260 private void doSeekTests(final String cipherClass) throws Exception {
261
262 testSeekLoop(cipherClass, 0, length, bufferSize);
263 testSeekLoop(cipherClass, 0, lengthLess, bufferSize);
264 testSeekLoop(cipherClass, 0, lengthMore, bufferSize);
265
266 testSeekLoop(cipherClass, dataLen, length, bufferSize);
267
268 testSeekFailed(cipherClass, -1, bufferSize);
269 }
270
271 @Test
272 public void doTestJCE() throws Exception {
273 testCipher(AbstractCipherTest.JCE_CIPHER_CLASSNAME);
274 }
275
276 @Test
277 public void doTestJNI() throws Exception {
278 assumeTrue(Crypto.isNativeCodeLoaded());
279 testCipher(AbstractCipherTest.OPENSSL_CIPHER_CLASSNAME);
280 }
281
282 private CryptoCipher getCipher(final String cipherClass) throws IOException {
283 try {
284 return (CryptoCipher) ReflectionUtils.newInstance(
285 ReflectionUtils.getClassByName(cipherClass), props,
286 transformation);
287 } catch (final ClassNotFoundException cnfe) {
288 throw new IOException("Illegal crypto cipher!");
289 }
290 }
291
292 private PositionedCryptoInputStream getCryptoInputStream(
293 final CryptoCipher cipher, final int bufferSize) throws IOException {
294 return new PositionedCryptoInputStream(props, new PositionedInputForTest(
295 Arrays.copyOf(encData, encData.length)), cipher, bufferSize,
296 key, iv, 0);
297 }
298
299 private PositionedCryptoInputStream getCryptoInputStream(final int streamOffset)
300 throws IOException {
301 return new PositionedCryptoInputStream(props, new PositionedInputForTest(
302 Arrays.copyOf(encData, encData.length)), key, iv, streamOffset);
303 }
304
305 private void prepareData() throws IOException {
306 final CryptoCipher cipher;
307 try {
308 cipher = (CryptoCipher) ReflectionUtils.newInstance(
309 ReflectionUtils.getClassByName(AbstractCipherTest.JCE_CIPHER_CLASSNAME), props,
310 transformation);
311 } catch (final ClassNotFoundException cnfe) {
312 throw new IOException("Illegal crypto cipher!");
313 }
314
315 final ByteArrayOutputStream baos = new ByteArrayOutputStream();
316
317 try (final OutputStream out = new CryptoOutputStream(baos, cipher, bufferSize,
318 AES.newSecretKeySpec(key), new IvParameterSpec(iv))) {
319 out.write(testData);
320 out.flush();
321 }
322 encData = baos.toByteArray();
323 }
324
325 protected void testCipher(final String cipherClass) throws Exception {
326 doPositionedReadTests(cipherClass);
327 doPositionedReadTests();
328 doReadFullyTests(cipherClass);
329 doReadFullyTests();
330 doSeekTests(cipherClass);
331 doSeekTests();
332 doMultipleReadTest(cipherClass);
333 doMultipleReadTest();
334 }
335
336 private void testPositionedReadLoop(final String cipherClass, int position,
337 final int length, final int bufferSize, final int total) throws Exception {
338 try (PositionedCryptoInputStream in = getCryptoInputStream(getCipher(cipherClass), bufferSize)) {
339
340 while (position < total) {
341 final byte[] bytes = new byte[length];
342 final int n = in.read(position, bytes, 0, length);
343 if (n < 0) {
344 break;
345 }
346 compareByteArray(testData, position, bytes, n);
347 position += n;
348 }
349 }
350 }
351
352
353 private void testPositionedReadNone(final String cipherClass, final int position,
354 final int length, final int bufferSize) throws Exception {
355 try (PositionedCryptoInputStream in = getCryptoInputStream(getCipher(cipherClass), bufferSize)) {
356 final byte[] bytes = new byte[length];
357 final int n = in.read(position, bytes, 0, length);
358 assertEquals(n, -1);
359 }
360 }
361
362
363 private void testReadFullyFailed(final String cipherClass, final int position,
364 final int length, final int bufferSize) throws Exception {
365 try (final PositionedCryptoInputStream in = getCryptoInputStream(getCipher(cipherClass), bufferSize)) {
366 final byte[] bytes = new byte[length];
367 assertThrows(IOException.class, () -> in.readFully(position, bytes, 0, length));
368 in.close();
369 in.close();
370 }
371 }
372
373 private void testReadFullyLoop(final String cipherClass, int position,
374 final int length, final int bufferSize, final int total) throws Exception {
375 try (PositionedCryptoInputStream in = getCryptoInputStream(
376 getCipher(cipherClass), bufferSize)) {
377
378
379 while (position + length <= total) {
380 final byte[] bytes = new byte[length];
381 in.readFully(position, bytes);
382 compareByteArray(testData, position, bytes, length);
383 position += length;
384 }
385
386 }
387 }
388
389
390 private void testSeekFailed(final String cipherClass, final int position, final int bufferSize)
391 throws Exception {
392 try (final PositionedCryptoInputStream in = getCryptoInputStream(getCipher(cipherClass), bufferSize)) {
393 assertThrows(IllegalArgumentException.class, () -> in.seek(position));
394 }
395 }
396
397 private void testSeekLoop(final String cipherClass, int position, final int length,
398 final int bufferSize) throws Exception {
399 try (PositionedCryptoInputStream in = getCryptoInputStream(getCipher(cipherClass), bufferSize)) {
400 while (in.available() > 0) {
401 in.seek(position);
402 final ByteBuffer buf = ByteBuffer.allocate(length);
403 final int n = in.read(buf);
404 if (n <= 0) {
405 break;
406 }
407 compareByteArray(testData, position, buf.array(), n);
408 position += n;
409 }
410 }
411 }
412 }