From ea43500fcad318bd3a1eec71e2096a5fe814de6b Mon Sep 17 00:00:00 2001 From: Nemo157 Date: Sat, 5 Sep 2020 04:28:52 +0200 Subject: [PATCH] Add test utility that verifies an AsyncWrite is closed correctly (#2159) * Add test utility that verifies an AsyncWrite is closed correctly * Add track_closed for sinks too --- futures-test/Cargo.toml | 1 + futures-test/src/io/write/mod.rs | 40 +++++++++ futures-test/src/lib.rs | 4 + futures-test/src/sink/mod.rs | 56 ++++++++++++ futures-test/src/track_closed.rs | 146 +++++++++++++++++++++++++++++++ 5 files changed, 247 insertions(+) create mode 100644 futures-test/src/sink/mod.rs create mode 100644 futures-test/src/track_closed.rs diff --git a/futures-test/Cargo.toml b/futures-test/Cargo.toml index 3abe28eb6c..ba53b73b98 100644 --- a/futures-test/Cargo.toml +++ b/futures-test/Cargo.toml @@ -17,6 +17,7 @@ futures-task = { version = "0.3.5", path = "../futures-task", default-features = futures-io = { version = "0.3.5", path = "../futures-io", default-features = false } futures-util = { version = "0.3.5", path = "../futures-util", default-features = false } futures-executor = { version = "0.3.5", path = "../futures-executor", default-features = false } +futures-sink = { version = "0.3.5", path = "../futures-sink", default-features = false } pin-utils = { version = "0.1.0", default-features = false } once_cell = { version = "1.3.1", default-features = false, features = ["std"], optional = true } pin-project = "0.4.20" diff --git a/futures-test/src/io/write/mod.rs b/futures-test/src/io/write/mod.rs index d228dd6a77..9d34ee00a8 100644 --- a/futures-test/src/io/write/mod.rs +++ b/futures-test/src/io/write/mod.rs @@ -4,6 +4,7 @@ use futures_io::AsyncWrite; pub use super::limited::Limited; pub use crate::interleave_pending::InterleavePending; +pub use crate::track_closed::TrackClosed; /// Additional combinators for testing async writers. pub trait AsyncWriteTestExt: AsyncWrite { @@ -80,6 +81,45 @@ pub trait AsyncWriteTestExt: AsyncWrite { { Limited::new(self, limit) } + + /// Track whether this stream has been closed and errors if it is used after closing. + /// + /// # Examples + /// + /// ``` + /// # futures::executor::block_on(async { + /// use futures::io::{AsyncWriteExt, Cursor}; + /// use futures_test::io::AsyncWriteTestExt; + /// + /// let mut writer = Cursor::new(vec![0u8; 4]).track_closed(); + /// + /// writer.write_all(&[1, 2]).await?; + /// assert!(!writer.is_closed()); + /// writer.close().await?; + /// assert!(writer.is_closed()); + /// + /// # Ok::<(), std::io::Error>(()) })?; + /// # Ok::<(), std::io::Error>(()) + /// ``` + /// + /// ``` + /// # futures::executor::block_on(async { + /// use futures::io::{AsyncWriteExt, Cursor}; + /// use futures_test::io::AsyncWriteTestExt; + /// + /// let mut writer = Cursor::new(vec![0u8; 4]).track_closed(); + /// + /// writer.close().await?; + /// assert!(writer.write_all(&[1, 2]).await.is_err()); + /// # Ok::<(), std::io::Error>(()) })?; + /// # Ok::<(), std::io::Error>(()) + /// ``` + fn track_closed(self) -> TrackClosed + where + Self: Sized, + { + TrackClosed::new(self) + } } impl AsyncWriteTestExt for W where W: AsyncWrite {} diff --git a/futures-test/src/lib.rs b/futures-test/src/lib.rs index 411a8ad3c7..06d6c61105 100644 --- a/futures-test/src/lib.rs +++ b/futures-test/src/lib.rs @@ -39,7 +39,11 @@ pub mod future; #[cfg(feature = "std")] pub mod stream; +#[cfg(feature = "std")] +pub mod sink; + #[cfg(feature = "std")] pub mod io; mod interleave_pending; +mod track_closed; diff --git a/futures-test/src/sink/mod.rs b/futures-test/src/sink/mod.rs new file mode 100644 index 0000000000..c3ebcfd65e --- /dev/null +++ b/futures-test/src/sink/mod.rs @@ -0,0 +1,56 @@ +//! Additional combinators for testing sinks. + +use futures_sink::Sink; + +pub use crate::track_closed::TrackClosed; + +/// Additional combinators for testing sinks. +pub trait SinkTestExt: Sink { + /// Track whether this sink has been closed and panics if it is used after closing. + /// + /// # Examples + /// + /// ``` + /// # futures::executor::block_on(async { + /// use futures::sink::{SinkExt, drain}; + /// use futures_test::sink::SinkTestExt; + /// + /// let mut sink = drain::().track_closed(); + /// + /// sink.send(1).await?; + /// assert!(!sink.is_closed()); + /// sink.close().await?; + /// assert!(sink.is_closed()); + /// + /// # Ok::<(), std::convert::Infallible>(()) })?; + /// # Ok::<(), std::convert::Infallible>(()) + /// ``` + /// + /// Note: Unlike [`AsyncWriteTestExt::track_closed`] when + /// used as a sink the adaptor will panic if closed too early as there's no easy way to + /// integrate as an error. + /// + /// [`AsyncWriteTestExt::track_closed`]: crate::io::AsyncWriteTestExt::track_closed + /// + /// ``` + /// # futures::executor::block_on(async { + /// use std::panic::AssertUnwindSafe; + /// use futures::{sink::{SinkExt, drain}, future::FutureExt}; + /// use futures_test::sink::SinkTestExt; + /// + /// let mut sink = drain::().track_closed(); + /// + /// sink.close().await?; + /// assert!(AssertUnwindSafe(sink.send(1)).catch_unwind().await.is_err()); + /// # Ok::<(), std::convert::Infallible>(()) })?; + /// # Ok::<(), std::convert::Infallible>(()) + /// ``` + fn track_closed(self) -> TrackClosed + where + Self: Sized, + { + TrackClosed::new(self) + } +} + +impl SinkTestExt for W where W: Sink {} diff --git a/futures-test/src/track_closed.rs b/futures-test/src/track_closed.rs new file mode 100644 index 0000000000..28614df85f --- /dev/null +++ b/futures-test/src/track_closed.rs @@ -0,0 +1,146 @@ +use futures_io::AsyncWrite; +use futures_sink::Sink; +use std::{ + io::{self, IoSlice}, + pin::Pin, + task::{Context, Poll}, +}; + +/// Async wrapper that tracks whether it has been closed. +/// +/// See the `track_closed` methods on: +/// * [`SinkTestExt`](crate::sink::SinkTestExt::track_closed) +/// * [`AsyncWriteTestExt`](crate::io::AsyncWriteTestExt::track_closed) +#[pin_project::pin_project] +#[derive(Debug)] +pub struct TrackClosed { + #[pin] + inner: T, + closed: bool, +} + +impl TrackClosed { + pub(crate) fn new(inner: T) -> TrackClosed { + TrackClosed { + inner, + closed: false, + } + } + + /// Check whether this object has been closed. + pub fn is_closed(&self) -> bool { + self.closed + } + + /// Acquires a reference to the underlying object that this adaptor is + /// wrapping. + pub fn get_ref(&self) -> &T { + &self.inner + } + + /// Acquires a mutable reference to the underlying object that this + /// adaptor is wrapping. + pub fn get_mut(&mut self) -> &mut T { + &mut self.inner + } + + /// Acquires a pinned mutable reference to the underlying object that + /// this adaptor is wrapping. + pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> { + self.project().inner + } + + /// Consumes this adaptor returning the underlying object. + pub fn into_inner(self) -> T { + self.inner + } +} + +impl AsyncWrite for TrackClosed { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + if self.is_closed() { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::Other, + "Attempted to write after stream was closed", + ))); + } + self.project().inner.poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.is_closed() { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::Other, + "Attempted to flush after stream was closed", + ))); + } + assert!(!self.is_closed()); + self.project().inner.poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.is_closed() { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::Other, + "Attempted to close after stream was closed", + ))); + } + let this = self.project(); + match this.inner.poll_close(cx) { + Poll::Ready(Ok(())) => { + *this.closed = true; + Poll::Ready(Ok(())) + } + other => other, + } + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + if self.is_closed() { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::Other, + "Attempted to write after stream was closed", + ))); + } + self.project().inner.poll_write_vectored(cx, bufs) + } +} + +impl> Sink for TrackClosed { + type Error = T::Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + assert!(!self.is_closed()); + self.project().inner.poll_ready(cx) + } + + fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> { + assert!(!self.is_closed()); + self.project().inner.start_send(item) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + assert!(!self.is_closed()); + self.project().inner.poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + assert!(!self.is_closed()); + let this = self.project(); + match this.inner.poll_close(cx) { + Poll::Ready(Ok(())) => { + *this.closed = true; + Poll::Ready(Ok(())) + } + other => other, + } + } +}