Skip to content

Commit

Permalink
Add some trait/method implementation to AsyncReadExt::{chain, take}
Browse files Browse the repository at this point in the history
  • Loading branch information
taiki-e committed Aug 22, 2019
1 parent cde791c commit ddf4f55
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 33 deletions.
22 changes: 17 additions & 5 deletions futures-util/src/io/chain.rs
Expand Up @@ -37,11 +37,6 @@ where
}
}

/// Consumes the `Chain`, returning the wrapped readers.
pub fn into_inner(self) -> (T, U) {
(self.first, self.second)
}

/// Gets references to the underlying readers in this `Chain`.
pub fn get_ref(&self) -> (&T, &U) {
(&self.first, &self.second)
Expand All @@ -55,6 +50,23 @@ where
pub fn get_mut(&mut self) -> (&mut T, &mut U) {
(&mut self.first, &mut self.second)
}

/// Gets pinned mutable references to the underlying readers in this `Chain`.
///
/// Care should be taken to avoid modifying the internal I/O state of the
/// underlying readers as doing so may corrupt the internal state of this
/// `Chain`.
pub fn get_pin_mut(self: Pin<&mut Self>) -> (Pin<&mut T>, Pin<&mut U>) {
unsafe {
let Self { first, second, .. } = self.get_unchecked_mut();
(Pin::new_unchecked(first), Pin::new_unchecked(second))
}
}

/// Consumes the `Chain`, returning the wrapped readers.
pub fn into_inner(self) -> (T, U) {
(self.first, self.second)
}
}

