diff --git a/tokio-io/src/io/async_read_ext.rs b/tokio-io/src/io/async_read_ext.rs index 0976fd1217e..69265df26ec 100644 --- a/tokio-io/src/io/async_read_ext.rs +++ b/tokio-io/src/io/async_read_ext.rs @@ -1,12 +1,27 @@ +use crate::io::chain::{chain, Chain}; use crate::io::copy::{copy, Copy}; use crate::io::read::{read, Read}; use crate::io::read_exact::{read_exact, ReadExact}; use crate::io::read_to_end::{read_to_end, ReadToEnd}; use crate::io::read_to_string::{read_to_string, ReadToString}; +use crate::io::take::{take, Take}; use crate::{AsyncRead, AsyncWrite}; /// An extension trait which adds utility methods to `AsyncRead` types. pub trait AsyncReadExt: AsyncRead { + /// Creates an adaptor which will chain this stream with another. + /// + /// The returned `AsyncRead` instance will first read all bytes from this object + /// until EOF is encountered. Afterwards the output is equivalent to the + /// output of `next`. + fn chain(self, next: R) -> Chain + where + Self: Sized, + R: AsyncRead, + { + chain(self, next) + } + /// Copy all data from `self` into the provided `AsyncWrite`. /// /// The returned future will copy all the bytes read from `reader` into the @@ -63,6 +78,15 @@ pub trait AsyncReadExt: AsyncRead { { read_to_string(self, dst) } + + /// Creates an AsyncRead adapter which will read at most `limit` bytes + /// from the underlying reader. + fn take(self, limit: u64) -> Take + where + Self: Sized, + { + take(self, limit) + } } impl AsyncReadExt for R {} diff --git a/tokio-io/src/io/chain.rs b/tokio-io/src/io/chain.rs new file mode 100644 index 00000000000..d7de5f219a4 --- /dev/null +++ b/tokio-io/src/io/chain.rs @@ -0,0 +1,142 @@ +use crate::{AsyncBufRead, AsyncRead}; +use futures_core::ready; +use pin_utils::{unsafe_pinned, unsafe_unpinned}; +use std::fmt; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// Stream for the [`chain`](super::AsyncReadExt::chain) method. +#[must_use = "streams do nothing unless polled"] +pub struct Chain { + first: T, + second: U, + done_first: bool, +} + +impl Unpin for Chain +where + T: Unpin, + U: Unpin, +{ +} + +pub(super) fn chain(first: T, second: U) -> Chain +where + T: AsyncRead, + U: AsyncRead, +{ + Chain { + first, + second, + done_first: false, + } +} + +impl Chain +where + T: AsyncRead, + U: AsyncRead, +{ + unsafe_pinned!(first: T); + unsafe_pinned!(second: U); + unsafe_unpinned!(done_first: bool); + + /// Gets references to the underlying readers in this `Chain`. + pub fn get_ref(&self) -> (&T, &U) { + (&self.first, &self.second) + } + + /// Gets mutable references to the underlying readers in this `Chain`. + /// + /// Care should be taken to avoid modifying the internal I/O state of the + /// underlying readers as doing so may corrupt the internal state of this + /// `Chain`. + pub fn get_mut(&mut self) -> (&mut T, &mut U) { + (&mut self.first, &mut self.second) + } + + /// Gets pinned mutable references to the underlying readers in this `Chain`. + /// + /// Care should be taken to avoid modifying the internal I/O state of the + /// underlying readers as doing so may corrupt the internal state of this + /// `Chain`. + pub fn get_pin_mut(self: Pin<&mut Self>) -> (Pin<&mut T>, Pin<&mut U>) { + unsafe { + let Self { first, second, .. } = self.get_unchecked_mut(); + (Pin::new_unchecked(first), Pin::new_unchecked(second)) + } + } + + /// Consumes the `Chain`, returning the wrapped readers. + pub fn into_inner(self) -> (T, U) { + (self.first, self.second) + } +} + +impl fmt::Debug for Chain +where + T: fmt::Debug, + U: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Chain") + .field("t", &self.first) + .field("u", &self.second) + .finish() + } +} + +impl AsyncRead for Chain +where + T: AsyncRead, + U: AsyncRead, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + if !self.done_first { + match ready!(self.as_mut().first().poll_read(cx, buf)?) { + 0 if !buf.is_empty() => *self.as_mut().done_first() = true, + n => return Poll::Ready(Ok(n)), + } + } + self.second().poll_read(cx, buf) + } +} + +impl AsyncBufRead for Chain +where + T: AsyncBufRead, + U: AsyncBufRead, +{ + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let Self { + first, + second, + done_first, + } = unsafe { self.get_unchecked_mut() }; + let first = unsafe { Pin::new_unchecked(first) }; + let second = unsafe { Pin::new_unchecked(second) }; + + if !*done_first { + match ready!(first.poll_fill_buf(cx)?) { + buf if buf.is_empty() => { + *done_first = true; + } + buf => return Poll::Ready(Ok(buf)), + } + } + second.poll_fill_buf(cx) + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + if !self.done_first { + self.first().consume(amt) + } else { + self.second().consume(amt) + } + } +} diff --git a/tokio-io/src/io/mod.rs b/tokio-io/src/io/mod.rs index 888f6548657..36eae7a3d8b 100644 --- a/tokio-io/src/io/mod.rs +++ b/tokio-io/src/io/mod.rs @@ -3,6 +3,7 @@ mod async_read_ext; mod async_write_ext; mod buf_reader; mod buf_writer; +mod chain; mod copy; mod flush; mod lines; @@ -13,6 +14,7 @@ mod read_to_end; mod read_to_string; mod read_until; mod shutdown; +mod take; mod write; mod write_all; diff --git a/tokio-io/src/io/take.rs b/tokio-io/src/io/take.rs new file mode 100644 index 00000000000..b0b7975a693 --- /dev/null +++ b/tokio-io/src/io/take.rs @@ -0,0 +1,120 @@ +use crate::{AsyncBufRead, AsyncRead}; +use futures_core::ready; +use pin_utils::{unsafe_pinned, unsafe_unpinned}; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::{cmp, io}; + +/// Stream for the [`take`](super::AsyncReadExt::take) method. +#[derive(Debug)] +#[must_use = "streams do nothing unless you `.await` or poll them"] +pub struct Take { + inner: R, + // Add '_' to avoid conflicts with `limit` method. + limit_: u64, +} + +impl Unpin for Take {} + +pub(super) fn take(inner: R, limit: u64) -> Take { + Take { + inner, + limit_: limit, + } +} + +impl Take { + unsafe_pinned!(inner: R); + unsafe_unpinned!(limit_: u64); + + /// Returns the remaining number of bytes that can be + /// read before this instance will return EOF. + /// + /// # Note + /// + /// This instance may reach `EOF` after reading fewer bytes than indicated by + /// this method if the underlying [`AsyncRead`] instance reaches EOF. + pub fn limit(&self) -> u64 { + self.limit_ + } + + /// Sets the number of bytes that can be read before this instance will + /// return EOF. This is the same as constructing a new `Take` instance, so + /// the amount of bytes read and the previous limit value don't matter when + /// calling this method. + pub fn set_limit(&mut self, limit: u64) { + self.limit_ = limit + } + + /// Gets a reference to the underlying reader. + pub fn get_ref(&self) -> &R { + &self.inner + } + + /// Gets a mutable reference to the underlying reader. + /// + /// Care should be taken to avoid modifying the internal I/O state of the + /// underlying reader as doing so may corrupt the internal limit of this + /// `Take`. + pub fn get_mut(&mut self) -> &mut R { + &mut self.inner + } + + /// Gets a pinned mutable reference to the underlying reader. + /// + /// Care should be taken to avoid modifying the internal I/O state of the + /// underlying reader as doing so may corrupt the internal limit of this + /// `Take`. + pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut R> { + self.inner() + } + + /// Consumes the `Take`, returning the wrapped reader. + pub fn into_inner(self) -> R { + self.inner + } +} + +impl AsyncRead for Take { + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { + self.inner.prepare_uninitialized_buffer(buf) + } + + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + if self.limit_ == 0 { + return Poll::Ready(Ok(0)); + } + + let max = std::cmp::min(buf.len() as u64, self.limit_) as usize; + let n = ready!(self.as_mut().inner().poll_read(cx, &mut buf[..max]))?; + *self.as_mut().limit_() -= n as u64; + Poll::Ready(Ok(n)) + } +} + +impl AsyncBufRead for Take { + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let Self { inner, limit_ } = unsafe { self.get_unchecked_mut() }; + let inner = unsafe { Pin::new_unchecked(inner) }; + + // Don't call into inner reader at all at EOF because it may still block + if *limit_ == 0 { + return Poll::Ready(Ok(&[])); + } + + let buf = ready!(inner.poll_fill_buf(cx)?); + let cap = cmp::min(buf.len() as u64, *limit_) as usize; + Poll::Ready(Ok(&buf[..cap])) + } + + fn consume(mut self: Pin<&mut Self>, amt: usize) { + // Don't let callers reset the limit by passing an overlarge value + let amt = cmp::min(amt as u64, self.limit_) as usize; + *self.as_mut().limit_() -= amt as u64; + self.inner().consume(amt); + } +} diff --git a/tokio-io/tests/chain.rs b/tokio-io/tests/chain.rs new file mode 100644 index 00000000000..7cd45932eda --- /dev/null +++ b/tokio-io/tests/chain.rs @@ -0,0 +1,16 @@ +#![warn(rust_2018_idioms)] +#![feature(async_await)] + +use tokio_io::AsyncReadExt; +use tokio_test::assert_ok; + +#[tokio::test] +async fn chain() { + let mut buf = Vec::new(); + let rd1: &[u8] = b"hello "; + let rd2: &[u8] = b"world"; + + let mut rd = rd1.chain(rd2); + assert_ok!(rd.read_to_end(&mut buf).await); + assert_eq!(buf, b"hello world"); +} diff --git a/tokio-io/tests/take.rs b/tokio-io/tests/take.rs new file mode 100644 index 00000000000..ac1e831ffab --- /dev/null +++ b/tokio-io/tests/take.rs @@ -0,0 +1,16 @@ +#![warn(rust_2018_idioms)] +#![feature(async_await)] + +use tokio_io::AsyncReadExt; +use tokio_test::assert_ok; + +#[tokio::test] +async fn take() { + let mut buf = [0; 6]; + let rd: &[u8] = b"hello world"; + + let mut rd = rd.take(4); + let n = assert_ok!(rd.read(&mut buf).await); + assert_eq!(n, 4); + assert_eq!(&buf, &b"hell\0\0"[..]); +}