diff --git a/tokio-test/src/io.rs b/tokio-test/src/io.rs index 26ef57e47a2..f1ce77aa248 100644 --- a/tokio-test/src/io.rs +++ b/tokio-test/src/io.rs @@ -18,7 +18,7 @@ //! [`AsyncRead`]: tokio::io::AsyncRead //! [`AsyncWrite`]: tokio::io::AsyncWrite -use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::sync::mpsc; use tokio::time::{self, Delay, Duration, Instant}; @@ -204,20 +204,19 @@ impl Inner { self.rx.poll_recv(cx) } - fn read(&mut self, dst: &mut [u8]) -> io::Result { + fn read(&mut self, dst: &mut ReadBuf<'_>) -> io::Result<()> { match self.action() { Some(&mut Action::Read(ref mut data)) => { // Figure out how much to copy - let n = cmp::min(dst.len(), data.len()); + let n = cmp::min(dst.remaining(), data.len()); // Copy the data into the `dst` slice - (&mut dst[..n]).copy_from_slice(&data[..n]); + dst.append(&data[..n]); // Drain the data from the source data.drain(..n); - // Return the number of bytes read - Ok(n) + Ok(()) } Some(&mut Action::ReadError(ref mut err)) => { // As the @@ -229,7 +228,7 @@ impl Inner { // Either waiting or expecting a write Err(io::ErrorKind::WouldBlock.into()) } - None => Ok(0), + None => Ok(()), } } @@ -348,8 +347,8 @@ impl AsyncRead for Mock { fn poll_read( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf<'_>, + ) -> Poll> { loop { if let Some(ref mut sleep) = self.inner.sleep { ready!(Pin::new(sleep).poll(cx)); @@ -358,6 +357,9 @@ impl AsyncRead for Mock { // If a sleep is set, it has already fired self.inner.sleep = None; + // Capture 'filled' to monitor if it changed + let filled = buf.filled().len(); + match self.inner.read(buf) { Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { if let Some(rem) = self.inner.remaining_wait() { @@ -368,19 +370,22 @@ impl AsyncRead for Mock { return Poll::Pending; } } - Ok(0) => { - // TODO: Extract - match ready!(self.inner.poll_action(cx)) { - Some(action) => { - self.inner.actions.push_back(action); - continue; - } - None => { - return Poll::Ready(Ok(0)); + Ok(()) => { + if buf.filled().len() == filled { + match ready!(self.inner.poll_action(cx)) { + Some(action) => { + self.inner.actions.push_back(action); + continue; + } + None => { + return Poll::Ready(Ok(())); + } } + } else { + return Poll::Ready(Ok(())); } } - ret => return Poll::Ready(ret), + Err(e) => return Poll::Ready(Err(e)), } } } diff --git a/tokio-util/src/compat.rs b/tokio-util/src/compat.rs index 769e30c2bb9..34120d43a34 100644 --- a/tokio-util/src/compat.rs +++ b/tokio-util/src/compat.rs @@ -1,5 +1,6 @@ //! Compatibility between the `tokio::io` and `futures-io` versions of the //! `AsyncRead` and `AsyncWrite` traits. +use futures_core::ready; use pin_project_lite::pin_project; use std::io; use std::pin::Pin; @@ -107,9 +108,18 @@ where fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - futures_io::AsyncRead::poll_read(self.project().inner, cx, buf) + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + // We can't trust the inner type to not peak at the bytes, + // so we must defensively initialize the buffer. + let slice = buf.initialize_unfilled(); + let n = ready!(futures_io::AsyncRead::poll_read( + self.project().inner, + cx, + slice + ))?; + buf.add_filled(n); + Poll::Ready(Ok(())) } } @@ -120,9 +130,15 @@ where fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], + slice: &mut [u8], ) -> Poll> { - tokio::io::AsyncRead::poll_read(self.project().inner, cx, buf) + let mut buf = tokio::io::ReadBuf::new(slice); + ready!(tokio::io::AsyncRead::poll_read( + self.project().inner, + cx, + &mut buf + ))?; + Poll::Ready(Ok(buf.filled().len())) } } diff --git a/tokio-util/tests/framed.rs b/tokio-util/tests/framed.rs index d7ee3ef51fb..4c5f8418615 100644 --- a/tokio-util/tests/framed.rs +++ b/tokio-util/tests/framed.rs @@ -55,8 +55,8 @@ impl AsyncRead for DontReadIntoThis { fn poll_read( self: Pin<&mut Self>, _cx: &mut Context<'_>, - _buf: &mut [u8], - ) -> Poll> { + _buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { unreachable!() } } diff --git a/tokio-util/tests/framed_read.rs b/tokio-util/tests/framed_read.rs index 27bb298a7fa..da38c432326 100644 --- a/tokio-util/tests/framed_read.rs +++ b/tokio-util/tests/framed_read.rs @@ -1,6 +1,6 @@ #![warn(rust_2018_idioms)] -use tokio::io::AsyncRead; +use tokio::io::{AsyncRead, ReadBuf}; use tokio_test::assert_ready; use tokio_test::task; use tokio_util::codec::{Decoder, FramedRead}; @@ -264,19 +264,19 @@ impl AsyncRead for Mock { fn poll_read( mut self: Pin<&mut Self>, _cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf<'_>, + ) -> Poll> { use io::ErrorKind::WouldBlock; match self.calls.pop_front() { Some(Ok(data)) => { - debug_assert!(buf.len() >= data.len()); - buf[..data.len()].copy_from_slice(&data[..]); - Ready(Ok(data.len())) + debug_assert!(buf.remaining() >= data.len()); + buf.append(&data); + Ready(Ok(())) } Some(Err(ref e)) if e.kind() == WouldBlock => Pending, Some(Err(e)) => Ready(Err(e)), - None => Ready(Ok(0)), + None => Ready(Ok(())), } } } @@ -288,8 +288,8 @@ impl AsyncRead for Slice<'_> { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf<'_>, + ) -> Poll> { Pin::new(&mut self.0).poll_read(cx, buf) } } diff --git a/tokio-util/tests/length_delimited.rs b/tokio-util/tests/length_delimited.rs index 734cd834da1..9f615412875 100644 --- a/tokio-util/tests/length_delimited.rs +++ b/tokio-util/tests/length_delimited.rs @@ -1,6 +1,6 @@ #![warn(rust_2018_idioms)] -use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio_test::task; use tokio_test::{ assert_err, assert_ok, assert_pending, assert_ready, assert_ready_err, assert_ready_ok, @@ -707,18 +707,18 @@ impl AsyncRead for Mock { fn poll_read( mut self: Pin<&mut Self>, _cx: &mut Context<'_>, - dst: &mut [u8], - ) -> Poll> { + dst: &mut ReadBuf<'_>, + ) -> Poll> { match self.calls.pop_front() { Some(Ready(Ok(Op::Data(data)))) => { - debug_assert!(dst.len() >= data.len()); - dst[..data.len()].copy_from_slice(&data[..]); - Ready(Ok(data.len())) + debug_assert!(dst.remaining() >= data.len()); + dst.append(&data); + Ready(Ok(())) } Some(Ready(Ok(_))) => panic!(), Some(Ready(Err(e))) => Ready(Err(e)), Some(Pending) => Pending, - None => Ready(Ok(0)), + None => Ready(Ok(())), } } } diff --git a/tokio/src/fs/file.rs b/tokio/src/fs/file.rs index c44196b3e74..2c36806d870 100644 --- a/tokio/src/fs/file.rs +++ b/tokio/src/fs/file.rs @@ -5,7 +5,7 @@ use self::State::*; use crate::fs::{asyncify, sys}; use crate::io::blocking::Buf; -use crate::io::{AsyncRead, AsyncSeek, AsyncWrite}; +use crate::io::{AsyncRead, AsyncSeek, AsyncWrite, ReadBuf}; use std::fmt; use std::fs::{Metadata, Permissions}; @@ -537,25 +537,20 @@ impl File { } impl AsyncRead for File { - unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit]) -> bool { - // https://github.com/rust-lang/rust/blob/09c817eeb29e764cfc12d0a8d94841e3ffe34023/src/libstd/fs.rs#L668 - false - } - fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - dst: &mut [u8], - ) -> Poll> { + dst: &mut ReadBuf<'_>, + ) -> Poll> { loop { match self.state { Idle(ref mut buf_cell) => { let mut buf = buf_cell.take().unwrap(); if !buf.is_empty() { - let n = buf.copy_to(dst); + buf.copy_to(dst); *buf_cell = Some(buf); - return Ready(Ok(n)); + return Ready(Ok(())); } buf.ensure_capacity_for(dst); @@ -571,9 +566,9 @@ impl AsyncRead for File { match op { Operation::Read(Ok(_)) => { - let n = buf.copy_to(dst); + buf.copy_to(dst); self.state = Idle(Some(buf)); - return Ready(Ok(n)); + return Ready(Ok(())); } Operation::Read(Err(e)) => { assert!(buf.is_empty()); diff --git a/tokio/src/io/async_read.rs b/tokio/src/io/async_read.rs index 1aef4150166..d341b63d41a 100644 --- a/tokio/src/io/async_read.rs +++ b/tokio/src/io/async_read.rs @@ -1,6 +1,6 @@ +use super::ReadBuf; use bytes::BufMut; use std::io; -use std::mem::MaybeUninit; use std::ops::DerefMut; use std::pin::Pin; use std::task::{Context, Poll}; @@ -41,47 +41,6 @@ use std::task::{Context, Poll}; /// [`Read::read`]: std::io::Read::read /// [`AsyncReadExt`]: crate::io::AsyncReadExt pub trait AsyncRead { - /// Prepares an uninitialized buffer to be safe to pass to `read`. Returns - /// `true` if the supplied buffer was zeroed out. - /// - /// While it would be highly unusual, implementations of [`io::Read`] are - /// able to read data from the buffer passed as an argument. Because of - /// this, the buffer passed to [`io::Read`] must be initialized memory. In - /// situations where large numbers of buffers are used, constantly having to - /// zero out buffers can be expensive. - /// - /// This function does any necessary work to prepare an uninitialized buffer - /// to be safe to pass to `read`. If `read` guarantees to never attempt to - /// read data out of the supplied buffer, then `prepare_uninitialized_buffer` - /// doesn't need to do any work. - /// - /// If this function returns `true`, then the memory has been zeroed out. - /// This allows implementations of `AsyncRead` which are composed of - /// multiple subimplementations to efficiently implement - /// `prepare_uninitialized_buffer`. - /// - /// This function isn't actually `unsafe` to call but `unsafe` to implement. - /// The implementer must ensure that either the whole `buf` has been zeroed - /// or `poll_read_buf()` overwrites the buffer without reading it and returns - /// correct value. - /// - /// This function is called from [`poll_read_buf`]. - /// - /// # Safety - /// - /// Implementations that return `false` must never read from data slices - /// that they did not write to. - /// - /// [`io::Read`]: std::io::Read - /// [`poll_read_buf`]: method@Self::poll_read_buf - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit]) -> bool { - for x in buf { - *x = MaybeUninit::new(0); - } - - true - } - /// Attempts to read from the `AsyncRead` into `buf`. /// /// On success, returns `Poll::Ready(Ok(num_bytes_read))`. @@ -93,8 +52,8 @@ pub trait AsyncRead { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll>; + buf: &mut ReadBuf<'_>, + ) -> Poll>; /// Pulls some bytes from this source into the specified `BufMut`, returning /// how many bytes were read. @@ -114,37 +73,26 @@ pub trait AsyncRead { return Poll::Ready(Ok(0)); } - unsafe { - let n = { - let b = buf.bytes_mut(); - - self.prepare_uninitialized_buffer(b); - - // Convert to `&mut [u8]` - let b = &mut *(b as *mut [MaybeUninit] as *mut [u8]); + let mut b = ReadBuf::uninit(buf.bytes_mut()); - let n = ready!(self.poll_read(cx, b))?; - assert!(n <= b.len(), "Bad AsyncRead implementation, more bytes were reported as read than the buffer can hold"); - n - }; + ready!(self.poll_read(cx, &mut b))?; + let n = b.filled().len(); + // Safety: we can assume `n` bytes were read, since they are in`filled`. + unsafe { buf.advance_mut(n); - Poll::Ready(Ok(n)) } + Poll::Ready(Ok(n)) } } macro_rules! deref_async_read { () => { - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit]) -> bool { - (**self).prepare_uninitialized_buffer(buf) - } - fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf<'_>, + ) -> Poll> { Pin::new(&mut **self).poll_read(cx, buf) } }; @@ -163,43 +111,50 @@ where P: DerefMut + Unpin, P::Target: AsyncRead, { - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit]) -> bool { - (**self).prepare_uninitialized_buffer(buf) - } - fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf<'_>, + ) -> Poll> { self.get_mut().as_mut().poll_read(cx, buf) } } impl AsyncRead for &[u8] { - unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [MaybeUninit]) -> bool { - false - } - fn poll_read( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, _cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - Poll::Ready(io::Read::read(self.get_mut(), buf)) + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let amt = std::cmp::min(self.len(), buf.remaining()); + let (a, b) = self.split_at(amt); + buf.append(a); + *self = b; + Poll::Ready(Ok(())) } } impl + Unpin> AsyncRead for io::Cursor { - unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [MaybeUninit]) -> bool { - false - } - fn poll_read( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, _cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - Poll::Ready(io::Read::read(self.get_mut(), buf)) + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let pos = self.position(); + let slice: &[u8] = (*self).get_ref().as_ref(); + + // The position could technically be out of bounds, so don't panic... + if pos > slice.len() as u64 { + return Poll::Ready(Ok(())); + } + + let start = pos as usize; + let amt = std::cmp::min(slice.len() - start, buf.remaining()); + // Add won't overflow because of pos check above. + let end = start + amt; + buf.append(&slice[start..end]); + self.set_position(end as u64); + + Poll::Ready(Ok(())) } } diff --git a/tokio/src/io/blocking.rs b/tokio/src/io/blocking.rs index 2491039a3f3..d2265a00aa2 100644 --- a/tokio/src/io/blocking.rs +++ b/tokio/src/io/blocking.rs @@ -1,5 +1,5 @@ use crate::io::sys; -use crate::io::{AsyncRead, AsyncWrite}; +use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; use std::cmp; use std::future::Future; @@ -53,17 +53,17 @@ where fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - dst: &mut [u8], - ) -> Poll> { + dst: &mut ReadBuf<'_>, + ) -> Poll> { loop { match self.state { Idle(ref mut buf_cell) => { let mut buf = buf_cell.take().unwrap(); if !buf.is_empty() { - let n = buf.copy_to(dst); + buf.copy_to(dst); *buf_cell = Some(buf); - return Ready(Ok(n)); + return Ready(Ok(())); } buf.ensure_capacity_for(dst); @@ -80,9 +80,9 @@ where match res { Ok(_) => { - let n = buf.copy_to(dst); + buf.copy_to(dst); self.state = Idle(Some(buf)); - return Ready(Ok(n)); + return Ready(Ok(())); } Err(e) => { assert!(buf.is_empty()); @@ -203,9 +203,9 @@ impl Buf { self.buf.len() - self.pos } - pub(crate) fn copy_to(&mut self, dst: &mut [u8]) -> usize { - let n = cmp::min(self.len(), dst.len()); - dst[..n].copy_from_slice(&self.bytes()[..n]); + pub(crate) fn copy_to(&mut self, dst: &mut ReadBuf<'_>) -> usize { + let n = cmp::min(self.len(), dst.remaining()); + dst.append(&self.bytes()[..n]); self.pos += n; if self.pos == self.buf.len() { @@ -229,10 +229,10 @@ impl Buf { &self.buf[self.pos..] } - pub(crate) fn ensure_capacity_for(&mut self, bytes: &[u8]) { + pub(crate) fn ensure_capacity_for(&mut self, bytes: &ReadBuf<'_>) { assert!(self.is_empty()); - let len = cmp::min(bytes.len(), MAX_BUF); + let len = cmp::min(bytes.remaining(), MAX_BUF); if self.buf.len() < len { self.buf.reserve(len - self.buf.len()); diff --git a/tokio/src/io/mod.rs b/tokio/src/io/mod.rs index 9e0e063195c..c43f0e83140 100644 --- a/tokio/src/io/mod.rs +++ b/tokio/src/io/mod.rs @@ -196,6 +196,9 @@ pub use self::async_seek::AsyncSeek; mod async_write; pub use self::async_write::AsyncWrite; +mod read_buf; +pub use self::read_buf::ReadBuf; + // Re-export some types from `std::io` so that users don't have to deal // with conflicts when `use`ing `tokio::io` and `std::io`. pub use std::io::{Error, ErrorKind, Result, SeekFrom}; diff --git a/tokio/src/io/poll_evented.rs b/tokio/src/io/poll_evented.rs index 5295bd71ad8..785968f43f8 100644 --- a/tokio/src/io/poll_evented.rs +++ b/tokio/src/io/poll_evented.rs @@ -1,5 +1,5 @@ use crate::io::driver::platform; -use crate::io::{AsyncRead, AsyncWrite, Registration}; +use crate::io::{AsyncRead, AsyncWrite, ReadBuf, Registration}; use mio::event::Evented; use std::fmt; @@ -384,18 +384,22 @@ where fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf<'_>, + ) -> Poll> { ready!(self.poll_read_ready(cx, mio::Ready::readable()))?; - let r = (*self).get_mut().read(buf); + // We can't assume the `Read` won't look at the read buffer, + // so we have to force initialization here. + let r = (*self).get_mut().read(buf.initialize_unfilled()); if is_wouldblock(&r) { self.clear_read_ready(cx, mio::Ready::readable())?; return Poll::Pending; } - Poll::Ready(r) + Poll::Ready(r.map(|n| { + buf.add_filled(n); + })) } } diff --git a/tokio/src/io/read_buf.rs b/tokio/src/io/read_buf.rs new file mode 100644 index 00000000000..03b5d05ca03 --- /dev/null +++ b/tokio/src/io/read_buf.rs @@ -0,0 +1,253 @@ +// This lint claims ugly casting is somehow safer than transmute, but there's +// no evidence that is the case. Shush. +#![allow(clippy::transmute_ptr_to_ptr)] + +use std::fmt; +use std::mem::{self, MaybeUninit}; + +/// A wrapper around a byte buffer that is incrementally filled and initialized. +/// +/// This type is a sort of "double cursor". It tracks three regions in the +/// buffer: a region at the beginning of the buffer that has been logically +/// filled with data, a region that has been initialized at some point but not +/// yet logically filled, and a region at the end that is fully uninitialized. +/// The filled region is guaranteed to be a subset of the initialized region. +/// +/// In summary, the contents of the buffer can be visualized as: +/// +/// ```not_rust +/// [ capacity ] +/// [ filled | unfilled ] +/// [ initialized | uninitialized ] +/// ``` +pub struct ReadBuf<'a> { + buf: &'a mut [MaybeUninit], + filled: usize, + initialized: usize, +} + +impl<'a> ReadBuf<'a> { + /// Creates a new `ReadBuf` from a fully initialized buffer. + #[inline] + pub fn new(buf: &'a mut [u8]) -> ReadBuf<'a> { + let initialized = buf.len(); + let buf = unsafe { mem::transmute::<&mut [u8], &mut [MaybeUninit]>(buf) }; + ReadBuf { + buf, + filled: 0, + initialized, + } + } + + /// Creates a new `ReadBuf` from a fully uninitialized buffer. + /// + /// Use `assume_init` if part of the buffer is known to be already inintialized. + #[inline] + pub fn uninit(buf: &'a mut [MaybeUninit]) -> ReadBuf<'a> { + ReadBuf { + buf, + filled: 0, + initialized: 0, + } + } + + /// Returns the total capacity of the buffer. + #[inline] + pub fn capacity(&self) -> usize { + self.buf.len() + } + + /// Returns a shared reference to the filled portion of the buffer. + #[inline] + pub fn filled(&self) -> &[u8] { + let slice = &self.buf[..self.filled]; + // safety: filled describes how far into the buffer that the + // user has filled with bytes, so it's been initialized. + // TODO: This could use `MaybeUninit::slice_get_ref` when it is stable. + unsafe { mem::transmute::<&[MaybeUninit], &[u8]>(slice) } + } + + /// Returns a mutable reference to the filled portion of the buffer. + #[inline] + pub fn filled_mut(&mut self) -> &mut [u8] { + let slice = &mut self.buf[..self.filled]; + // safety: filled describes how far into the buffer that the + // user has filled with bytes, so it's been initialized. + // TODO: This could use `MaybeUninit::slice_get_mut` when it is stable. + unsafe { mem::transmute::<&mut [MaybeUninit], &mut [u8]>(slice) } + } + + /// Returns a shared reference to the initialized portion of the buffer. + /// + /// This includes the filled portion. + #[inline] + pub fn initialized(&self) -> &[u8] { + let slice = &self.buf[..self.initialized]; + // safety: initialized describes how far into the buffer that the + // user has at some point initialized with bytes. + // TODO: This could use `MaybeUninit::slice_get_ref` when it is stable. + unsafe { mem::transmute::<&[MaybeUninit], &[u8]>(slice) } + } + + /// Returns a mutable reference to the initialized portion of the buffer. + /// + /// This includes the filled portion. + #[inline] + pub fn initialized_mut(&mut self) -> &mut [u8] { + let slice = &mut self.buf[..self.initialized]; + // safety: initialized describes how far into the buffer that the + // user has at some point initialized with bytes. + // TODO: This could use `MaybeUninit::slice_get_mut` when it is stable. + unsafe { mem::transmute::<&mut [MaybeUninit], &mut [u8]>(slice) } + } + + /// Returns a mutable reference to the unfilled part of the buffer without ensuring that it has been fully + /// initialized. + /// + /// # Safety + /// + /// The caller must not de-initialize portions of the buffer that have already been initialized. + #[inline] + pub unsafe fn unfilled_mut(&mut self) -> &mut [MaybeUninit] { + &mut self.buf[self.filled..] + } + + /// Returns a mutable reference to the unfilled part of the buffer, ensuring it is fully initialized. + /// + /// Since `ReadBuf` tracks the region of the buffer that has been initialized, this is effectively "free" after + /// the first use. + #[inline] + pub fn initialize_unfilled(&mut self) -> &mut [u8] { + self.initialize_unfilled_to(self.remaining()) + } + + /// Returns a mutable reference to the first `n` bytes of the unfilled part of the buffer, ensuring it is + /// fully initialized. + /// + /// # Panics + /// + /// Panics if `self.remaining()` is less than `n`. + #[inline] + pub fn initialize_unfilled_to(&mut self, n: usize) -> &mut [u8] { + assert!(self.remaining() >= n, "n overflows remaining"); + + // This can't overflow, otherwise the assert above would have failed. + let end = self.filled + n; + + if self.initialized < end { + unsafe { + self.buf[self.initialized..end] + .as_mut_ptr() + .write_bytes(0, end - self.initialized); + } + self.initialized = end; + } + + let slice = &mut self.buf[self.filled..end]; + // safety: just above, we checked that the end of the buf has + // been initialized to some value. + unsafe { mem::transmute::<&mut [MaybeUninit], &mut [u8]>(slice) } + } + + /// Returns the number of bytes at the end of the slice that have not yet been filled. + #[inline] + pub fn remaining(&self) -> usize { + self.capacity() - self.filled + } + + /// Clears the buffer, resetting the filled region to empty. + /// + /// The number of initialized bytes is not changed, and the contents of the buffer are not modified. + #[inline] + pub fn clear(&mut self) { + self.filled = 0; + } + + /// Increases the size of the filled region of the buffer. + /// + /// The number of initialized bytes is not changed. + /// + /// # Panics + /// + /// Panics if the filled region of the buffer would become larger than the initialized region. + #[inline] + pub fn add_filled(&mut self, n: usize) { + let new = self.filled.checked_add(n).expect("filled overflow"); + self.set_filled(new); + } + + /// Sets the size of the filled region of the buffer. + /// + /// The number of initialized bytes is not changed. + /// + /// Note that this can be used to *shrink* the filled region of the buffer in addition to growing it (for + /// example, by a `AsyncRead` implementation that compresses data in-place). + /// + /// # Panics + /// + /// Panics if the filled region of the buffer would become larger than the intialized region. + #[inline] + pub fn set_filled(&mut self, n: usize) { + assert!( + n <= self.initialized, + "filled must not become larger than initialized" + ); + self.filled = n; + } + + /// Asserts that the first `n` unfilled bytes of the buffer are initialized. + /// + /// `ReadBuf` assumes that bytes are never de-initialized, so this method does nothing when called with fewer + /// bytes than are already known to be initialized. + /// + /// # Safety + /// + /// The caller must ensure that `n` unfilled bytes of the buffer have already been initialized. + #[inline] + pub unsafe fn assume_init(&mut self, n: usize) { + let new = self.filled + n; + if new > self.initialized { + self.initialized = new; + } + } + + /// Appends data to the buffer, advancing the written position and possibly also the initialized position. + /// + /// # Panics + /// + /// Panics if `self.remaining()` is less than `buf.len()`. + #[inline] + pub fn append(&mut self, buf: &[u8]) { + assert!( + self.remaining() >= buf.len(), + "buf.len() must fit in remaining()" + ); + + let amt = buf.len(); + // Cannot overflow, asserted above + let end = self.filled + amt; + + // Safety: the length is asserted above + unsafe { + self.buf[self.filled..end] + .as_mut_ptr() + .cast::() + .copy_from_nonoverlapping(buf.as_ptr(), amt); + } + + if self.initialized < end { + self.initialized = end; + } + self.filled = end; + } +} + +impl fmt::Debug for ReadBuf<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ReadBuf") + .field("filled", &self.filled) + .field("initialized", &self.initialized) + .field("capacity", &self.capacity()) + .finish() + } +} diff --git a/tokio/src/io/split.rs b/tokio/src/io/split.rs index 134b937a5f1..dcd3da2032b 100644 --- a/tokio/src/io/split.rs +++ b/tokio/src/io/split.rs @@ -4,7 +4,7 @@ //! To restore this read/write object from its `split::ReadHalf` and //! `split::WriteHalf` use `unsplit`. -use crate::io::{AsyncRead, AsyncWrite}; +use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; use bytes::{Buf, BufMut}; use std::cell::UnsafeCell; @@ -102,8 +102,8 @@ impl AsyncRead for ReadHalf { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf<'_>, + ) -> Poll> { let mut inner = ready!(self.inner.poll_lock(cx)); inner.stream_pin().poll_read(cx, buf) } diff --git a/tokio/src/io/stdin.rs b/tokio/src/io/stdin.rs index 325b8757ec1..c9578f17b64 100644 --- a/tokio/src/io/stdin.rs +++ b/tokio/src/io/stdin.rs @@ -1,5 +1,5 @@ use crate::io::blocking::Blocking; -use crate::io::AsyncRead; +use crate::io::{AsyncRead, ReadBuf}; use std::io; use std::pin::Pin; @@ -63,16 +63,11 @@ impl std::os::windows::io::AsRawHandle for Stdin { } impl AsyncRead for Stdin { - unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit]) -> bool { - // https://github.com/rust-lang/rust/blob/09c817eeb29e764cfc12d0a8d94841e3ffe34023/src/libstd/io/stdio.rs#L97 - false - } - fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf<'_>, + ) -> Poll> { Pin::new(&mut self.std).poll_read(cx, buf) } } diff --git a/tokio/src/io/util/buf_reader.rs b/tokio/src/io/util/buf_reader.rs index a1c5990a644..3ab78f0eb8d 100644 --- a/tokio/src/io/util/buf_reader.rs +++ b/tokio/src/io/util/buf_reader.rs @@ -1,10 +1,9 @@ use crate::io::util::DEFAULT_BUF_SIZE; -use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite}; +use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}; use bytes::Buf; use pin_project_lite::pin_project; -use std::io::{self, Read}; -use std::mem::MaybeUninit; +use std::io; use std::pin::Pin; use std::task::{Context, Poll}; use std::{cmp, fmt}; @@ -44,21 +43,12 @@ impl BufReader { /// Creates a new `BufReader` with the specified buffer capacity. pub fn with_capacity(capacity: usize, inner: R) -> Self { - unsafe { - let mut buffer = Vec::with_capacity(capacity); - buffer.set_len(capacity); - - { - // Convert to MaybeUninit - let b = &mut *(&mut buffer[..] as *mut [u8] as *mut [MaybeUninit]); - inner.prepare_uninitialized_buffer(b); - } - Self { - inner, - buf: buffer.into_boxed_slice(), - pos: 0, - cap: 0, - } + let buffer = vec![0; capacity]; + Self { + inner, + buf: buffer.into_boxed_slice(), + pos: 0, + cap: 0, } } @@ -110,25 +100,21 @@ impl AsyncRead for BufReader { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf<'_>, + ) -> Poll> { // If we don't have any buffered data and we're doing a massive read // (larger than our internal buffer), bypass our internal buffer // entirely. - if self.pos == self.cap && buf.len() >= self.buf.len() { + if self.pos == self.cap && buf.remaining() >= self.buf.len() { let res = ready!(self.as_mut().get_pin_mut().poll_read(cx, buf)); self.discard_buffer(); return Poll::Ready(res); } - let mut rem = ready!(self.as_mut().poll_fill_buf(cx))?; - let nread = rem.read(buf)?; - self.consume(nread); - Poll::Ready(Ok(nread)) - } - - // we can't skip unconditionally because of the large buffer case in read. - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit]) -> bool { - self.inner.prepare_uninitialized_buffer(buf) + let rem = ready!(self.as_mut().poll_fill_buf(cx))?; + let amt = std::cmp::min(rem.len(), buf.remaining()); + buf.append(&rem[..amt]); + self.consume(amt); + Poll::Ready(Ok(())) } } @@ -142,7 +128,9 @@ impl AsyncBufRead for BufReader { // to tell the compiler that the pos..cap slice is always valid. if *me.pos >= *me.cap { debug_assert!(*me.pos == *me.cap); - *me.cap = ready!(me.inner.poll_read(cx, me.buf))?; + let mut buf = ReadBuf::new(me.buf); + ready!(me.inner.poll_read(cx, &mut buf))?; + *me.cap = buf.filled().len(); *me.pos = 0; } Poll::Ready(Ok(&me.buf[*me.pos..*me.cap])) diff --git a/tokio/src/io/util/buf_stream.rs b/tokio/src/io/util/buf_stream.rs index a56a4517fa4..cc857e225bc 100644 --- a/tokio/src/io/util/buf_stream.rs +++ b/tokio/src/io/util/buf_stream.rs @@ -1,9 +1,8 @@ use crate::io::util::{BufReader, BufWriter}; -use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite}; +use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}; use pin_project_lite::pin_project; use std::io; -use std::mem::MaybeUninit; use std::pin::Pin; use std::task::{Context, Poll}; @@ -137,15 +136,10 @@ impl AsyncRead for BufStream { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf<'_>, + ) -> Poll> { self.project().inner.poll_read(cx, buf) } - - // we can't skip unconditionally because of the large buffer case in read. - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit]) -> bool { - self.inner.prepare_uninitialized_buffer(buf) - } } impl AsyncBufRead for BufStream { diff --git a/tokio/src/io/util/buf_writer.rs b/tokio/src/io/util/buf_writer.rs index efd053ebac6..5e3d4b710f2 100644 --- a/tokio/src/io/util/buf_writer.rs +++ b/tokio/src/io/util/buf_writer.rs @@ -1,10 +1,9 @@ use crate::io::util::DEFAULT_BUF_SIZE; -use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite}; +use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}; use pin_project_lite::pin_project; use std::fmt; use std::io::{self, Write}; -use std::mem::MaybeUninit; use std::pin::Pin; use std::task::{Context, Poll}; @@ -147,15 +146,10 @@ impl AsyncRead for BufWriter { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf<'_>, + ) -> Poll> { self.get_pin_mut().poll_read(cx, buf) } - - // we can't skip unconditionally because of the large buffer case in read. - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit]) -> bool { - self.get_ref().prepare_uninitialized_buffer(buf) - } } impl AsyncBufRead for BufWriter { diff --git a/tokio/src/io/util/chain.rs b/tokio/src/io/util/chain.rs index 8ba9194f5de..84f37fc7d46 100644 --- a/tokio/src/io/util/chain.rs +++ b/tokio/src/io/util/chain.rs @@ -1,4 +1,4 @@ -use crate::io::{AsyncBufRead, AsyncRead}; +use crate::io::{AsyncBufRead, AsyncRead, ReadBuf}; use pin_project_lite::pin_project; use std::fmt; @@ -84,26 +84,20 @@ where T: AsyncRead, U: AsyncRead, { - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit]) -> bool { - if self.first.prepare_uninitialized_buffer(buf) { - return true; - } - if self.second.prepare_uninitialized_buffer(buf) { - return true; - } - false - } fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf<'_>, + ) -> Poll> { let me = self.project(); if !*me.done_first { - match ready!(me.first.poll_read(cx, buf)?) { - 0 if !buf.is_empty() => *me.done_first = true, - n => return Poll::Ready(Ok(n)), + let rem = buf.remaining(); + ready!(me.first.poll_read(cx, buf))?; + if buf.remaining() == rem { + *me.done_first = true; + } else { + return Poll::Ready(Ok(())); } } me.second.poll_read(cx, buf) diff --git a/tokio/src/io/util/copy.rs b/tokio/src/io/util/copy.rs index 7bfe296941e..86001ee7696 100644 --- a/tokio/src/io/util/copy.rs +++ b/tokio/src/io/util/copy.rs @@ -1,4 +1,4 @@ -use crate::io::{AsyncRead, AsyncWrite}; +use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; use std::future::Future; use std::io; @@ -88,7 +88,9 @@ where // continue. if self.pos == self.cap && !self.read_done { let me = &mut *self; - let n = ready!(Pin::new(&mut *me.reader).poll_read(cx, &mut me.buf))?; + let mut buf = ReadBuf::new(&mut me.buf); + ready!(Pin::new(&mut *me.reader).poll_read(cx, &mut buf))?; + let n = buf.filled().len(); if n == 0 { self.read_done = true; } else { diff --git a/tokio/src/io/util/empty.rs b/tokio/src/io/util/empty.rs index 576058d52d1..f964d18e6ef 100644 --- a/tokio/src/io/util/empty.rs +++ b/tokio/src/io/util/empty.rs @@ -1,4 +1,4 @@ -use crate::io::{AsyncBufRead, AsyncRead}; +use crate::io::{AsyncBufRead, AsyncRead, ReadBuf}; use std::fmt; use std::io; @@ -47,16 +47,13 @@ cfg_io_util! { } impl AsyncRead for Empty { - unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit]) -> bool { - false - } #[inline] fn poll_read( self: Pin<&mut Self>, _: &mut Context<'_>, - _: &mut [u8], - ) -> Poll> { - Poll::Ready(Ok(0)) + _: &mut ReadBuf<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) } } diff --git a/tokio/src/io/util/mem.rs b/tokio/src/io/util/mem.rs index 02ba6aa7e91..1b9b37b71dc 100644 --- a/tokio/src/io/util/mem.rs +++ b/tokio/src/io/util/mem.rs @@ -1,6 +1,6 @@ //! In-process memory IO types. -use crate::io::{AsyncRead, AsyncWrite}; +use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; use crate::loom::sync::Mutex; use bytes::{Buf, BytesMut}; @@ -98,8 +98,8 @@ impl AsyncRead for DuplexStream { fn poll_read( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf<'_>, + ) -> Poll> { Pin::new(&mut *self.read.lock().unwrap()).poll_read(cx, buf) } } @@ -163,11 +163,12 @@ impl AsyncRead for Pipe { fn poll_read( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf<'_>, + ) -> Poll> { if self.buffer.has_remaining() { - let max = self.buffer.remaining().min(buf.len()); - self.buffer.copy_to_slice(&mut buf[..max]); + let max = self.buffer.remaining().min(buf.remaining()); + buf.append(&self.buffer[..max]); + self.buffer.advance(max); if max > 0 { // The passed `buf` might have been empty, don't wake up if // no bytes have been moved. @@ -175,9 +176,9 @@ impl AsyncRead for Pipe { waker.wake(); } } - Poll::Ready(Ok(max)) + Poll::Ready(Ok(())) } else if self.is_closed { - Poll::Ready(Ok(0)) + Poll::Ready(Ok(())) } else { self.read_waker = Some(cx.waker().clone()); Poll::Pending diff --git a/tokio/src/io/util/read.rs b/tokio/src/io/util/read.rs index a8ca370ea87..28470d5a5c6 100644 --- a/tokio/src/io/util/read.rs +++ b/tokio/src/io/util/read.rs @@ -1,4 +1,4 @@ -use crate::io::AsyncRead; +use crate::io::{AsyncRead, ReadBuf}; use std::future::Future; use std::io; @@ -39,7 +39,9 @@ where fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let me = &mut *self; - Pin::new(&mut *me.reader).poll_read(cx, me.buf) + let mut buf = ReadBuf::new(me.buf); + ready!(Pin::new(&mut *me.reader).poll_read(cx, &mut buf))?; + Poll::Ready(Ok(buf.filled().len())) } } diff --git a/tokio/src/io/util/read_exact.rs b/tokio/src/io/util/read_exact.rs index 86b8412954b..970074aa5ec 100644 --- a/tokio/src/io/util/read_exact.rs +++ b/tokio/src/io/util/read_exact.rs @@ -1,4 +1,4 @@ -use crate::io::AsyncRead; +use crate::io::{AsyncRead, ReadBuf}; use std::future::Future; use std::io; @@ -17,8 +17,7 @@ where { ReadExact { reader, - buf, - pos: 0, + buf: ReadBuf::new(buf), } } @@ -31,8 +30,7 @@ cfg_io_util! { #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct ReadExact<'a, A: ?Sized> { reader: &'a mut A, - buf: &'a mut [u8], - pos: usize, + buf: ReadBuf<'a>, } } @@ -49,17 +47,15 @@ where fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { // if our buffer is empty, then we need to read some data to continue. - if self.pos < self.buf.len() { + let rem = self.buf.remaining(); + if rem != 0 { let me = &mut *self; - let n = ready!(Pin::new(&mut *me.reader).poll_read(cx, &mut me.buf[me.pos..]))?; - me.pos += n; - if n == 0 { + ready!(Pin::new(&mut *me.reader).poll_read(cx, &mut me.buf))?; + if me.buf.remaining() == rem { return Err(eof()).into(); } - } - - if self.pos >= self.buf.len() { - return Poll::Ready(Ok(self.pos)); + } else { + return Poll::Ready(Ok(self.buf.capacity())); } } } diff --git a/tokio/src/io/util/read_int.rs b/tokio/src/io/util/read_int.rs index 9d37dc7a400..c3dbbd56943 100644 --- a/tokio/src/io/util/read_int.rs +++ b/tokio/src/io/util/read_int.rs @@ -1,4 +1,4 @@ -use crate::io::AsyncRead; +use crate::io::{AsyncRead, ReadBuf}; use bytes::Buf; use pin_project_lite::pin_project; @@ -48,17 +48,19 @@ macro_rules! reader { } while *me.read < $bytes as u8 { - *me.read += match me - .src - .as_mut() - .poll_read(cx, &mut me.buf[*me.read as usize..]) - { + let mut buf = ReadBuf::new(&mut me.buf[*me.read as usize..]); + + *me.read += match me.src.as_mut().poll_read(cx, &mut buf) { Poll::Pending => return Poll::Pending, Poll::Ready(Err(e)) => return Poll::Ready(Err(e.into())), - Poll::Ready(Ok(0)) => { - return Poll::Ready(Err(UnexpectedEof.into())); + Poll::Ready(Ok(())) => { + let n = buf.filled().len(); + if n == 0 { + return Poll::Ready(Err(UnexpectedEof.into())); + } + + n as u8 } - Poll::Ready(Ok(n)) => n as u8, }; } @@ -97,12 +99,17 @@ macro_rules! reader8 { let me = self.project(); let mut buf = [0; 1]; - match me.reader.poll_read(cx, &mut buf[..]) { + let mut buf = ReadBuf::new(&mut buf); + match me.reader.poll_read(cx, &mut buf) { Poll::Pending => Poll::Pending, Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), - Poll::Ready(Ok(0)) => Poll::Ready(Err(UnexpectedEof.into())), - Poll::Ready(Ok(1)) => Poll::Ready(Ok(buf[0] as $ty)), - Poll::Ready(Ok(_)) => unreachable!(), + Poll::Ready(Ok(())) => { + if buf.filled().len() == 0 { + return Poll::Ready(Err(UnexpectedEof.into())); + } + + Poll::Ready(Ok(buf.filled()[0] as $ty)) + } } } } diff --git a/tokio/src/io/util/read_to_end.rs b/tokio/src/io/util/read_to_end.rs index 29b8b811f72..609af28e9fb 100644 --- a/tokio/src/io/util/read_to_end.rs +++ b/tokio/src/io/util/read_to_end.rs @@ -1,4 +1,4 @@ -use crate::io::AsyncRead; +use crate::io::{AsyncRead, ReadBuf}; use std::future::Future; use std::io; @@ -21,7 +21,6 @@ pub(crate) fn read_to_end<'a, R>(reader: &'a mut R, buffer: &'a mut Vec) -> where R: AsyncRead + Unpin + ?Sized, { - prepare_buffer(buffer, reader); ReadToEnd { reader, buf: buffer, @@ -29,12 +28,7 @@ where } } -/// # Safety -/// -/// Before first calling this method, the unused capacity must have been -/// prepared for use with the provided AsyncRead. This can be done using the -/// `prepare_buffer` function later in this file. -pub(super) unsafe fn read_to_end_internal( +pub(super) fn read_to_end_internal( buf: &mut Vec, mut reader: Pin<&mut R>, num_read: &mut usize, @@ -56,13 +50,7 @@ pub(super) unsafe fn read_to_end_internal( /// Tries to read from the provided AsyncRead. /// /// The length of the buffer is increased by the number of bytes read. -/// -/// # Safety -/// -/// The caller ensures that the buffer has been prepared for use with the -/// AsyncRead before calling this function. This can be done using the -/// `prepare_buffer` function later in this file. -unsafe fn poll_read_to_end( +fn poll_read_to_end( buf: &mut Vec, read: Pin<&mut R>, cx: &mut Context<'_>, @@ -73,70 +61,32 @@ unsafe fn poll_read_to_end( // of data to return. Simply tacking on an extra DEFAULT_BUF_SIZE space every // time is 4,500 times (!) slower than this if the reader has a very small // amount of data to return. - reserve(buf, &*read, 32); - - let unused_capacity: &mut [MaybeUninit] = get_unused_capacity(buf); - - // safety: The buffer has been prepared for use with the AsyncRead before - // calling this function. - let slice: &mut [u8] = &mut *(unused_capacity as *mut [MaybeUninit] as *mut [u8]); - - let res = ready!(read.poll_read(cx, slice)); - if let Ok(num) = res { - // safety: There are two situations: - // - // 1. The AsyncRead has not overriden `prepare_uninitialized_buffer`. - // - // In this situation, the default implementation of that method will have - // zeroed the unused capacity. This means that setting the length will - // never expose uninitialized memory in the vector. - // - // Note that the assert! below ensures that we don't set the length to - // something larger than the capacity, which malicious implementors might - // try to have us do. - // - // 2. The AsyncRead has overriden `prepare_uninitialized_buffer`. - // - // In this case, the safety of the `set_len` call below relies on this - // guarantee from the documentation on `prepare_uninitialized_buffer`: - // - // > This function isn't actually unsafe to call but unsafe to implement. - // > The implementer must ensure that either the whole buf has been zeroed - // > or poll_read() overwrites the buffer without reading it and returns - // > correct value. - // - // Note that `prepare_uninitialized_buffer` is unsafe to implement, so this - // is a guarantee we can rely on in unsafe code. - // - // The assert!() is technically only necessary in the first case. - let new_len = buf.len() + num; - assert!(new_len <= buf.capacity()); + reserve(buf, 32); - buf.set_len(new_len); - } - Poll::Ready(res) -} + let mut unused_capacity = ReadBuf::uninit(get_unused_capacity(buf)); -/// This function prepares the unused capacity for use with the provided AsyncRead. -pub(super) fn prepare_buffer(buf: &mut Vec, read: &R) { - let buffer = get_unused_capacity(buf); + ready!(read.poll_read(cx, &mut unused_capacity))?; - // safety: This function is only unsafe to implement. + let n = unused_capacity.filled().len(); + let new_len = buf.len() + n; + + // This should no longer even be possible in safe Rust. An implementor + // would need to have unsafely *replaced* the buffer inside `ReadBuf`, + // which... yolo? + assert!(new_len <= buf.capacity()); unsafe { - read.prepare_uninitialized_buffer(buffer); + buf.set_len(new_len); } + Poll::Ready(Ok(n)) } /// Allocates more memory and ensures that the unused capacity is prepared for use /// with the `AsyncRead`. -fn reserve(buf: &mut Vec, read: &R, bytes: usize) { +fn reserve(buf: &mut Vec, bytes: usize) { if buf.capacity() - buf.len() >= bytes { return; } buf.reserve(bytes); - // The call above has reallocated the buffer, so we must reinitialize the entire - // unused capacity, even if we already initialized some of it before the resize. - prepare_buffer(buf, read); } /// Returns the unused capacity of the provided vector. @@ -153,8 +103,7 @@ where fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let Self { reader, buf, read } = &mut *self; - // safety: The constructor of ReadToEnd calls `prepare_buffer` - unsafe { read_to_end_internal(buf, Pin::new(*reader), read, cx) } + read_to_end_internal(buf, Pin::new(*reader), read, cx) } } diff --git a/tokio/src/io/util/read_to_string.rs b/tokio/src/io/util/read_to_string.rs index 4ef50be308c..cf00e50d918 100644 --- a/tokio/src/io/util/read_to_string.rs +++ b/tokio/src/io/util/read_to_string.rs @@ -1,5 +1,5 @@ use crate::io::util::read_line::finish_string_read; -use crate::io::util::read_to_end::{prepare_buffer, read_to_end_internal}; +use crate::io::util::read_to_end::read_to_end_internal; use crate::io::AsyncRead; use std::future::Future; @@ -31,8 +31,7 @@ pub(crate) fn read_to_string<'a, R>( where R: AsyncRead + ?Sized + Unpin, { - let mut buf = mem::replace(string, String::new()).into_bytes(); - prepare_buffer(&mut buf, reader); + let buf = mem::replace(string, String::new()).into_bytes(); ReadToString { reader, buf, diff --git a/tokio/src/io/util/repeat.rs b/tokio/src/io/util/repeat.rs index eeef7cc187b..b942691d331 100644 --- a/tokio/src/io/util/repeat.rs +++ b/tokio/src/io/util/repeat.rs @@ -1,4 +1,4 @@ -use crate::io::AsyncRead; +use crate::io::{AsyncRead, ReadBuf}; use std::io; use std::pin::Pin; @@ -47,19 +47,17 @@ cfg_io_util! { } impl AsyncRead for Repeat { - unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit]) -> bool { - false - } #[inline] fn poll_read( self: Pin<&mut Self>, _: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - for byte in &mut *buf { - *byte = self.byte; + buf: &mut ReadBuf<'_>, + ) -> Poll> { + // TODO: could be faster, but should we unsafe it? + while buf.remaining() != 0 { + buf.append(&[self.byte]); } - Poll::Ready(Ok(buf.len())) + Poll::Ready(Ok(())) } } diff --git a/tokio/src/io/util/stream_reader.rs b/tokio/src/io/util/stream_reader.rs index b98f8bdfc28..2471197a46e 100644 --- a/tokio/src/io/util/stream_reader.rs +++ b/tokio/src/io/util/stream_reader.rs @@ -1,9 +1,8 @@ -use crate::io::{AsyncBufRead, AsyncRead}; +use crate::io::{AsyncBufRead, AsyncRead, ReadBuf}; use crate::stream::Stream; use bytes::{Buf, BufMut}; use pin_project_lite::pin_project; use std::io; -use std::mem::MaybeUninit; use std::pin::Pin; use std::task::{Context, Poll}; @@ -103,10 +102,10 @@ where fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - if buf.is_empty() { - return Poll::Ready(Ok(0)); + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if buf.remaining() == 0 { + return Poll::Ready(Ok(())); } let inner_buf = match self.as_mut().poll_fill_buf(cx) { @@ -114,11 +113,11 @@ where Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), Poll::Pending => return Poll::Pending, }; - let len = std::cmp::min(inner_buf.len(), buf.len()); - (&mut buf[..len]).copy_from_slice(&inner_buf[..len]); + let len = std::cmp::min(inner_buf.len(), buf.remaining()); + buf.append(&inner_buf[..len]); self.consume(len); - Poll::Ready(Ok(len)) + Poll::Ready(Ok(())) } fn poll_read_buf( mut self: Pin<&mut Self>, @@ -143,9 +142,6 @@ where self.consume(len); Poll::Ready(Ok(len)) } - unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [MaybeUninit]) -> bool { - false - } } impl AsyncBufRead for StreamReader diff --git a/tokio/src/io/util/take.rs b/tokio/src/io/util/take.rs index 5d6bd90aa31..2abc7693172 100644 --- a/tokio/src/io/util/take.rs +++ b/tokio/src/io/util/take.rs @@ -1,7 +1,6 @@ -use crate::io::{AsyncBufRead, AsyncRead}; +use crate::io::{AsyncBufRead, AsyncRead, ReadBuf}; use pin_project_lite::pin_project; -use std::mem::MaybeUninit; use std::pin::Pin; use std::task::{Context, Poll}; use std::{cmp, io}; @@ -76,24 +75,30 @@ impl Take { } impl AsyncRead for Take { - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit]) -> bool { - self.inner.prepare_uninitialized_buffer(buf) - } - fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf<'_>, + ) -> Poll> { if self.limit_ == 0 { - return Poll::Ready(Ok(0)); + return Poll::Ready(Ok(())); } let me = self.project(); - let max = std::cmp::min(buf.len() as u64, *me.limit_) as usize; - let n = ready!(me.inner.poll_read(cx, &mut buf[..max]))?; + let max = std::cmp::min(buf.remaining() as u64, *me.limit_) as usize; + // Make a ReadBuf of the unfulled section up to max + // Saftey: We don't set any of the `unfilled_mut` with `MaybeUninit::uninit`. + let mut b = unsafe { ReadBuf::uninit(&mut buf.unfilled_mut()[..max]) }; + ready!(me.inner.poll_read(cx, &mut b))?; + let n = b.filled().len(); + + // We need to update the original ReadBuf + unsafe { + buf.assume_init(n); + } + buf.add_filled(n); *me.limit_ -= n as u64; - Poll::Ready(Ok(n)) + Poll::Ready(Ok(())) } } diff --git a/tokio/src/net/tcp/split.rs b/tokio/src/net/tcp/split.rs index 0c1e359f72d..9d99d7bdfbf 100644 --- a/tokio/src/net/tcp/split.rs +++ b/tokio/src/net/tcp/split.rs @@ -9,12 +9,11 @@ //! level. use crate::future::poll_fn; -use crate::io::{AsyncRead, AsyncWrite}; +use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; use crate::net::TcpStream; use bytes::Buf; use std::io; -use std::mem::MaybeUninit; use std::net::Shutdown; use std::pin::Pin; use std::task::{Context, Poll}; @@ -131,15 +130,11 @@ impl ReadHalf<'_> { } impl AsyncRead for ReadHalf<'_> { - unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit]) -> bool { - false - } - fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf<'_>, + ) -> Poll> { self.0.poll_read_priv(cx, buf) } } diff --git a/tokio/src/net/tcp/split_owned.rs b/tokio/src/net/tcp/split_owned.rs index 6c2b9e6977e..87be6efd8a1 100644 --- a/tokio/src/net/tcp/split_owned.rs +++ b/tokio/src/net/tcp/split_owned.rs @@ -9,12 +9,11 @@ //! level. use crate::future::poll_fn; -use crate::io::{AsyncRead, AsyncWrite}; +use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; use crate::net::TcpStream; use bytes::Buf; use std::error::Error; -use std::mem::MaybeUninit; use std::net::Shutdown; use std::pin::Pin; use std::sync::Arc; @@ -186,15 +185,11 @@ impl OwnedReadHalf { } impl AsyncRead for OwnedReadHalf { - unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit]) -> bool { - false - } - fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf<'_>, + ) -> Poll> { self.inner.poll_read_priv(cx, buf) } } diff --git a/tokio/src/net/tcp/stream.rs b/tokio/src/net/tcp/stream.rs index 02b5262723e..e624fb9d954 100644 --- a/tokio/src/net/tcp/stream.rs +++ b/tokio/src/net/tcp/stream.rs @@ -1,5 +1,5 @@ use crate::future::poll_fn; -use crate::io::{AsyncRead, AsyncWrite, PollEvented}; +use crate::io::{AsyncRead, AsyncWrite, PollEvented, ReadBuf}; use crate::net::tcp::split::{split, ReadHalf, WriteHalf}; use crate::net::tcp::split_owned::{split_owned, OwnedReadHalf, OwnedWriteHalf}; use crate::net::ToSocketAddrs; @@ -9,7 +9,6 @@ use iovec::IoVec; use std::convert::TryFrom; use std::fmt; use std::io::{self, Read, Write}; -use std::mem::MaybeUninit; use std::net::{self, Shutdown, SocketAddr}; use std::pin::Pin; use std::task::{Context, Poll}; @@ -702,16 +701,28 @@ impl TcpStream { pub(crate) fn poll_read_priv( &self, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf<'_>, + ) -> Poll> { ready!(self.io.poll_read_ready(cx, mio::Ready::readable()))?; - match self.io.get_ref().read(buf) { + // Safety: `TcpStream::read` will not peak at the maybe uinitialized bytes. + let b = + unsafe { &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit] as *mut [u8]) }; + match self.io.get_ref().read(b) { Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { self.io.clear_read_ready(cx, mio::Ready::readable())?; Poll::Pending } - x => Poll::Ready(x), + Ok(n) => { + // Safety: We trust `TcpStream::read` to have filled up `n` bytes + // in the buffer. + unsafe { + buf.assume_init(n); + } + buf.add_filled(n); + Poll::Ready(Ok(())) + } + Err(e) => Poll::Ready(Err(e)), } } @@ -864,15 +875,11 @@ impl TryFrom for TcpStream { // ===== impl Read / Write ===== impl AsyncRead for TcpStream { - unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit]) -> bool { - false - } - fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf<'_>, + ) -> Poll> { self.poll_read_priv(cx, buf) } } diff --git a/tokio/src/net/unix/split.rs b/tokio/src/net/unix/split.rs index 4fd85774e9a..460bbc1954b 100644 --- a/tokio/src/net/unix/split.rs +++ b/tokio/src/net/unix/split.rs @@ -8,11 +8,10 @@ //! split has no associated overhead and enforces all invariants at the type //! level. -use crate::io::{AsyncRead, AsyncWrite}; +use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; use crate::net::UnixStream; use std::io; -use std::mem::MaybeUninit; use std::net::Shutdown; use std::pin::Pin; use std::task::{Context, Poll}; @@ -51,15 +50,11 @@ pub(crate) fn split(stream: &mut UnixStream) -> (ReadHalf<'_>, WriteHalf<'_>) { } impl AsyncRead for ReadHalf<'_> { - unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit]) -> bool { - false - } - fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf<'_>, + ) -> Poll> { self.0.poll_read_priv(cx, buf) } } diff --git a/tokio/src/net/unix/split_owned.rs b/tokio/src/net/unix/split_owned.rs index eb35304bfa2..ab233072b35 100644 --- a/tokio/src/net/unix/split_owned.rs +++ b/tokio/src/net/unix/split_owned.rs @@ -8,11 +8,10 @@ //! split has no associated overhead and enforces all invariants at the type //! level. -use crate::io::{AsyncRead, AsyncWrite}; +use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; use crate::net::UnixStream; use std::error::Error; -use std::mem::MaybeUninit; use std::net::Shutdown; use std::pin::Pin; use std::sync::Arc; @@ -109,15 +108,11 @@ impl OwnedReadHalf { } impl AsyncRead for OwnedReadHalf { - unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit]) -> bool { - false - } - fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf<'_>, + ) -> Poll> { self.inner.poll_read_priv(cx, buf) } } diff --git a/tokio/src/net/unix/stream.rs b/tokio/src/net/unix/stream.rs index 5fe242d0887..559fe02a625 100644 --- a/tokio/src/net/unix/stream.rs +++ b/tokio/src/net/unix/stream.rs @@ -1,5 +1,5 @@ use crate::future::poll_fn; -use crate::io::{AsyncRead, AsyncWrite, PollEvented}; +use crate::io::{AsyncRead, AsyncWrite, PollEvented, ReadBuf}; use crate::net::unix::split::{split, ReadHalf, WriteHalf}; use crate::net::unix::split_owned::{split_owned, OwnedReadHalf, OwnedWriteHalf}; use crate::net::unix::ucred::{self, UCred}; @@ -7,7 +7,6 @@ use crate::net::unix::ucred::{self, UCred}; use std::convert::TryFrom; use std::fmt; use std::io::{self, Read, Write}; -use std::mem::MaybeUninit; use std::net::Shutdown; use std::os::unix::io::{AsRawFd, RawFd}; use std::os::unix::net::{self, SocketAddr}; @@ -167,15 +166,11 @@ impl TryFrom for UnixStream { } impl AsyncRead for UnixStream { - unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit]) -> bool { - false - } - fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf<'_>, + ) -> Poll> { self.poll_read_priv(cx, buf) } } @@ -214,16 +209,28 @@ impl UnixStream { pub(crate) fn poll_read_priv( &self, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf<'_>, + ) -> Poll> { ready!(self.io.poll_read_ready(cx, mio::Ready::readable()))?; - match self.io.get_ref().read(buf) { + // Safety: `UnixStream::read` will not peak at the maybe uinitialized bytes. + let b = + unsafe { &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit] as *mut [u8]) }; + match self.io.get_ref().read(b) { Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { self.io.clear_read_ready(cx, mio::Ready::readable())?; Poll::Pending } - x => Poll::Ready(x), + Ok(n) => { + // Safety: We trust `UnixStream::read` to have filled up `n` bytes + // in the buffer. + unsafe { + buf.assume_init(n); + } + buf.add_filled(n); + Poll::Ready(Ok(())) + } + Err(e) => Poll::Ready(Err(e)), } } diff --git a/tokio/src/process/mod.rs b/tokio/src/process/mod.rs index 4a070023b06..a3b7c384101 100644 --- a/tokio/src/process/mod.rs +++ b/tokio/src/process/mod.rs @@ -120,7 +120,7 @@ mod imp; mod kill; -use crate::io::{AsyncRead, AsyncWrite}; +use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; use crate::process::kill::Kill; use std::ffi::OsStr; @@ -909,31 +909,21 @@ impl AsyncWrite for ChildStdin { } impl AsyncRead for ChildStdout { - unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit]) -> bool { - // https://github.com/rust-lang/rust/blob/09c817eeb29e764cfc12d0a8d94841e3ffe34023/src/libstd/process.rs#L314 - false - } - fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf<'_>, + ) -> Poll> { Pin::new(&mut self.inner).poll_read(cx, buf) } } impl AsyncRead for ChildStderr { - unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit]) -> bool { - // https://github.com/rust-lang/rust/blob/09c817eeb29e764cfc12d0a8d94841e3ffe34023/src/libstd/process.rs#L375 - false - } - fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf<'_>, + ) -> Poll> { Pin::new(&mut self.inner).poll_read(cx, buf) } } diff --git a/tokio/src/signal/unix.rs b/tokio/src/signal/unix.rs index b46b15c99a6..bc48bdfaa64 100644 --- a/tokio/src/signal/unix.rs +++ b/tokio/src/signal/unix.rs @@ -5,7 +5,7 @@ #![cfg(unix)] -use crate::io::{AsyncRead, PollEvented}; +use crate::io::{AsyncRead, PollEvented, ReadBuf}; use crate::signal::registry::{globals, EventId, EventInfo, Globals, Init, Storage}; use crate::sync::mpsc::{channel, Receiver}; @@ -300,10 +300,16 @@ impl Driver { /// [#38](https://github.com/alexcrichton/tokio-signal/issues/38) for more /// info. fn drain(&mut self, cx: &mut Context<'_>) { + let mut buf = [0; 128]; + let mut buf = ReadBuf::new(&mut buf); loop { - match Pin::new(&mut self.wakeup).poll_read(cx, &mut [0; 128]) { - Poll::Ready(Ok(0)) => panic!("EOF on self-pipe"), - Poll::Ready(Ok(_)) => {} + match Pin::new(&mut self.wakeup).poll_read(cx, &mut buf) { + Poll::Ready(Ok(())) => { + if buf.filled().is_empty() { + panic!("EOF on self-pipe") + } + buf.clear(); + } Poll::Ready(Err(e)) => panic!("Bad read on self-pipe: {}", e), Poll::Pending => break, } diff --git a/tokio/tests/io_async_read.rs b/tokio/tests/io_async_read.rs index 20440bbde35..d1aae9a1a7f 100644 --- a/tokio/tests/io_async_read.rs +++ b/tokio/tests/io_async_read.rs @@ -1,14 +1,12 @@ -#![allow(clippy::transmute_ptr_to_ptr)] #![warn(rust_2018_idioms)] #![cfg(feature = "full")] -use tokio::io::AsyncRead; +use tokio::io::{AsyncRead, ReadBuf}; use tokio_test::task; use tokio_test::{assert_ready_err, assert_ready_ok}; -use bytes::{BufMut, BytesMut}; +use bytes::BytesMut; use std::io; -use std::mem::MaybeUninit; use std::pin::Pin; use std::task::{Context, Poll}; @@ -26,10 +24,10 @@ fn read_buf_success() { fn poll_read( self: Pin<&mut Self>, _cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - buf[0..11].copy_from_slice(b"hello world"); - Poll::Ready(Ok(11)) + buf: &mut ReadBuf<'_>, + ) -> Poll> { + buf.append(b"hello world"); + Poll::Ready(Ok(())) } } @@ -51,8 +49,8 @@ fn read_buf_error() { fn poll_read( self: Pin<&mut Self>, _cx: &mut Context<'_>, - _buf: &mut [u8], - ) -> Poll> { + _buf: &mut ReadBuf<'_>, + ) -> Poll> { let err = io::ErrorKind::Other.into(); Poll::Ready(Err(err)) } @@ -74,8 +72,8 @@ fn read_buf_no_capacity() { fn poll_read( self: Pin<&mut Self>, _cx: &mut Context<'_>, - _buf: &mut [u8], - ) -> Poll> { + _buf: &mut ReadBuf<'_>, + ) -> Poll> { unimplemented!(); } } @@ -88,59 +86,26 @@ fn read_buf_no_capacity() { }); } -#[test] -fn read_buf_no_uninitialized() { - struct Rd; - - impl AsyncRead for Rd { - fn poll_read( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - for b in buf { - assert_eq!(0, *b); - } - - Poll::Ready(Ok(0)) - } - } - - let mut buf = BytesMut::with_capacity(64); - - task::spawn(Rd).enter(|cx, rd| { - let n = assert_ready_ok!(rd.poll_read_buf(cx, &mut buf)); - assert_eq!(0, n); - }); -} - #[test] fn read_buf_uninitialized_ok() { struct Rd; impl AsyncRead for Rd { - unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit]) -> bool { - false - } - fn poll_read( self: Pin<&mut Self>, _cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - assert_eq!(buf[0..11], b"hello world"[..]); - Poll::Ready(Ok(0)) + buf: &mut ReadBuf<'_>, + ) -> Poll> { + assert_eq!(buf.remaining(), 64); + assert_eq!(buf.filled().len(), 0); + assert_eq!(buf.initialized().len(), 0); + Poll::Ready(Ok(())) } } // Can't create BytesMut w/ zero capacity, so fill it up let mut buf = BytesMut::with_capacity(64); - unsafe { - let b: &mut [u8] = std::mem::transmute(buf.bytes_mut()); - b[0..11].copy_from_slice(b"hello world"); - } - task::spawn(Rd).enter(|cx, rd| { let n = assert_ready_ok!(rd.poll_read_buf(cx, &mut buf)); assert_eq!(0, n); diff --git a/tokio/tests/io_copy.rs b/tokio/tests/io_copy.rs index c1c6df4eb34..aed6c789d23 100644 --- a/tokio/tests/io_copy.rs +++ b/tokio/tests/io_copy.rs @@ -1,7 +1,7 @@ #![warn(rust_2018_idioms)] #![cfg(feature = "full")] -use tokio::io::{self, AsyncRead}; +use tokio::io::{self, AsyncRead, ReadBuf}; use tokio_test::assert_ok; use std::pin::Pin; @@ -15,14 +15,14 @@ async fn copy() { fn poll_read( mut self: Pin<&mut Self>, _cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf<'_>, + ) -> Poll> { if self.0 { - buf[0..11].copy_from_slice(b"hello world"); + buf.append(b"hello world"); self.0 = false; - Poll::Ready(Ok(11)) + Poll::Ready(Ok(())) } else { - Poll::Ready(Ok(0)) + Poll::Ready(Ok(())) } } } diff --git a/tokio/tests/io_read.rs b/tokio/tests/io_read.rs index 4791c9a6618..0a717cf519e 100644 --- a/tokio/tests/io_read.rs +++ b/tokio/tests/io_read.rs @@ -1,7 +1,7 @@ #![warn(rust_2018_idioms)] #![cfg(feature = "full")] -use tokio::io::{AsyncRead, AsyncReadExt}; +use tokio::io::{AsyncRead, AsyncReadExt, ReadBuf}; use tokio_test::assert_ok; use std::io; @@ -19,13 +19,13 @@ async fn read() { fn poll_read( mut self: Pin<&mut Self>, _cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf<'_>, + ) -> Poll> { assert_eq!(0, self.poll_cnt); self.poll_cnt += 1; - buf[0..11].copy_from_slice(b"hello world"); - Poll::Ready(Ok(11)) + buf.append(b"hello world"); + Poll::Ready(Ok(())) } } @@ -36,25 +36,3 @@ async fn read() { assert_eq!(n, 11); assert_eq!(buf[..], b"hello world"[..]); } - -struct BadAsyncRead; - -impl AsyncRead for BadAsyncRead { - fn poll_read( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - for b in &mut *buf { - *b = b'a'; - } - Poll::Ready(Ok(buf.len() * 2)) - } -} - -#[tokio::test] -#[should_panic] -async fn read_buf_bad_async_read() { - let mut buf = Vec::with_capacity(10); - BadAsyncRead.read_buf(&mut buf).await.unwrap(); -} diff --git a/tokio/tests/io_split.rs b/tokio/tests/io_split.rs index e54bf248521..7b401424151 100644 --- a/tokio/tests/io_split.rs +++ b/tokio/tests/io_split.rs @@ -1,7 +1,7 @@ #![warn(rust_2018_idioms)] #![cfg(feature = "full")] -use tokio::io::{split, AsyncRead, AsyncWrite, ReadHalf, WriteHalf}; +use tokio::io::{split, AsyncRead, AsyncWrite, ReadBuf, ReadHalf, WriteHalf}; use std::io; use std::pin::Pin; @@ -13,9 +13,10 @@ impl AsyncRead for RW { fn poll_read( self: Pin<&mut Self>, _cx: &mut Context<'_>, - _buf: &mut [u8], - ) -> Poll> { - Poll::Ready(Ok(1)) + buf: &mut ReadBuf<'_>, + ) -> Poll> { + buf.append(&[b'z']); + Poll::Ready(Ok(())) } }