diff --git a/tokio/src/io/util/mod.rs b/tokio/src/io/util/mod.rs index e75ea03424c..e06e7e25e47 100644 --- a/tokio/src/io/util/mod.rs +++ b/tokio/src/io/util/mod.rs @@ -48,6 +48,7 @@ cfg_io_util! { mod read_line; mod read_to_end; + mod vec_with_initialized; cfg_process! { pub(crate) use read_to_end::read_to_end; } @@ -82,6 +83,7 @@ cfg_io_util! { cfg_not_io_util! { cfg_process! { + mod vec_with_initialized; mod read_to_end; // Used by process pub(crate) use read_to_end::read_to_end; diff --git a/tokio/src/io/util/read_to_end.rs b/tokio/src/io/util/read_to_end.rs index 1aee6810ee4..dff7d66d0c6 100644 --- a/tokio/src/io/util/read_to_end.rs +++ b/tokio/src/io/util/read_to_end.rs @@ -1,10 +1,11 @@ -use crate::io::{AsyncRead, ReadBuf}; +use crate::io::util::vec_with_initialized::{into_read_buf_parts, VecWithInitialized}; +use crate::io::AsyncRead; use pin_project_lite::pin_project; use std::future::Future; use std::io; use std::marker::PhantomPinned; -use std::mem::{self, MaybeUninit}; +use std::mem; use std::pin::Pin; use std::task::{Context, Poll}; @@ -13,7 +14,7 @@ pin_project! { #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct ReadToEnd<'a, R: ?Sized> { reader: &'a mut R, - buf: &'a mut Vec, + buf: VecWithInitialized<&'a mut Vec>, // The number of bytes appended to buf. This can be less than buf.len() if // the buffer was not empty when the operation was started. read: usize, @@ -27,22 +28,22 @@ pub(crate) fn read_to_end<'a, R>(reader: &'a mut R, buffer: &'a mut Vec) -> where R: AsyncRead + Unpin + ?Sized, { + // SAFETY: The generic type on VecWithInitialized is &mut Vec. ReadToEnd { reader, - buf: buffer, + buf: unsafe { VecWithInitialized::new(buffer) }, read: 0, _pin: PhantomPinned, } } -pub(super) fn read_to_end_internal( - buf: &mut Vec, +pub(super) fn read_to_end_internal>, R: AsyncRead + ?Sized>( + buf: &mut VecWithInitialized, mut reader: Pin<&mut R>, num_read: &mut usize, cx: &mut Context<'_>, ) -> Poll> { loop { - // safety: The caller promised to prepare the buffer. let ret = ready!(poll_read_to_end(buf, reader.as_mut(), cx)); match ret { Err(err) => return Poll::Ready(Err(err)), @@ -57,8 +58,8 @@ pub(super) fn read_to_end_internal( /// Tries to read from the provided AsyncRead. /// /// The length of the buffer is increased by the number of bytes read. -fn poll_read_to_end( - buf: &mut Vec, +fn poll_read_to_end>, R: AsyncRead + ?Sized>( + buf: &mut VecWithInitialized, read: Pin<&mut R>, cx: &mut Context<'_>, ) -> Poll> { @@ -68,37 +69,34 @@ fn poll_read_to_end( // of data to return. Simply tacking on an extra DEFAULT_BUF_SIZE space every // time is 4,500 times (!) slower than this if the reader has a very small // amount of data to return. - reserve(buf, 32); + buf.reserve(32); - let mut unused_capacity = ReadBuf::uninit(get_unused_capacity(buf)); + // Get a ReadBuf into the vector. + let mut read_buf = buf.get_read_buf(); - let ptr = unused_capacity.filled().as_ptr(); - ready!(read.poll_read(cx, &mut unused_capacity))?; - assert_eq!(ptr, unused_capacity.filled().as_ptr()); + let filled_before = read_buf.filled().len(); + let poll_result = read.poll_read(cx, &mut read_buf); + let filled_after = read_buf.filled().len(); + let n = filled_after - filled_before; - let n = unused_capacity.filled().len(); - let new_len = buf.len() + n; + // Update the length of the vector using the result of poll_read. + let read_buf_parts = into_read_buf_parts(read_buf); + buf.apply_read_buf(read_buf_parts); - assert!(new_len <= buf.capacity()); - unsafe { - buf.set_len(new_len); - } - Poll::Ready(Ok(n)) -} - -/// Allocates more memory and ensures that the unused capacity is prepared for use -/// with the `AsyncRead`. -fn reserve(buf: &mut Vec, bytes: usize) { - if buf.capacity() - buf.len() >= bytes { - return; + match poll_result { + Poll::Pending => { + // In this case, nothing should have been read. However we still + // update the vector in case the poll_read call initialized parts of + // the vector's unused capacity. + debug_assert_eq!(filled_before, filled_after); + Poll::Pending + } + Poll::Ready(Err(err)) => { + debug_assert_eq!(filled_before, filled_after); + Poll::Ready(Err(err)) + } + Poll::Ready(Ok(())) => Poll::Ready(Ok(n)), } - buf.reserve(bytes); -} - -/// Returns the unused capacity of the provided vector. -fn get_unused_capacity(buf: &mut Vec) -> &mut [MaybeUninit] { - let uninit = bytes::BufMut::chunk_mut(buf); - unsafe { &mut *(uninit as *mut _ as *mut [MaybeUninit]) } } impl Future for ReadToEnd<'_, A> diff --git a/tokio/src/io/util/read_to_string.rs b/tokio/src/io/util/read_to_string.rs index e463203c0ae..215ead370d8 100644 --- a/tokio/src/io/util/read_to_string.rs +++ b/tokio/src/io/util/read_to_string.rs @@ -1,5 +1,6 @@ use crate::io::util::read_line::finish_string_read; use crate::io::util::read_to_end::read_to_end_internal; +use crate::io::util::vec_with_initialized::VecWithInitialized; use crate::io::AsyncRead; use pin_project_lite::pin_project; @@ -19,7 +20,7 @@ pin_project! { // while reading to postpone utf-8 handling until after reading. output: &'a mut String, // The actual allocation of the string is moved into this vector instead. - buf: Vec, + buf: VecWithInitialized>, // The number of bytes appended to buf. This can be less than buf.len() if // the buffer was not empty when the operation was started. read: usize, @@ -37,29 +38,25 @@ where R: AsyncRead + ?Sized + Unpin, { let buf = mem::replace(string, String::new()).into_bytes(); + // SAFETY: The generic type of the VecWithInitialized is Vec. ReadToString { reader, - buf, + buf: unsafe { VecWithInitialized::new(buf) }, output: string, read: 0, _pin: PhantomPinned, } } -/// # Safety -/// -/// Before first calling this method, the unused capacity must have been -/// prepared for use with the provided AsyncRead. This can be done using the -/// `prepare_buffer` function in `read_to_end.rs`. -unsafe fn read_to_string_internal( +fn read_to_string_internal( reader: Pin<&mut R>, output: &mut String, - buf: &mut Vec, + buf: &mut VecWithInitialized>, read: &mut usize, cx: &mut Context<'_>, ) -> Poll> { let io_res = ready!(read_to_end_internal(buf, reader, read, cx)); - let utf8_res = String::from_utf8(mem::replace(buf, Vec::new())); + let utf8_res = String::from_utf8(buf.take()); // At this point both buf and output are empty. The allocation is in utf8_res. @@ -77,7 +74,6 @@ where fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let me = self.project(); - // safety: The constructor of ReadToString called `prepare_buffer`. - unsafe { read_to_string_internal(Pin::new(*me.reader), me.output, me.buf, me.read, cx) } + read_to_string_internal(Pin::new(*me.reader), me.output, me.buf, me.read, cx) } } diff --git a/tokio/src/io/util/vec_with_initialized.rs b/tokio/src/io/util/vec_with_initialized.rs new file mode 100644 index 00000000000..ca24719dbf6 --- /dev/null +++ b/tokio/src/io/util/vec_with_initialized.rs @@ -0,0 +1,120 @@ +use crate::io::ReadBuf; +use std::mem::MaybeUninit; + +/// This struct wraps a `Vec` or `&mut Vec`, combining it with a +/// `num_initialized`, which keeps track of the number of initialized bytes +/// in the unused capacity. +/// +/// The purpose of this struct is to remember how many bytes were initialized +/// through a `ReadBuf` from call to call. +/// +/// This struct has the safety invariant that the first `num_initialized` of the +/// vector's allocation must be initialized at any time. +#[derive(Debug)] +pub(crate) struct VecWithInitialized { + vec: V, + // The number of initialized bytes in the vector. + // Always between `vec.len()` and `vec.capacity()`. + num_initialized: usize, +} + +impl VecWithInitialized> { + #[cfg(feature = "io-util")] + pub(crate) fn take(&mut self) -> Vec { + self.num_initialized = 0; + std::mem::take(&mut self.vec) + } +} + +impl VecWithInitialized +where + V: AsMut>, +{ + /// Safety: The generic parameter `V` must be either `Vec` or `&mut Vec`. + pub(crate) unsafe fn new(mut vec: V) -> Self { + // SAFETY: The safety invariants of vector guarantee that the bytes up + // to its length are initialized. + Self { + num_initialized: vec.as_mut().len(), + vec, + } + } + + pub(crate) fn reserve(&mut self, num_bytes: usize) { + let vec = self.vec.as_mut(); + if vec.capacity() - vec.len() >= num_bytes { + return; + } + // SAFETY: Setting num_initialized to `vec.len()` is correct as + // `reserve` does not change the length of the vector. + self.num_initialized = vec.len(); + vec.reserve(num_bytes); + } + + #[cfg(feature = "io-util")] + pub(crate) fn is_empty(&mut self) -> bool { + self.vec.as_mut().is_empty() + } + + pub(crate) fn get_read_buf<'a>(&'a mut self) -> ReadBuf<'a> { + let num_initialized = self.num_initialized; + + // SAFETY: Creating the slice is safe because of the safety invariants + // on Vec. The safety invariants of `ReadBuf` will further guarantee + // that no bytes in the slice are de-initialized. + let vec = self.vec.as_mut(); + let len = vec.len(); + let cap = vec.capacity(); + let ptr = vec.as_mut_ptr().cast::>(); + let slice = unsafe { std::slice::from_raw_parts_mut::<'a, MaybeUninit>(ptr, cap) }; + + // SAFETY: This is safe because the safety invariants of + // VecWithInitialized say that the first num_initialized bytes must be + // initialized. + let mut read_buf = ReadBuf::uninit(slice); + unsafe { + read_buf.assume_init(num_initialized); + } + read_buf.set_filled(len); + + read_buf + } + + pub(crate) fn apply_read_buf(&mut self, parts: ReadBufParts) { + let vec = self.vec.as_mut(); + assert_eq!(vec.as_ptr(), parts.ptr); + + // SAFETY: + // The ReadBufParts really does point inside `self.vec` due to the above + // check, and the safety invariants of `ReadBuf` guarantee that the + // first `parts.initialized` bytes of `self.vec` really have been + // initialized. Additionally, `ReadBuf` guarantees that `parts.len` is + // at most `parts.initialized`, so the first `parts.len` bytes are also + // initialized. + // + // Note that this relies on the fact that `V` is either `Vec` or + // `&mut Vec`, so the vector returned by `self.vec.as_mut()` cannot + // change from call to call. + unsafe { + self.num_initialized = parts.initialized; + vec.set_len(parts.len); + } + } +} + +pub(crate) struct ReadBufParts { + // Pointer is only used to check that the ReadBuf actually came from the + // right VecWithInitialized. + ptr: *const u8, + len: usize, + initialized: usize, +} + +// This is needed to release the borrow on `VecWithInitialized`. +pub(crate) fn into_read_buf_parts(rb: ReadBuf<'_>) -> ReadBufParts { + ReadBufParts { + ptr: rb.filled().as_ptr(), + len: rb.filled().len(), + initialized: rb.initialized().len(), + } +} diff --git a/tokio/tests/io_read_to_end.rs b/tokio/tests/io_read_to_end.rs index ee636ba5963..171e6d6480e 100644 --- a/tokio/tests/io_read_to_end.rs +++ b/tokio/tests/io_read_to_end.rs @@ -1,7 +1,9 @@ #![warn(rust_2018_idioms)] #![cfg(feature = "full")] -use tokio::io::AsyncReadExt; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncReadExt, ReadBuf}; use tokio_test::assert_ok; #[tokio::test] @@ -13,3 +15,64 @@ async fn read_to_end() { assert_eq!(n, 11); assert_eq!(buf[..], b"hello world"[..]); } + +#[derive(Copy, Clone, Debug)] +enum State { + Initializing, + JustFilling, + Done, +} + +struct UninitTest { + num_init: usize, + state: State, +} + +impl AsyncRead for UninitTest { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let me = Pin::into_inner(self); + let real_num_init = buf.initialized().len() - buf.filled().len(); + assert_eq!(real_num_init, me.num_init, "{:?}", me.state); + + match me.state { + State::Initializing => { + buf.initialize_unfilled_to(me.num_init + 2); + buf.advance(1); + me.num_init += 1; + + if me.num_init == 24 { + me.state = State::JustFilling; + } + } + State::JustFilling => { + buf.advance(1); + me.num_init -= 1; + + if me.num_init == 15 { + // The buffer is resized on next call. + me.num_init = 0; + me.state = State::Done; + } + } + State::Done => { /* .. do nothing .. */ } + } + + Poll::Ready(Ok(())) + } +} + +#[tokio::test] +async fn read_to_end_uninit() { + let mut buf = Vec::with_capacity(64); + let mut test = UninitTest { + num_init: 0, + state: State::Initializing, + }; + + test.read_to_end(&mut buf).await.unwrap(); + assert_eq!(buf.len(), 33); +}