impl<T, U> fmt::Debug for Chain<T, U>
Expand Down
2 changes: 1 addition & 1 deletion futures-util/src/io/mod.rs
Expand Up @@ -367,7 +367,7 @@ pub trait AsyncReadExt: AsyncRead {
/// # Ok::<(), Box<dyn std::error::Error>>(()) }).unwrap();
/// ```
fn take(self, limit: u64) -> Take<Self>
where Self: Sized + Unpin
where Self: Sized
{
Take::new(self, limit)
}
Expand Down
95 changes: 68 additions & 27 deletions futures-util/src/io/take.rs
@@ -1,21 +1,26 @@
use futures_core::task::{Context, Poll};
use futures_io::AsyncRead;
use std::io;
use futures_io::{AsyncRead, AsyncBufRead, Initializer};
use pin_utils::{unsafe_pinned, unsafe_unpinned};
use std::{cmp, io};
use std::pin::Pin;

/// Future for the [`take`](super::AsyncReadExt::take) method.
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Take<R: Unpin> {
pub struct Take<R> {
inner: R,
limit: u64,
// Add '_' to avoid conflicts with `limit` method.
limit_: u64,
}

impl<R: Unpin> Unpin for Take<R> { }

impl<R: AsyncRead + Unpin> Take<R> {
impl<R: AsyncRead> Take<R> {
unsafe_pinned!(inner: R);
unsafe_unpinned!(limit_: u64);

pub(super) fn new(inner: R, limit: u64) -> Self {
Take { inner, limit }
Self { inner, limit_: limit }
}

/// Returns the remaining number of bytes that can be
Expand Down Expand Up @@ -43,7 +48,7 @@ impl<R: AsyncRead + Unpin> Take<R> {
/// # Ok::<(), Box<dyn std::error::Error>>(()) }).unwrap();
/// ```
pub fn limit(&self) -> u64 {
self.limit
self.limit_
}

/// Sets the number of bytes that can be read before this instance will
Expand Down Expand Up @@ -74,10 +79,10 @@ impl<R: AsyncRead + Unpin> Take<R> {
/// # Ok::<(), Box<dyn std::error::Error>>(()) }).unwrap();
/// ```
pub fn set_limit(&mut self, limit: u64) {
self.limit = limit
self.limit_ = limit
}

/// Consumes the `Take`, returning the wrapped reader.
/// Gets a reference to the underlying reader.
///
/// # Examples
///
Expand All @@ -92,16 +97,20 @@ impl<R: AsyncRead + Unpin> Take<R> {
/// let mut take = reader.take(4);
/// let n = take.read(&mut buffer).await?;
///
/// let cursor = take.into_inner();
/// assert_eq!(cursor.position(), 4);
/// let cursor_ref = take.get_ref();
/// assert_eq!(cursor_ref.position(), 4);
///
/// # Ok::<(), Box<dyn std::error::Error>>(()) }).unwrap();
/// ```
pub fn into_inner(self) -> R {
self.inner
pub fn get_ref(&self) -> &R {
&self.inner
}

/// Gets a reference to the underlying reader.
/// Gets a mutable reference to the underlying reader.
///
/// Care should be taken to avoid modifying the internal I/O state of the
/// underlying reader as doing so may corrupt the internal limit of this
/// `Take`.
///
/// # Examples
///
Expand All @@ -116,20 +125,24 @@ impl<R: AsyncRead + Unpin> Take<R> {
/// let mut take = reader.take(4);
/// let n = take.read(&mut buffer).await?;
///
/// let cursor_ref = take.get_ref();
/// assert_eq!(cursor_ref.position(), 4);
/// let cursor_mut = take.get_mut();
///
/// # Ok::<(), Box<dyn std::error::Error>>(()) }).unwrap();
/// ```
pub fn get_ref(&self) -> &R {
&self.inner
pub fn get_mut(&mut self) -> &mut R {
&mut self.inner
}

/// Gets a mutable reference to the underlying reader.
/// Gets a pinned mutable reference to the underlying reader.
///
/// Care should be taken to avoid modifying the internal I/O state of the
/// underlying reader as doing so may corrupt the internal limit of this
/// `Take`.
pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut R> {
self.inner()
}

/// Consumes the `Take`, returning the wrapped reader.
///
/// # Examples
///
Expand All @@ -144,28 +157,56 @@ impl<R: AsyncRead + Unpin> Take<R> {
/// let mut take = reader.take(4);
/// let n = take.read(&mut buffer).await?;
///
/// let cursor_mut = take.get_mut();
/// let cursor = take.into_inner();
/// assert_eq!(cursor.position(), 4);
///
/// # Ok::<(), Box<dyn std::error::Error>>(()) }).unwrap();
/// ```
pub fn get_mut(&mut self) -> &mut R {
&mut self.inner
pub fn into_inner(self) -> R {
self.inner
}
}

impl<R: AsyncRead + Unpin> AsyncRead for Take<R> {
impl<R: AsyncRead> AsyncRead for Take<R> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<usize, io::Error>> {
if self.limit == 0 {
if self.limit_ == 0 {
return Poll::Ready(Ok(0));
}

let max = std::cmp::min(buf.len() as u64, self.limit) as usize;
let n = ready!(Pin::new(&mut self.inner).poll_read(cx, &mut buf[..max]))?;
self.limit -= n as u64;
let max = std::cmp::min(buf.len() as u64, self.limit_) as usize;
let n = ready!(self.as_mut().inner().poll_read(cx, &mut buf[..max]))?;
*self.as_mut().limit_() -= n as u64;
Poll::Ready(Ok(n))
}

unsafe fn initializer(&self) -> Initializer {
self.inner.initializer()
}
}

impl<R: AsyncBufRead> AsyncBufRead for Take<R> {
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
let Self { inner, limit_ } = unsafe { self.get_unchecked_mut() };
let inner = unsafe { Pin::new_unchecked(inner) };

// Don't call into inner reader at all at EOF because it may still block
if *limit_ == 0 {
return Poll::Ready(Ok(&[]));
}

let buf = ready!(inner.poll_fill_buf(cx)?);
let cap = cmp::min(buf.len() as u64, *limit_) as usize;
Poll::Ready(Ok(&buf[..cap]))
}

fn consume(mut self: Pin<&mut Self>, amt: usize) {
// Don't let callers reset the limit by passing an overlarge value
let amt = cmp::min(amt as u64, self.limit_) as usize;
*self.as_mut().limit_() -= amt as u64;
self.inner().consume(amt);
}
}

0 comments on commit ddf4f55

Please sign in to comment.