Skip to content

Commit

Permalink
io: change AsyncRead to use a ReadBuf
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmonstar committed Aug 10, 2020
1 parent d8490c1 commit b889ae8
Show file tree
Hide file tree
Showing 31 changed files with 507 additions and 396 deletions.
19 changes: 7 additions & 12 deletions tokio/src/fs/file.rs
Expand Up @@ -5,7 +5,7 @@
use self::State::*;
use crate::fs::{asyncify, sys};
use crate::io::blocking::Buf;
use crate::io::{AsyncRead, AsyncSeek, AsyncWrite};
use crate::io::{AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};

use std::fmt;
use std::fs::{Metadata, Permissions};
Expand Down Expand Up @@ -537,25 +537,20 @@ impl File {
}

impl AsyncRead for File {
unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit<u8>]) -> bool {
// https://github.com/rust-lang/rust/blob/09c817eeb29e764cfc12d0a8d94841e3ffe34023/src/libstd/fs.rs#L668
false
}

fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
dst: &mut [u8],
) -> Poll<io::Result<usize>> {
dst: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
loop {
match self.state {
Idle(ref mut buf_cell) => {
let mut buf = buf_cell.take().unwrap();

if !buf.is_empty() {
let n = buf.copy_to(dst);
buf.copy_to(dst);
*buf_cell = Some(buf);
return Ready(Ok(n));
return Ready(Ok(()));
}

buf.ensure_capacity_for(dst);
Expand All @@ -571,9 +566,9 @@ impl AsyncRead for File {

match op {
Operation::Read(Ok(_)) => {
let n = buf.copy_to(dst);
buf.copy_to(dst);
self.state = Idle(Some(buf));
return Ready(Ok(n));
return Ready(Ok(()));
}
Operation::Read(Err(e)) => {
assert!(buf.is_empty());
Expand Down
118 changes: 37 additions & 81 deletions tokio/src/io/async_read.rs
@@ -1,6 +1,6 @@
use super::ReadBuf;
use bytes::BufMut;
use std::io;
use std::mem::MaybeUninit;
use std::ops::DerefMut;
use std::pin::Pin;
use std::task::{Context, Poll};
Expand Down Expand Up @@ -41,47 +41,6 @@ use std::task::{Context, Poll};
/// [`Read::read`]: std::io::Read::read
/// [`AsyncReadExt`]: crate::io::AsyncReadExt
pub trait AsyncRead {
/// Prepares an uninitialized buffer to be safe to pass to `read`. Returns
/// `true` if the supplied buffer was zeroed out.
///
/// While it would be highly unusual, implementations of [`io::Read`] are
/// able to read data from the buffer passed as an argument. Because of
/// this, the buffer passed to [`io::Read`] must be initialized memory. In
/// situations where large numbers of buffers are used, constantly having to
/// zero out buffers can be expensive.
///
/// This function does any necessary work to prepare an uninitialized buffer
/// to be safe to pass to `read`. If `read` guarantees to never attempt to
/// read data out of the supplied buffer, then `prepare_uninitialized_buffer`
/// doesn't need to do any work.
///
/// If this function returns `true`, then the memory has been zeroed out.
/// This allows implementations of `AsyncRead` which are composed of
/// multiple subimplementations to efficiently implement
/// `prepare_uninitialized_buffer`.
///
/// This function isn't actually `unsafe` to call but `unsafe` to implement.
/// The implementer must ensure that either the whole `buf` has been zeroed
/// or `poll_read_buf()` overwrites the buffer without reading it and returns
/// correct value.
///
/// This function is called from [`poll_read_buf`].
///
/// # Safety
///
/// Implementations that return `false` must never read from data slices
/// that they did not write to.
///
/// [`io::Read`]: std::io::Read
/// [`poll_read_buf`]: method@Self::poll_read_buf
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool {
for x in buf {
*x = MaybeUninit::new(0);
}

true
}

/// Attempts to read from the `AsyncRead` into `buf`.
///
/// On success, returns `Poll::Ready(Ok(num_bytes_read))`.
Expand All @@ -93,8 +52,8 @@ pub trait AsyncRead {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>>;
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>>;

/// Pulls some bytes from this source into the specified `BufMut`, returning
/// how many bytes were read.
Expand All @@ -116,16 +75,10 @@ pub trait AsyncRead {

unsafe {
let n = {
let b = buf.bytes_mut();

self.prepare_uninitialized_buffer(b);

// Convert to `&mut [u8]`
let b = &mut *(b as *mut [MaybeUninit<u8>] as *mut [u8]);
let mut b = ReadBuf::uninit(buf.bytes_mut());

let n = ready!(self.poll_read(cx, b))?;
assert!(n <= b.len(), "Bad AsyncRead implementation, more bytes were reported as read than the buffer can hold");
n
ready!(self.poll_read(cx, &mut b))?;
b.filled().len()
};

buf.advance_mut(n);
Expand All @@ -136,15 +89,11 @@ pub trait AsyncRead {

macro_rules! deref_async_read {
() => {
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool {
(**self).prepare_uninitialized_buffer(buf)
}

fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut **self).poll_read(cx, buf)
}
};
Expand All @@ -163,43 +112,50 @@ where
P: DerefMut + Unpin,
P::Target: AsyncRead,
{
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool {
(**self).prepare_uninitialized_buffer(buf)
}

fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.get_mut().as_mut().poll_read(cx, buf)
}
}

impl AsyncRead for &[u8] {
unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [MaybeUninit<u8>]) -> bool {
false
}

fn poll_read(
self: Pin<&mut Self>,
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Poll::Ready(io::Read::read(self.get_mut(), buf))
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let amt = std::cmp::min(self.len(), buf.remaining());
let (a, b) = self.split_at(amt);
buf.append(a);
*self = b;
Poll::Ready(Ok(()))
}
}

impl<T: AsRef<[u8]> + Unpin> AsyncRead for io::Cursor<T> {
unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [MaybeUninit<u8>]) -> bool {
false
}

fn poll_read(
self: Pin<&mut Self>,
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Poll::Ready(io::Read::read(self.get_mut(), buf))
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let pos = self.position();
let slice: &[u8] = (*self).get_ref().as_ref();

// The position could technically be out of bounds, so don't panic...
if pos > slice.len() as u64 {
return Poll::Ready(Ok(()));
}

let start = pos as usize;
let amt = std::cmp::min(slice.len() - start, buf.remaining());
// Add won't overflow because of pos check above.
let end = start + amt;
buf.append(&slice[start..end]);
self.set_position(end as u64);

Poll::Ready(Ok(()))
}
}
24 changes: 12 additions & 12 deletions tokio/src/io/blocking.rs
@@ -1,5 +1,5 @@
use crate::io::sys;
use crate::io::{AsyncRead, AsyncWrite};
use crate::io::{AsyncRead, AsyncWrite, ReadBuf};

use std::cmp;
use std::future::Future;
Expand Down Expand Up @@ -53,17 +53,17 @@ where
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
dst: &mut [u8],
) -> Poll<io::Result<usize>> {
dst: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
loop {
match self.state {
Idle(ref mut buf_cell) => {
let mut buf = buf_cell.take().unwrap();

if !buf.is_empty() {
let n = buf.copy_to(dst);
buf.copy_to(dst);
*buf_cell = Some(buf);
return Ready(Ok(n));
return Ready(Ok(()));
}

buf.ensure_capacity_for(dst);
Expand All @@ -80,9 +80,9 @@ where

match res {
Ok(_) => {
let n = buf.copy_to(dst);
buf.copy_to(dst);
self.state = Idle(Some(buf));
return Ready(Ok(n));
return Ready(Ok(()));
}
Err(e) => {
assert!(buf.is_empty());
Expand Down Expand Up @@ -203,9 +203,9 @@ impl Buf {
self.buf.len() - self.pos
}

pub(crate) fn copy_to(&mut self, dst: &mut [u8]) -> usize {
let n = cmp::min(self.len(), dst.len());
dst[..n].copy_from_slice(&self.bytes()[..n]);
pub(crate) fn copy_to(&mut self, dst: &mut ReadBuf<'_>) -> usize {
let n = cmp::min(self.len(), dst.remaining());
dst.append(&self.bytes()[..n]);
self.pos += n;

if self.pos == self.buf.len() {
Expand All @@ -229,10 +229,10 @@ impl Buf {
&self.buf[self.pos..]
}

pub(crate) fn ensure_capacity_for(&mut self, bytes: &[u8]) {
pub(crate) fn ensure_capacity_for(&mut self, bytes: &ReadBuf<'_>) {
assert!(self.is_empty());

let len = cmp::min(bytes.len(), MAX_BUF);
let len = cmp::min(bytes.remaining(), MAX_BUF);

if self.buf.len() < len {
self.buf.reserve(len - self.buf.len());
Expand Down
3 changes: 3 additions & 0 deletions tokio/src/io/mod.rs
Expand Up @@ -196,6 +196,9 @@ pub use self::async_seek::AsyncSeek;
mod async_write;
pub use self::async_write::AsyncWrite;

mod read_buf;
pub use self::read_buf::ReadBuf;

// Re-export some types from `std::io` so that users don't have to deal
// with conflicts when `use`ing `tokio::io` and `std::io`.
pub use std::io::{Error, ErrorKind, Result, SeekFrom};
Expand Down
14 changes: 9 additions & 5 deletions tokio/src/io/poll_evented.rs
@@ -1,5 +1,5 @@
use crate::io::driver::platform;
use crate::io::{AsyncRead, AsyncWrite, Registration};
use crate::io::{AsyncRead, AsyncWrite, ReadBuf, Registration};

use mio::event::Evented;
use std::fmt;
Expand Down Expand Up @@ -384,18 +384,22 @@ where
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
ready!(self.poll_read_ready(cx, mio::Ready::readable()))?;

let r = (*self).get_mut().read(buf);
// We can't assume the `Read` won't look at the read buffer,
// so we have to force initialization here.
let r = (*self).get_mut().read(buf.initialize_unfilled());

if is_wouldblock(&r) {
self.clear_read_ready(cx, mio::Ready::readable())?;
return Poll::Pending;
}

Poll::Ready(r)
Poll::Ready(r.map(|n| {
buf.add_filled(n);
}))
}
}

Expand Down

0 comments on commit b889ae8

Please sign in to comment.