diff --git a/tokio-test/src/io.rs b/tokio-test/src/io.rs index 26ef57e47a2..a0f36bb48a3 100644 --- a/tokio-test/src/io.rs +++ b/tokio-test/src/io.rs @@ -18,7 +18,7 @@ //! [`AsyncRead`]: tokio::io::AsyncRead //! [`AsyncWrite`]: tokio::io::AsyncWrite -use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::sync::mpsc; use tokio::time::{self, Delay, Duration, Instant}; @@ -204,20 +204,20 @@ impl Inner { self.rx.poll_recv(cx) } - fn read(&mut self, dst: &mut [u8]) -> io::Result { + fn read(&mut self, dst: &mut ReadBuf<'_>) -> io::Result<()> { match self.action() { Some(&mut Action::Read(ref mut data)) => { // Figure out how much to copy - let n = cmp::min(dst.len(), data.len()); + let n = cmp::min(dst.remaining(), data.len()); // Copy the data into the `dst` slice - (&mut dst[..n]).copy_from_slice(&data[..n]); + dst.append(&data[..n]); // Drain the data from the source data.drain(..n); // Return the number of bytes read - Ok(n) + Ok(()) } Some(&mut Action::ReadError(ref mut err)) => { // As the @@ -229,7 +229,7 @@ impl Inner { // Either waiting or expecting a write Err(io::ErrorKind::WouldBlock.into()) } - None => Ok(0), + None => Ok(()), } } @@ -348,8 +348,8 @@ impl AsyncRead for Mock { fn poll_read( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf<'_>, + ) -> Poll> { loop { if let Some(ref mut sleep) = self.inner.sleep { ready!(Pin::new(sleep).poll(cx)); @@ -358,6 +358,9 @@ impl AsyncRead for Mock { // If a sleep is set, it has already fired self.inner.sleep = None; + // Capture 'filled' to monitor if it changed + let filled = buf.filled().len(); + match self.inner.read(buf) { Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { if let Some(rem) = self.inner.remaining_wait() { @@ -368,19 +371,22 @@ impl AsyncRead for Mock { return Poll::Pending; } } - Ok(0) => { - // TODO: Extract - match ready!(self.inner.poll_action(cx)) { - Some(action) => { - self.inner.actions.push_back(action); - continue; - } - None => { - return Poll::Ready(Ok(0)); + Ok(()) => { + if buf.filled().len() == filled { + match ready!(self.inner.poll_action(cx)) { + Some(action) => { + self.inner.actions.push_back(action); + continue; + } + None => { + return Poll::Ready(Ok(())); + } } + } else { + return Poll::Ready(Ok(())) } } - ret => return Poll::Ready(ret), + Err(e) => return Poll::Ready(Err(e)), } } } diff --git a/tokio/tests/io_async_read.rs b/tokio/tests/io_async_read.rs index 20440bbde35..800ca00b907 100644 --- a/tokio/tests/io_async_read.rs +++ b/tokio/tests/io_async_read.rs @@ -2,13 +2,12 @@ #![warn(rust_2018_idioms)] #![cfg(feature = "full")] -use tokio::io::AsyncRead; +use tokio::io::{AsyncRead, ReadBuf}; use tokio_test::task; use tokio_test::{assert_ready_err, assert_ready_ok}; -use bytes::{BufMut, BytesMut}; +use bytes::{BytesMut}; use std::io; -use std::mem::MaybeUninit; use std::pin::Pin; use std::task::{Context, Poll}; @@ -26,10 +25,10 @@ fn read_buf_success() { fn poll_read( self: Pin<&mut Self>, _cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - buf[0..11].copy_from_slice(b"hello world"); - Poll::Ready(Ok(11)) + buf: &mut ReadBuf<'_>, + ) -> Poll> { + buf.append(b"hello world"); + Poll::Ready(Ok(())) } } @@ -51,8 +50,8 @@ fn read_buf_error() { fn poll_read( self: Pin<&mut Self>, _cx: &mut Context<'_>, - _buf: &mut [u8], - ) -> Poll> { + _buf: &mut ReadBuf<'_>, + ) -> Poll> { let err = io::ErrorKind::Other.into(); Poll::Ready(Err(err)) } @@ -74,8 +73,8 @@ fn read_buf_no_capacity() { fn poll_read( self: Pin<&mut Self>, _cx: &mut Context<'_>, - _buf: &mut [u8], - ) -> Poll> { + _buf: &mut ReadBuf<'_>, + ) -> Poll> { unimplemented!(); } } @@ -88,59 +87,26 @@ fn read_buf_no_capacity() { }); } -#[test] -fn read_buf_no_uninitialized() { - struct Rd; - - impl AsyncRead for Rd { - fn poll_read( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - for b in buf { - assert_eq!(0, *b); - } - - Poll::Ready(Ok(0)) - } - } - - let mut buf = BytesMut::with_capacity(64); - - task::spawn(Rd).enter(|cx, rd| { - let n = assert_ready_ok!(rd.poll_read_buf(cx, &mut buf)); - assert_eq!(0, n); - }); -} - #[test] fn read_buf_uninitialized_ok() { struct Rd; impl AsyncRead for Rd { - unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit]) -> bool { - false - } - fn poll_read( self: Pin<&mut Self>, _cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - assert_eq!(buf[0..11], b"hello world"[..]); - Poll::Ready(Ok(0)) + buf: &mut ReadBuf<'_>, + ) -> Poll> { + assert_eq!(buf.remaining(), 64); + assert_eq!(buf.filled().len(), 0); + assert_eq!(buf.initialized().len(), 0); + Poll::Ready(Ok(())) } } // Can't create BytesMut w/ zero capacity, so fill it up let mut buf = BytesMut::with_capacity(64); - unsafe { - let b: &mut [u8] = std::mem::transmute(buf.bytes_mut()); - b[0..11].copy_from_slice(b"hello world"); - } - task::spawn(Rd).enter(|cx, rd| { let n = assert_ready_ok!(rd.poll_read_buf(cx, &mut buf)); assert_eq!(0, n); diff --git a/tokio/tests/io_copy.rs b/tokio/tests/io_copy.rs index c1c6df4eb34..aed6c789d23 100644 --- a/tokio/tests/io_copy.rs +++ b/tokio/tests/io_copy.rs @@ -1,7 +1,7 @@ #![warn(rust_2018_idioms)] #![cfg(feature = "full")] -use tokio::io::{self, AsyncRead}; +use tokio::io::{self, AsyncRead, ReadBuf}; use tokio_test::assert_ok; use std::pin::Pin; @@ -15,14 +15,14 @@ async fn copy() { fn poll_read( mut self: Pin<&mut Self>, _cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf<'_>, + ) -> Poll> { if self.0 { - buf[0..11].copy_from_slice(b"hello world"); + buf.append(b"hello world"); self.0 = false; - Poll::Ready(Ok(11)) + Poll::Ready(Ok(())) } else { - Poll::Ready(Ok(0)) + Poll::Ready(Ok(())) } } } diff --git a/tokio/tests/io_read.rs b/tokio/tests/io_read.rs index 4791c9a6618..0a717cf519e 100644 --- a/tokio/tests/io_read.rs +++ b/tokio/tests/io_read.rs @@ -1,7 +1,7 @@ #![warn(rust_2018_idioms)] #![cfg(feature = "full")] -use tokio::io::{AsyncRead, AsyncReadExt}; +use tokio::io::{AsyncRead, AsyncReadExt, ReadBuf}; use tokio_test::assert_ok; use std::io; @@ -19,13 +19,13 @@ async fn read() { fn poll_read( mut self: Pin<&mut Self>, _cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf<'_>, + ) -> Poll> { assert_eq!(0, self.poll_cnt); self.poll_cnt += 1; - buf[0..11].copy_from_slice(b"hello world"); - Poll::Ready(Ok(11)) + buf.append(b"hello world"); + Poll::Ready(Ok(())) } } @@ -36,25 +36,3 @@ async fn read() { assert_eq!(n, 11); assert_eq!(buf[..], b"hello world"[..]); } - -struct BadAsyncRead; - -impl AsyncRead for BadAsyncRead { - fn poll_read( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - for b in &mut *buf { - *b = b'a'; - } - Poll::Ready(Ok(buf.len() * 2)) - } -} - -#[tokio::test] -#[should_panic] -async fn read_buf_bad_async_read() { - let mut buf = Vec::with_capacity(10); - BadAsyncRead.read_buf(&mut buf).await.unwrap(); -} diff --git a/tokio/tests/io_split.rs b/tokio/tests/io_split.rs index e54bf248521..7b401424151 100644 --- a/tokio/tests/io_split.rs +++ b/tokio/tests/io_split.rs @@ -1,7 +1,7 @@ #![warn(rust_2018_idioms)] #![cfg(feature = "full")] -use tokio::io::{split, AsyncRead, AsyncWrite, ReadHalf, WriteHalf}; +use tokio::io::{split, AsyncRead, AsyncWrite, ReadBuf, ReadHalf, WriteHalf}; use std::io; use std::pin::Pin; @@ -13,9 +13,10 @@ impl AsyncRead for RW { fn poll_read( self: Pin<&mut Self>, _cx: &mut Context<'_>, - _buf: &mut [u8], - ) -> Poll> { - Poll::Ready(Ok(1)) + buf: &mut ReadBuf<'_>, + ) -> Poll> { + buf.append(&[b'z']); + Poll::Ready(Ok(())) } }