diff --git a/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferInputStream.java b/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferInputStream.java index e8a1b179a7a6..33135285ae95 100644 --- a/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferInputStream.java +++ b/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferInputStream.java @@ -77,8 +77,9 @@ public boolean markSupported() { } @Override - public void mark(int mark) { - this.mark = mark; + public void mark(int readLimit) { + Assert.isTrue(readLimit > 0, "readLimit must be greater than 0"); + this.mark = this.dataBuffer.readPosition(); } @Override diff --git a/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferTests.java b/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferTests.java index 786290a3f9f3..6cd9c4bc9b29 100644 --- a/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferTests.java +++ b/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferTests.java @@ -318,6 +318,9 @@ void inputStream(DataBufferFactory bufferFactory) throws Exception { assertThat(result).isEqualTo((byte) 'b'); assertThat(inputStream.available()).isEqualTo(3); + assertThat(inputStream.markSupported()).isTrue(); + inputStream.mark(2); + byte[] bytes = new byte[2]; int len = inputStream.read(bytes); assertThat(len).isEqualTo(2); @@ -333,6 +336,12 @@ void inputStream(DataBufferFactory bufferFactory) throws Exception { assertThat(inputStream.read()).isEqualTo(-1); assertThat(inputStream.read(bytes)).isEqualTo(-1); + inputStream.reset(); + bytes = new byte[3]; + len = inputStream.read(bytes); + assertThat(len).isEqualTo(3); + assertThat(bytes).containsExactly('c', 'd', 'e'); + release(buffer); }