Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AsyncReadExt::chain #1810

Merged
merged 1 commit into from Aug 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
150 changes: 150 additions & 0 deletions futures-util/src/io/chain.rs
@@ -0,0 +1,150 @@
use futures_core::task::{Context, Poll};
use futures_io::{AsyncBufRead, AsyncRead, Initializer, IoSliceMut};
use pin_utils::{unsafe_pinned, unsafe_unpinned};
use std::fmt;
use std::io;
use std::pin::Pin;

/// Stream for the [`chain`](super::AsyncReadExt::chain) method.
#[must_use = "streams do nothing unless polled"]
pub struct Chain<T, U> {
first: T,
second: U,
done_first: bool,
}

impl<T, U> Unpin for Chain<T, U>
where
T: Unpin,
U: Unpin,
{
}

impl<T, U> Chain<T, U>
where
T: AsyncRead,
U: AsyncRead,
{
unsafe_pinned!(first: T);
unsafe_pinned!(second: U);
unsafe_unpinned!(done_first: bool);

pub(super) fn new(first: T, second: U) -> Self {
Self {
first,
second,
done_first: false,
}
}

/// Consumes the `Chain`, returning the wrapped readers.
pub fn into_inner(self) -> (T, U) {
(self.first, self.second)
}

/// Gets references to the underlying readers in this `Chain`.
pub fn get_ref(&self) -> (&T, &U) {
(&self.first, &self.second)
}

/// Gets mutable references to the underlying readers in this `Chain`.
///
/// Care should be taken to avoid modifying the internal I/O state of the
/// underlying readers as doing so may corrupt the internal state of this
/// `Chain`.
pub fn get_mut(&mut self) -> (&mut T, &mut U) {
(&mut self.first, &mut self.second)
}
}

impl<T, U> fmt::Debug for Chain<T, U>
where
T: fmt::Debug,
U: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Chain")
.field("t", &self.first)
.field("u", &self.second)
.finish()
}
}

impl<T, U> AsyncRead for Chain<T, U>
where
T: AsyncRead,
U: AsyncRead,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
if !self.done_first {
match ready!(self.as_mut().first().poll_read(cx, buf)?) {
0 if !buf.is_empty() => *self.as_mut().done_first() = true,
n => return Poll::Ready(Ok(n)),
}
}
self.second().poll_read(cx, buf)
}

fn poll_read_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &mut [IoSliceMut<'_>],
) -> Poll<io::Result<usize>> {
if !self.done_first {
let n = ready!(self.as_mut().first().poll_read_vectored(cx, bufs)?);
if n == 0 && bufs.iter().any(|b| !b.is_empty()) {
*self.as_mut().done_first() = true
} else {
return Poll::Ready(Ok(n));
}
}
self.second().poll_read_vectored(cx, bufs)
}

unsafe fn initializer(&self) -> Initializer {
let initializer = self.first.initializer();
if initializer.should_initialize() {
initializer
} else {
self.second.initializer()
}
}
}

impl<T, U> AsyncBufRead for Chain<T, U>
where
T: AsyncBufRead,
U: AsyncBufRead,
{
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
let Self {
first,
second,
done_first,
} = unsafe { self.get_unchecked_mut() };
let first = unsafe { Pin::new_unchecked(first) };
let second = unsafe { Pin::new_unchecked(second) };

if !*done_first {
match ready!(first.poll_fill_buf(cx)?) {
buf if buf.is_empty() => {
*done_first = true;
}
buf => return Poll::Ready(Ok(buf)),
}
}
second.poll_fill_buf(cx)
}

fn consume(self: Pin<&mut Self>, amt: usize) {
if !self.done_first {
self.first().consume(amt)
} else {
self.second().consume(amt)
}
}
}
42 changes: 39 additions & 3 deletions futures-util/src/io/mod.rs
Expand Up @@ -28,6 +28,12 @@ pub use self::buf_reader::BufReader;
mod buf_writer;
pub use self::buf_writer::BufWriter;

mod chain;
pub use self::chain::Chain;

mod close;
pub use self::close::Close;

mod copy_into;
pub use self::copy_into::CopyInto;

Expand Down Expand Up @@ -66,9 +72,6 @@ pub use self::read_to_string::ReadToString;
mod read_until;
pub use self::read_until::ReadUntil;

mod close;
pub use self::close::Close;

mod seek;
pub use self::seek::Seek;

Expand All @@ -92,6 +95,39 @@ pub use self::write_all::WriteAll;

/// An extension trait which adds utility methods to `AsyncRead` types.
pub trait AsyncReadExt: AsyncRead {
/// Creates an adaptor which will chain this stream with another.
///
/// The returned `AsyncRead` instance will first read all bytes from this object
/// until EOF is encountered. Afterwards the output is equivalent to the
/// output of `next`.
///
/// # Examples
///
/// ```
/// #![feature(async_await)]
/// # futures::executor::block_on(async {
/// use futures::io::AsyncReadExt;
/// use std::io::Cursor;
///
/// let reader1 = Cursor::new([1, 2, 3, 4]);
/// let reader2 = Cursor::new([5, 6, 7, 8]);
///
/// let mut reader = reader1.chain(reader2);
/// let mut buffer = Vec::new();
///
/// // read the value into a Vec.
/// reader.read_to_end(&mut buffer).await?;
/// assert_eq!(buffer, [1, 2, 3, 4, 5, 6, 7, 8]);
/// # Ok::<(), Box<dyn std::error::Error>>(()) }).unwrap();
/// ```
fn chain<R>(self, next: R) -> Chain<Self, R>
where
Self: Sized,
R: AsyncRead,
{
Chain::new(self, next)
}

/// Creates a future which copies all the bytes from one object to another.
///
/// The returned future will copy all the bytes read from this `AsyncRead` into the
Expand Down
8 changes: 4 additions & 4 deletions futures/src/lib.rs
Expand Up @@ -302,10 +302,10 @@ pub mod io {

pub use futures_util::io::{
AsyncReadExt, AsyncWriteExt, AsyncSeekExt, AsyncBufReadExt, AllowStdIo,
BufReader, BufWriter, Close, CopyInto, CopyBufInto, Flush, IntoSink,
Lines, Read, ReadExact, ReadHalf, ReadLine, ReadToEnd, ReadToString,
ReadUntil, ReadVectored, Seek, Window, Write, WriteAll, WriteHalf,
WriteVectored,
BufReader, BufWriter, Chain, Close, CopyInto, CopyBufInto, Flush,
IntoSink, Lines, Read, ReadExact, ReadHalf, ReadLine, ReadToEnd,
ReadToString, ReadUntil, ReadVectored, Seek, Take, Window, Write,
WriteAll, WriteHalf, WriteVectored,
};
}

Expand Down