Skip to content

Commit

Permalink
io: add ReaderStream (#2714)
Browse files Browse the repository at this point in the history
  • Loading branch information
MikailBag committed Aug 23, 2020
1 parent 1167c09 commit 30d4ec0
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 0 deletions.
1 change: 1 addition & 0 deletions tokio/src/io/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ cfg_io_util! {

cfg_stream! {
pub use util::{stream_reader, StreamReader};
pub use util::{reader_stream, ReaderStream};
}
}

Expand Down
3 changes: 3 additions & 0 deletions tokio/src/io/util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ cfg_io_util! {
cfg_stream! {
mod stream_reader;
pub use stream_reader::{stream_reader, StreamReader};

mod reader_stream;
pub use reader_stream::{reader_stream, ReaderStream};
}

mod take;
Expand Down
105 changes: 105 additions & 0 deletions tokio/src/io/util/reader_stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
use crate::io::AsyncRead;
use crate::stream::Stream;
use bytes::{Bytes, BytesMut};
use pin_project_lite::pin_project;
use std::pin::Pin;
use std::task::{Context, Poll};

pin_project! {
/// Convert an [`AsyncRead`] implementor into a
/// [`Stream`] of Result<[`Bytes`], std::io::Error>.
/// After first error it will stop.
/// Additionally, this stream is fused: after it returns None at some
/// moment, it is guaranteed that further `next()`, `poll_next()` and
/// similar functions will instantly return None.
///
/// This type can be created using the [`reader_stream`] function
///
/// [`AsyncRead`]: crate::io::AsyncRead
/// [`Stream`]: crate::stream::Stream
/// [`Bytes`]: bytes::Bytes
/// [`reader_stream`]: crate::io::reader_stream
#[derive(Debug)]
#[cfg_attr(docsrs, doc(cfg(feature = "stream")))]
#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
pub struct ReaderStream<R> {
// Reader itself.
// None if we had error reading from the `reader` in the past.
#[pin]
reader: Option<R>,
// Working buffer, used to optimize allocations.
// # Capacity behavior
// Initially `buf` is empty. Also it's getting smaller and smaller
// during polls (because its chunks are returned to stream user).
// But when it's capacity reaches 0, it is growed.
buf: BytesMut,
}
}

/// Convert an [`AsyncRead`] implementor into a
/// [`Stream`] of Result<[`Bytes`], std::io::Error>.
///
/// # Example
///
/// ```
/// # #[tokio::main]
/// # async fn main() -> std::io::Result<()> {
/// use tokio::stream::StreamExt;
///
/// let data: &[u8] = b"hello, world!";
/// let mut stream = tokio::io::reader_stream(data);
/// let mut stream_contents = Vec::new();
/// while let Some(chunk) = stream.next().await {
/// stream_contents.extend_from_slice(chunk?.as_ref());
/// }
/// assert_eq!(stream_contents, data);
/// # Ok(())
/// # }
/// ```
///
/// [`AsyncRead`]: crate::io::AsyncRead
/// [`Stream`]: crate::stream::Stream
/// [`Bytes`]: bytes::Bytes
pub fn reader_stream<R>(reader: R) -> ReaderStream<R>
where
R: AsyncRead,
{
ReaderStream {
reader: Some(reader),
buf: BytesMut::new(),
}
}

const CAPACITY: usize = 4096;

impl<R> Stream for ReaderStream<R>
where
R: AsyncRead,
{
type Item = std::io::Result<Bytes>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.as_mut().project();
let reader = match this.reader.as_pin_mut() {
Some(r) => r,
None => return Poll::Ready(None),
};
if this.buf.capacity() == 0 {
this.buf.reserve(CAPACITY);
}
match reader.poll_read_buf(cx, &mut this.buf) {
Poll::Pending => Poll::Pending,
Poll::Ready(Err(err)) => {
self.project().reader.set(None);
Poll::Ready(Some(Err(err)))
}
Poll::Ready(Ok(0)) => {
self.project().reader.set(None);
Poll::Ready(None)
}
Poll::Ready(Ok(_)) => {
let chunk = this.buf.split();
Poll::Ready(Some(Ok(chunk.freeze())))
}
}
}
}
64 changes: 64 additions & 0 deletions tokio/tests/io_reader_stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#![warn(rust_2018_idioms)]
#![cfg(feature = "full")]

use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::AsyncRead;
use tokio::stream::StreamExt;

/// produces at most `remaining` zeros, that returns error.
/// each time it reads at most 31 byte.
struct Reader {
remaining: usize,
}

impl AsyncRead for Reader {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
let this = Pin::into_inner(self);
assert_ne!(buf.len(), 0);
if this.remaining > 0 {
let n = std::cmp::min(this.remaining, buf.len());
let n = std::cmp::min(n, 31);
for x in &mut buf[..n] {
*x = 0;
}
this.remaining -= n;
Poll::Ready(Ok(n))
} else {
Poll::Ready(Err(std::io::Error::from_raw_os_error(22)))
}
}
}

#[tokio::test]
async fn correct_behavior_on_errors() {
let reader = Reader { remaining: 8000 };
let mut stream = tokio::io::reader_stream(reader);
let mut zeros_received = 0;
let mut had_error = false;
loop {
let item = stream.next().await.unwrap();
match item {
Ok(bytes) => {
let bytes = &*bytes;
for byte in bytes {
assert_eq!(*byte, 0);
zeros_received += 1;
}
}
Err(_) => {
assert!(!had_error);
had_error = true;
break;
}
}
}

assert!(had_error);
assert_eq!(zeros_received, 8000);
assert!(stream.next().await.is_none());
}

0 comments on commit 30d4ec0

Please sign in to comment.