diff --git a/futures-util/src/io/buf_reader.rs b/futures-util/src/io/buf_reader.rs index 5931edc1b2..6608d43ebe 100644 --- a/futures-util/src/io/buf_reader.rs +++ b/futures-util/src/io/buf_reader.rs @@ -1,4 +1,5 @@ use super::DEFAULT_BUF_SIZE; +use futures_core::future::Future; use futures_core::ready; use futures_core::task::{Context, Poll}; #[cfg(feature = "read-initializer")] @@ -73,6 +74,43 @@ impl BufReader { } } +impl BufReader { + /// Seeks relative to the current position. If the new position lies within the buffer, + /// the buffer will not be flushed, allowing for more efficient seeks. + /// This method does not return the location of the underlying reader, so the caller + /// must track this information themselves if it is required. + pub fn seek_relative(&mut self, offset: i64) -> SeeKRelative<'_, R> + where + R: Unpin, + { + SeeKRelative { inner: self, offset, first: true } + } + + /// Attempts to seek relative to the current position. If the new position lies within the buffer, + /// the buffer will not be flushed, allowing for more efficient seeks. + /// This method does not return the location of the underlying reader, so the caller + /// must track this information themselves if it is required. + pub fn poll_seek_relative( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + offset: i64, + ) -> Poll> { + let pos = self.pos as u64; + if offset < 0 { + if let Some(new_pos) = pos.checked_sub((-offset) as u64) { + *self.project().pos = new_pos as usize; + return Poll::Ready(Ok(())); + } + } else if let Some(new_pos) = pos.checked_add(offset as u64) { + if new_pos <= self.cap as u64 { + *self.project().pos = new_pos as usize; + return Poll::Ready(Ok(())); + } + } + self.poll_seek(cx, SeekFrom::Current(offset)).map(|res| res.map(|_| ())) + } +} + impl AsyncRead for BufReader { fn poll_read( mut self: Pin<&mut Self>, @@ -163,6 +201,10 @@ impl AsyncSeek for BufReader { /// `.into_inner()` immediately after a seek yields the underlying reader /// at the same position. /// + /// To seek without discarding the internal buffer, use + /// [`BufReader::seek_relative`](BufReader::seek_relative) or + /// [`BufReader::poll_seek_relative`](BufReader::poll_seek_relative). + /// /// See [`AsyncSeek`](futures_io::AsyncSeek) for more details. /// /// Note: In the edge case where you're seeking with `SeekFrom::Current(n)` @@ -200,3 +242,32 @@ impl AsyncSeek for BufReader { Poll::Ready(Ok(result)) } } + +/// Future for the [`BufReader::seek_relative`](self::BufReader::seek_relative) method. +#[derive(Debug)] +#[must_use = "futures do nothing unless polled"] +pub struct SeeKRelative<'a, R> { + inner: &'a mut BufReader, + offset: i64, + first: bool, +} + +impl Future for SeeKRelative<'_, R> +where + R: AsyncRead + AsyncSeek + Unpin, +{ + type Output = io::Result<()>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let offset = self.offset; + if self.first { + self.first = false; + Pin::new(&mut *self.inner).as_mut().poll_seek_relative(cx, offset) + } else { + Pin::new(&mut *self.inner) + .as_mut() + .poll_seek(cx, SeekFrom::Current(offset)) + .map(|res| res.map(|_| ())) + } + } +} diff --git a/futures-util/src/io/mod.rs b/futures-util/src/io/mod.rs index b96223d1c1..16cf5a7bab 100644 --- a/futures-util/src/io/mod.rs +++ b/futures-util/src/io/mod.rs @@ -56,7 +56,7 @@ mod allow_std; pub use self::allow_std::AllowStdIo; mod buf_reader; -pub use self::buf_reader::BufReader; +pub use self::buf_reader::{BufReader, SeeKRelative}; mod buf_writer; pub use self::buf_writer::BufWriter; diff --git a/futures/tests/io_buf_reader.rs b/futures/tests/io_buf_reader.rs index d60df879c2..060453615e 100644 --- a/futures/tests/io_buf_reader.rs +++ b/futures/tests/io_buf_reader.rs @@ -130,6 +130,57 @@ fn test_buffered_reader_seek() { assert_eq!(block_on(reader.seek(SeekFrom::Current(-2))).ok(), Some(3)); } +#[test] +fn test_buffered_reader_seek_relative() { + let inner: &[u8] = &[5, 6, 7, 0, 1, 2, 3, 4]; + let mut reader = BufReader::with_capacity(2, Cursor::new(inner)); + + assert!(block_on(reader.seek_relative(3)).is_ok()); + assert_eq!(run_fill_buf!(reader).ok(), Some(&[0, 1][..])); + assert!(block_on(reader.seek_relative(0)).is_ok()); + assert_eq!(run_fill_buf!(reader).ok(), Some(&[0, 1][..])); + assert!(block_on(reader.seek_relative(1)).is_ok()); + assert_eq!(run_fill_buf!(reader).ok(), Some(&[1][..])); + assert!(block_on(reader.seek_relative(-1)).is_ok()); + assert_eq!(run_fill_buf!(reader).ok(), Some(&[0, 1][..])); + assert!(block_on(reader.seek_relative(2)).is_ok()); + assert_eq!(run_fill_buf!(reader).ok(), Some(&[2, 3][..])); +} + +#[test] +fn test_buffered_reader_invalidated_after_read() { + let inner: &[u8] = &[5, 6, 7, 0, 1, 2, 3, 4]; + let mut reader = BufReader::with_capacity(3, Cursor::new(inner)); + + assert_eq!(run_fill_buf!(reader).ok(), Some(&[5, 6, 7][..])); + Pin::new(&mut reader).consume(3); + + let mut buffer = [0, 0, 0, 0, 0]; + assert_eq!(block_on(reader.read(&mut buffer)).ok(), Some(5)); + assert_eq!(buffer, [0, 1, 2, 3, 4]); + + assert!(block_on(reader.seek_relative(-2)).is_ok()); + let mut buffer = [0, 0]; + assert_eq!(block_on(reader.read(&mut buffer)).ok(), Some(2)); + assert_eq!(buffer, [3, 4]); +} + +#[test] +fn test_buffered_reader_invalidated_after_seek() { + let inner: &[u8] = &[5, 6, 7, 0, 1, 2, 3, 4]; + let mut reader = BufReader::with_capacity(3, Cursor::new(inner)); + + assert_eq!(run_fill_buf!(reader).ok(), Some(&[5, 6, 7][..])); + Pin::new(&mut reader).consume(3); + + assert!(block_on(reader.seek(SeekFrom::Current(5))).is_ok()); + + assert!(block_on(reader.seek_relative(-2)).is_ok()); + let mut buffer = [0, 0]; + assert_eq!(block_on(reader.read(&mut buffer)).ok(), Some(2)); + assert_eq!(buffer, [3, 4]); +} + #[test] fn test_buffered_reader_seek_underflow() { // gimmick reader that yields its position modulo 256 for each byte