View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one
3    * or more contributor license agreements.  See the NOTICE file
4    * distributed with this work for additional information
5    * regarding copyright ownership.  The ASF licenses this file
6    * to you under the Apache License, Version 2.0 (the
7    * "License"); you may not use this file except in compliance
8    * with the License.  You may obtain a copy of the License at
9    *
10   *   https://www.apache.org/licenses/LICENSE-2.0
11   *
12   * Unless required by applicable law or agreed to in writing,
13   * software distributed under the License is distributed on an
14   * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15   * KIND, either express or implied.  See the License for the
16   * specific language governing permissions and limitations
17   * under the License.
18   */
19  
20  package org.apache.commons.compress.archivers.zip;
21  
22  import static org.junit.jupiter.api.Assertions.assertEquals;
23  import static org.junit.jupiter.api.Assertions.assertThrows;
24  import static org.mockito.ArgumentMatchers.any;
25  import static org.mockito.ArgumentMatchers.eq;
26  import static org.mockito.Mockito.mock;
27  import static org.mockito.Mockito.times;
28  import static org.mockito.Mockito.verify;
29  import static org.mockito.Mockito.when;
30  
31  import java.io.IOException;
32  import java.nio.ByteBuffer;
33  import java.nio.channels.FileChannel;
34  import java.nio.channels.SeekableByteChannel;
35  import java.nio.charset.StandardCharsets;
36  import java.nio.file.Files;
37  import java.nio.file.Path;
38  import java.nio.file.StandardOpenOption;
39  
40  import org.apache.commons.compress.AbstractTempDirTest;
41  import org.junit.jupiter.api.Test;
42  
43  /**
44   * Tests {@link SeekableChannelRandomAccessOutputStream}.
45   */
46  class SeekableChannelRandomAccessOutputStreamTest extends AbstractTempDirTest {
47  
48      @Test
49      void testInitialization() throws IOException {
50          final Path file = newTempPath("testChannel");
51          try (SeekableChannelRandomAccessOutputStream stream = new SeekableChannelRandomAccessOutputStream(
52                  Files.newByteChannel(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE))) {
53              assertEquals(0, stream.position());
54          }
55      }
56  
57      @Test
58      void testWrite() throws IOException {
59          final FileChannel channel = mock(FileChannel.class);
60          final SeekableChannelRandomAccessOutputStream stream = new SeekableChannelRandomAccessOutputStream(channel);
61          when(channel.position()).thenReturn(11L);
62          when(channel.write((ByteBuffer) any())).thenAnswer(answer -> {
63              ((ByteBuffer) answer.getArgument(0)).position(5);
64              return 5;
65          }).thenAnswer(answer -> {
66              ((ByteBuffer) answer.getArgument(0)).position(6);
67              return 6;
68          });
69          stream.write("hello".getBytes(StandardCharsets.UTF_8));
70          stream.write("world\n".getBytes(StandardCharsets.UTF_8));
71          verify(channel, times(2)).write((ByteBuffer) any());
72          assertEquals(11, stream.position());
73      }
74  
75      @Test
76      void testWriteFullyAt_whenFullAtOnce_thenSucceed() throws IOException {
77          try (SeekableByteChannel channel = mock(SeekableByteChannel.class);
78                  SeekableChannelRandomAccessOutputStream stream = new SeekableChannelRandomAccessOutputStream(channel)) {
79              when(channel.position()).thenReturn(50L).thenReturn(60L);
80              when(channel.write((ByteBuffer) any())).thenAnswer(answer -> {
81                  ((ByteBuffer) answer.getArgument(0)).position(5);
82                  return 5;
83              }).thenAnswer(answer -> {
84                  ((ByteBuffer) answer.getArgument(0)).position(6);
85                  return 6;
86              });
87              stream.writeAll("hello".getBytes(StandardCharsets.UTF_8), 20);
88              stream.writeAll("world\n".getBytes(StandardCharsets.UTF_8), 30);
89              verify(channel, times(2)).write((ByteBuffer) any());
90              verify(channel, times(1)).position(eq(50L));
91              verify(channel, times(1)).position(eq(60L));
92              assertEquals(60L, stream.position());
93          }
94      }
95  
96      @Test
97      void testWriteFullyAt_whenFullButPartial_thenSucceed() throws IOException {
98          try (SeekableByteChannel channel = mock(SeekableByteChannel.class);
99                  SeekableChannelRandomAccessOutputStream stream = new SeekableChannelRandomAccessOutputStream(channel)) {
100             when(channel.position()).thenReturn(50L).thenReturn(60L);
101             when(channel.write((ByteBuffer) any())).thenAnswer(answer -> {
102                 ((ByteBuffer) answer.getArgument(0)).position(3);
103                 return 3;
104             }).thenAnswer(answer -> {
105                 ((ByteBuffer) answer.getArgument(0)).position(5);
106                 return 2;
107             }).thenAnswer(answer -> {
108                 ((ByteBuffer) answer.getArgument(0)).position(6);
109                 return 6;
110             });
111             stream.writeAll("hello".getBytes(StandardCharsets.UTF_8), 20);
112             stream.writeAll("world\n".getBytes(StandardCharsets.UTF_8), 30);
113             verify(channel, times(3)).write((ByteBuffer) any());
114             verify(channel, times(1)).position(eq(50L));
115             verify(channel, times(1)).position(eq(60L));
116             assertEquals(60L, stream.position());
117         }
118     }
119 
120     @Test
121     void testWriteFullyAt_whenPartial_thenFail() throws IOException {
122         try (SeekableByteChannel channel = mock(SeekableByteChannel.class);
123                 SeekableChannelRandomAccessOutputStream stream = new SeekableChannelRandomAccessOutputStream(channel)) {
124             when(channel.position()).thenReturn(50L);
125             when(channel.write((ByteBuffer) any())).thenAnswer(answer -> {
126                 ((ByteBuffer) answer.getArgument(0)).position(3);
127                 return 3;
128             }).thenAnswer(answer -> 0).thenAnswer(answer -> -1);
129             assertThrows(IOException.class, () -> stream.writeAll("hello".getBytes(StandardCharsets.UTF_8), 20));
130             verify(channel, times(3)).write((ByteBuffer) any());
131             assertEquals(50L, stream.position());
132         }
133     }
134 }