diff --git a/tokio/Cargo.toml b/tokio/Cargo.toml index 6d8377c2830..94e8e1cce8a 100644 --- a/tokio/Cargo.toml +++ b/tokio/Cargo.toml @@ -30,6 +30,7 @@ default = [] # enable everything full = [ + "async-fd", "blocking", "dns", "fs", @@ -52,7 +53,7 @@ dns = ["rt-core"] fs = ["rt-core", "io-util"] io-util = ["memchr"] # stdin, stdout, stderr -io-std = ["rt-core"] +io-std = ["rt-core", "mio/os-util"] macros = ["tokio-macros"] net = ["dns", "tcp", "udp", "uds"] process = [ @@ -76,6 +77,7 @@ signal = [ "libc", "mio/os-poll", "mio/uds", + "mio/os-util", "signal-hook-registry", "winapi/consoleapi", ] @@ -86,6 +88,7 @@ tcp = ["lazy_static", "mio/tcp", "mio/os-poll"] time = [] udp = ["lazy_static", "mio/udp", "mio/os-poll"] uds = ["lazy_static", "libc", "mio/uds", "mio/os-poll"] +async-fd = ["lazy_static", "mio/udp", "mio/os-poll", "mio/os-util"] [dependencies] tokio-macros = { version = "0.3.0", path = "../tokio-macros", optional = true } @@ -108,6 +111,10 @@ tracing = { version = "0.1.16", default-features = false, features = ["std"], op libc = { version = "0.2.42", optional = true } signal-hook-registry = { version = "1.1.1", optional = true } +[target.'cfg(unix)'.dev-dependencies] +libc = { version = "0.2.42" } +nix = { version = "0.18.0" } + [target.'cfg(windows)'.dependencies.winapi] version = "0.3.8" default-features = false diff --git a/tokio/src/io/async_fd.rs b/tokio/src/io/async_fd.rs new file mode 100644 index 00000000000..b890b5a5bdc --- /dev/null +++ b/tokio/src/io/async_fd.rs @@ -0,0 +1,269 @@ +use std::os::unix::io::RawFd; + +use std::io; + +use mio::unix::SourceFd; + +use crate::io::driver::{Handle, ReadyEvent, ScheduledIo}; +use crate::util::slab; + +/// Associates a Unix file descriptor with the tokio reactor, allowing for +/// readiness to be polled. +/// +/// Creating an AsyncFd registers the file descriptor with the current tokio +/// Reactor, allowing you to directly await the file descriptor being readable +/// or writable. Once registered, the file descriptor remains registered until +/// the AsyncFd is dropped. +/// +/// It is the responsibility of the caller to ensure that the AsyncFd is dropped +/// before the associated file descriptor is closed. Failing to do so may result +/// in spurious events or mysterious errors from other tokio IO calls. +/// +/// Polling for readiness is done by calling the async functions [`readable`] +/// and [`writable`]. These functions complete when the associated readiness +/// condition is observed. Any number of tasks can query the same `AsyncFd` +/// in parallel, on the same or different conditions. +/// +/// On some platforms, the readiness detecting mechanism relies on +/// edge-triggered notifications. This means that the OS will only notify Tokio +/// when the file descriptor transitions from not-ready to ready. Tokio +/// internally tracks when it has received a ready notification, and when +/// readiness checking functions like [`readable`] and [`writable`] are called, +/// if the readiness flag is set, these async functions will complete +/// immediately. +/// +/// This however does mean that it is critical to ensure that this ready flag is +/// cleared when (and only when) the file descriptor ceases to be ready. The +/// [`ReadyGuard`] returned from readiness checking functions serves this +/// function; after calling a readiness-checking async function, you must use +/// this [`ReadyGuard`] to signal to tokio whether the file descriptor is no +/// longer in a ready state. +/// +/// ## Converting to a poll-based API +/// +/// In some cases it may be desirable to use `AsyncFd` from APIs similar to +/// [`TcpStream::poll_read_ready`]. One can do so by allocating a pinned future +/// to perform the poll: +/// +/// ``` +/// use tokio::io::{ReadyGuard, AsyncFd}; +/// +/// use std::future::Future; +/// use std::sync::Arc; +/// use std::pin::Pin; +/// use std::task::{Context, Poll}; +/// +/// use futures::ready; +/// +/// struct MyIoStruct { +/// async_fd: Arc, +/// poller: Pin>> +/// } +/// +/// impl MyIoStruct { +/// fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll>> { +/// let mut result = Poll::Pending; +/// while result.is_pending() { +/// // Poll the saved future; if it's not ready, our context waker will be saved in the +/// // future and we can return. +/// ready!(self.poller.as_mut().poll(cx)); +/// +/// // Reset the poller future, since we consumed it. +/// let arc = self.async_fd.clone(); +/// self.poller = Box::pin(async move { +/// let _ = arc.readable().await.map(|mut guard| guard.retain_ready()); +/// }); +/// +/// // Because we need to bind the ReadyGuard to the lifetime of self, we have to re-poll here. +/// // It's possible that we might race with another thread clearing the ready state, so deal +/// // with that as well. +/// let fut = self.async_fd.readable(); +/// tokio::pin!(fut); +/// result = fut.as_mut().poll(cx); +/// } +/// +/// result +/// } +/// } +/// ``` +/// +/// [`readable`]: method@Self::readable +/// [`writable`]: method@Self::writable +/// [`ReadyGuard`]: struct@self::ReadyGuard +/// [`TcpStream::poll_read_ready`]: struct@crate::net::TcpStream +pub struct AsyncFd { + handle: Handle, + fd: RawFd, + shared: slab::Ref, +} + +impl std::fmt::Debug for AsyncFd { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AsyncFd").field("fd", &self.fd).finish() + } +} + +unsafe impl Send for AsyncFd {} +unsafe impl Sync for AsyncFd {} + +const fn all_interest() -> mio::Interest { + mio::Interest::READABLE.add(mio::Interest::WRITABLE) +} + +/// Represents an IO-ready event detected on a particular file descriptor, which +/// has not yet been acknowledged. This is a `must_use` structure to help ensure +/// that you do not forget to explicitly clear (or not clear) the event. +#[must_use = "You must explicitly choose whether to clear the readiness state by calling a method on ReadyGuard"] +pub struct ReadyGuard<'a> { + async_fd: &'a AsyncFd, + event: Option, +} + +impl<'a> std::fmt::Debug for ReadyGuard<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ClearReady") + .field("async_fd", self.async_fd) + .finish() + } +} + +impl<'a> ReadyGuard<'a> { + /// Indicates to tokio that the file descriptor is no longer ready. The + /// internal readiness flag will be cleared, and tokio will wait for the + /// next edge-triggered readiness notification from the OS. + /// + /// It is critical that this function not be called unless your code + /// _actually observes_ that the file descriptor is _not_ ready. Do not call + /// it simply because, for example, a read succeeded; it should be called + /// when a read is observed to block. + /// + /// [`drop`]: method@std::mem::drop + pub fn clear_ready(&mut self) { + if let Some(event) = self.event.take() { + self.async_fd.shared.clear_readiness(event); + } + } + + /// This function should be invoked when you intentionally want to keep the + /// ready flag asserted. + /// + /// While this function is itself a no-op, it satisfies the `#[must_use]` + /// constraint on the [`ReadyGuard`] type. + pub fn retain_ready(&mut self) { + // no-op + } + + /// Performs the IO operation `f`; if `f` returns a [`WouldBlock`] error, + /// the readiness state associated with this file descriptor is cleared. + /// + /// This method helps ensure that the readiness state of the underlying file + /// descriptor remains in sync with the tokio-side readiness state, by + /// clearing the tokio-side state only when a [`WouldBlock`] condition + /// occurs. It is the responsibility of the caller to ensure that `f` + /// returns [`WouldBlock`] only if the file descriptor that originated this + /// `ReadyGuard` no longer expresses the readiness state that was queried to + /// create this `ReadyGuard`. + /// + /// [`WouldBlock`]: std::io::ErrorKind::WouldBlock + pub fn with_io(&mut self, f: impl FnOnce() -> Result) -> Result + where + E: std::error::Error + 'static, + { + use std::error::Error; + + let result = f(); + + if let Err(e) = result.as_ref() { + // Is this a WouldBlock error? + let mut error_ref: Option<&(dyn Error + 'static)> = Some(e); + + while let Some(current) = error_ref { + if let Some(e) = Error::downcast_ref::(current) { + if e.kind() == std::io::ErrorKind::WouldBlock { + self.clear_ready(); + break; + } + } + error_ref = current.source(); + } + } + + result + } + + /// Performs the IO operation `f`; if `f` returns [`Pending`], the readiness + /// state associated with this file descriptor is cleared. + /// + /// This method helps ensure that the readiness state of the underlying file + /// descriptor remains in sync with the tokio-side readiness state, by + /// clearing the tokio-side state only when a [`Pending`] condition occurs. + /// It is the responsibility of the caller to ensure that `f` returns + /// [`Pending`] only if the file descriptor that originated this + /// `ReadyGuard` no longer expresses the readiness state that was queried to + /// create this `ReadyGuard`. + /// + /// [`Pending`]: std::task::Poll::Pending + pub fn with_poll(&mut self, f: impl FnOnce() -> std::task::Poll) -> std::task::Poll { + let result = f(); + + if result.is_pending() { + self.clear_ready(); + } + + result + } +} + +impl Drop for AsyncFd { + fn drop(&mut self) { + if let Some(inner) = self.handle.inner() { + let _ = inner.deregister_source(&mut SourceFd(&self.fd)); + } + } +} + +impl AsyncFd { + /// Constructs a new AsyncFd, binding this file descriptor to the current tokio Reactor. + /// + /// This function must be called in the context of a tokio runtime. + pub fn new(fd: RawFd) -> io::Result { + Self::new_with_handle(fd, Handle::current()) + } + + pub(crate) fn new_with_handle(fd: RawFd, handle: Handle) -> io::Result { + let shared = if let Some(inner) = handle.inner() { + inner.add_source(&mut SourceFd(&fd), all_interest())? + } else { + return Err(io::Error::new( + io::ErrorKind::Other, + "failed to find event loop", + )); + }; + + Ok(AsyncFd { handle, fd, shared }) + } + + async fn readiness(&self, interest: mio::Interest) -> io::Result> { + let event = self.shared.readiness(interest).await; + Ok(ReadyGuard { + async_fd: self, + event: Some(event), + }) + } + + /// Waits for the file descriptor to become readable, returning a + /// [`ReadyGuard`] that must be dropped to resume read-readiness polling. + /// + /// [`ReadyGuard`]: struct@self::ReadyGuard + pub async fn readable(&self) -> io::Result> { + self.readiness(mio::Interest::READABLE).await + } + + /// Waits for the file descriptor to become writable, returning a + /// [`ReadyGuard`] that must be dropped to resume write-readiness polling. + /// + /// [`ReadyGuard`]: struct@self::ReadyGuard + pub async fn writable(&self) -> io::Result> { + self.readiness(mio::Interest::WRITABLE).await + } +} diff --git a/tokio/src/io/driver/mod.rs b/tokio/src/io/driver/mod.rs index c4f5887a930..41469c1964d 100644 --- a/tokio/src/io/driver/mod.rs +++ b/tokio/src/io/driver/mod.rs @@ -56,10 +56,12 @@ pub(super) struct Inner { waker: mio::Waker, } -#[derive(Debug, Eq, PartialEq, Clone, Copy)] -pub(super) enum Direction { - Read, - Write, +cfg_io_poll_evented! { + #[derive(Debug, Eq, PartialEq, Clone, Copy)] + pub(super) enum Direction { + Read, + Write, + } } enum Tick { @@ -292,12 +294,13 @@ impl Inner { self.registry.deregister(source) } } - -impl Direction { - pub(super) fn mask(self) -> Ready { - match self { - Direction::Read => Ready::READABLE | Ready::READ_CLOSED, - Direction::Write => Ready::WRITABLE | Ready::WRITE_CLOSED, +cfg_io_poll_evented! { + impl Direction { + pub(super) fn mask(self) -> Ready { + match self { + Direction::Read => Ready::READABLE | Ready::READ_CLOSED, + Direction::Write => Ready::WRITABLE | Ready::WRITE_CLOSED, + } } } } diff --git a/tokio/src/io/driver/scheduled_io.rs b/tokio/src/io/driver/scheduled_io.rs index bdf217987d2..73d729968bc 100644 --- a/tokio/src/io/driver/scheduled_io.rs +++ b/tokio/src/io/driver/scheduled_io.rs @@ -1,4 +1,4 @@ -use super::{Direction, Ready, ReadyEvent, Tick}; +use super::{Ready, ReadyEvent, Tick}; use crate::loom::sync::atomic::AtomicUsize; use crate::loom::sync::Mutex; use crate::util::bit; @@ -7,6 +7,10 @@ use crate::util::slab::Entry; use std::sync::atomic::Ordering::{AcqRel, Acquire, Release}; use std::task::{Context, Poll, Waker}; +cfg_io_poll_evented! { + use super::Direction; +} + cfg_io_readiness! { use crate::util::linked_list::{self, LinkedList}; @@ -32,7 +36,7 @@ cfg_io_readiness! { #[derive(Debug, Default)] struct Waiters { - #[cfg(any(feature = "udp", feature = "uds"))] + #[cfg(any(feature = "udp", feature = "uds", feature = "async-fd"))] /// List of all current waiters list: WaitList, @@ -203,7 +207,7 @@ impl ScheduledIo { } } - #[cfg(any(feature = "udp", feature = "uds"))] + #[cfg(any(feature = "udp", feature = "uds", feature = "async-fd"))] { // check list of waiters for waiter in waiters.list.drain_filter(|w| ready.satisfies(w.interest)) { @@ -216,46 +220,48 @@ impl ScheduledIo { } } - /// Poll version of checking readiness for a certain direction. - /// - /// These are to support `AsyncRead` and `AsyncWrite` polling methods, - /// which cannot use the `async fn` version. This uses reserved reader - /// and writer slots. - pub(in crate::io) fn poll_readiness( - &self, - cx: &mut Context<'_>, - direction: Direction, - ) -> Poll { - let curr = self.readiness.load(Acquire); - - let ready = direction.mask() & Ready::from_usize(READINESS.unpack(curr)); - - if ready.is_empty() { - // Update the task info - let mut waiters = self.waiters.lock(); - let slot = match direction { - Direction::Read => &mut waiters.reader, - Direction::Write => &mut waiters.writer, - }; - *slot = Some(cx.waker().clone()); - - // Try again, in case the readiness was changed while we were - // taking the waiters lock + cfg_io_poll_evented! { + /// Poll version of checking readiness for a certain direction. + /// + /// These are to support `AsyncRead` and `AsyncWrite` polling methods, + /// which cannot use the `async fn` version. This uses reserved reader + /// and writer slots. + pub(in crate::io) fn poll_readiness( + &self, + cx: &mut Context<'_>, + direction: Direction, + ) -> Poll { let curr = self.readiness.load(Acquire); + let ready = direction.mask() & Ready::from_usize(READINESS.unpack(curr)); + if ready.is_empty() { - Poll::Pending + // Update the task info + let mut waiters = self.waiters.lock(); + let slot = match direction { + Direction::Read => &mut waiters.reader, + Direction::Write => &mut waiters.writer, + }; + *slot = Some(cx.waker().clone()); + + // Try again, in case the readiness was changed while we were + // taking the waiters lock + let curr = self.readiness.load(Acquire); + let ready = direction.mask() & Ready::from_usize(READINESS.unpack(curr)); + if ready.is_empty() { + Poll::Pending + } else { + Poll::Ready(ReadyEvent { + tick: TICK.unpack(curr) as u8, + ready, + }) + } } else { Poll::Ready(ReadyEvent { tick: TICK.unpack(curr) as u8, ready, }) } - } else { - Poll::Ready(ReadyEvent { - tick: TICK.unpack(curr) as u8, - ready, - }) } } diff --git a/tokio/src/io/mod.rs b/tokio/src/io/mod.rs index 7eba6d14972..b1c403bd70c 100644 --- a/tokio/src/io/mod.rs +++ b/tokio/src/io/mod.rs @@ -206,12 +206,21 @@ pub use std::io::{Error, ErrorKind, Result, SeekFrom}; cfg_io_driver! { pub(crate) mod driver; +} + +cfg_io_poll_evented! { + mod registration; mod poll_evented; + #[cfg(not(loom))] pub(crate) use poll_evented::PollEvented; +} - mod registration; +cfg_async_fd_unix! { + mod async_fd; + + pub use self::async_fd::{AsyncFd, ReadyGuard}; } cfg_io_std! { diff --git a/tokio/src/io/poll_evented.rs b/tokio/src/io/poll_evented.rs index 4457195fd76..c7df8bf75ab 100644 --- a/tokio/src/io/poll_evented.rs +++ b/tokio/src/io/poll_evented.rs @@ -125,6 +125,7 @@ impl PollEvented { /// Returns a shared reference to the underlying I/O object this readiness /// stream is wrapping. #[cfg(any( + feature = "async-fd", feature = "process", feature = "tcp", feature = "udp", diff --git a/tokio/src/lib.rs b/tokio/src/lib.rs index 1b0dad5d667..620fa40b5a9 100644 --- a/tokio/src/lib.rs +++ b/tokio/src/lib.rs @@ -83,6 +83,8 @@ //! - `tcp`: Enables all `tokio::net::tcp` types. //! - `udp`: Enables all `tokio::net::udp` types. //! - `uds`: Enables all `tokio::net::unix` types. +//! - `async-fd`: Enables the `tokio::io::AsyncFd` and associated types (available on +//! unix-like systems only). //! - `time`: Enables `tokio::time` types and allows the schedulers to enable //! the built in timer. //! - `process`: Enables `tokio::process` types. diff --git a/tokio/src/macros/cfg.rs b/tokio/src/macros/cfg.rs index f245b09e921..67da12adefc 100644 --- a/tokio/src/macros/cfg.rs +++ b/tokio/src/macros/cfg.rs @@ -4,6 +4,7 @@ macro_rules! cfg_resource_drivers { ($($item:item)*) => { $( #[cfg(any( + all(unix, feature = "async-fd"), feature = "process", all(unix, feature = "signal"), all(not(loom), feature = "tcp"), @@ -140,6 +141,7 @@ macro_rules! cfg_io_driver { ($($item:item)*) => { $( #[cfg(any( + all(unix, feature = "async-fd"), feature = "process", all(unix, feature = "signal"), feature = "tcp", @@ -147,6 +149,7 @@ macro_rules! cfg_io_driver { feature = "uds", ))] #[cfg_attr(docsrs, doc(cfg(any( + all(unix, feature = "async-fd"), feature = "process", all(unix, feature = "signal"), feature = "tcp", @@ -158,10 +161,47 @@ macro_rules! cfg_io_driver { } } +macro_rules! cfg_io_poll_evented { + ($($item:item)*) => { + $( + #[cfg(any( + feature = "process", + all(unix, feature = "signal"), + feature = "tcp", + feature = "udp", + feature = "uds", + ))] + #[cfg_attr(docsrs, doc(cfg(any( + feature = "process", + all(unix, feature = "signal"), + feature = "tcp", + feature = "udp", + feature = "uds", + ))))] + $item + )* + } +} + +macro_rules! cfg_async_fd_unix { + ($($item:item)*) => { + $( + #[cfg( + all(unix, feature = "async-fd"), + )] + #[cfg_attr(docsrs, doc(cfg( + all(unix, feature = "async-fd"), + )))] + $item + )* + } +} + macro_rules! cfg_not_io_driver { ($($item:item)*) => { $( #[cfg(not(any( + all(unix, feature = "async-fd"), feature = "process", all(unix, feature = "signal"), feature = "tcp", @@ -176,7 +216,7 @@ macro_rules! cfg_not_io_driver { macro_rules! cfg_io_readiness { ($($item:item)*) => { $( - #[cfg(any(feature = "udp", feature = "uds"))] + #[cfg(any(feature = "udp", feature = "uds", feature="async-fd"))] $item )* } diff --git a/tokio/src/runtime/builder.rs b/tokio/src/runtime/builder.rs index d43666d3c0f..1cd20a93716 100644 --- a/tokio/src/runtime/builder.rs +++ b/tokio/src/runtime/builder.rs @@ -139,6 +139,7 @@ impl Builder { /// ``` pub fn enable_all(&mut self) -> &mut Self { #[cfg(any( + feature = "async-fd", feature = "process", all(unix, feature = "signal"), feature = "tcp", diff --git a/tokio/tests/io_async_fd.rs b/tokio/tests/io_async_fd.rs new file mode 100644 index 00000000000..79caf50e223 --- /dev/null +++ b/tokio/tests/io_async_fd.rs @@ -0,0 +1,279 @@ +#![warn(rust_2018_idioms)] +#![cfg(all( + unix, + feature = "async-fd", + feature = "test-util", + feature = "macros", + feature = "rt-threaded" +))] + +use futures::FutureExt; +use std::io::{self, ErrorKind, Read, Write}; +use std::{os::unix::io::RawFd, sync::Arc, time::Duration}; +use tokio::io::AsyncFd; + +use nix::errno::Errno; +use nix::unistd::{close, read, write}; + +fn is_blocking(e: &nix::Error) -> bool { + Some(Errno::EAGAIN) == e.as_errno() +} + +struct FileDescriptor { + fd: RawFd, +} + +impl Read for FileDescriptor { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + match read(self.fd, buf) { + Ok(n) => Ok(n), + Err(e) if is_blocking(&e) => Err(ErrorKind::WouldBlock.into()), + Err(e) => Err(io::Error::new(ErrorKind::Other, e)), + } + } +} + +impl Write for FileDescriptor { + fn write(&mut self, buf: &[u8]) -> io::Result { + match write(self.fd, buf) { + Ok(n) => Ok(n), + Err(e) if is_blocking(&e) => Err(ErrorKind::WouldBlock.into()), + Err(e) => Err(io::Error::new(ErrorKind::Other, e)), + } + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +impl Drop for FileDescriptor { + fn drop(&mut self) { + let _ = close(self.fd); + } +} + +fn set_nonblocking(fd: RawFd) { + use nix::fcntl::{OFlag, F_GETFL, F_SETFL}; + + let flags = nix::fcntl::fcntl(fd, F_GETFL).expect("fcntl(F_GETFD)"); + + if flags < 0 { + panic!( + "bad return value from fcntl(F_GETFL): {} ({:?})", + flags, + nix::Error::last() + ); + } + + let flags = OFlag::from_bits_truncate(flags) | OFlag::O_NONBLOCK; + + nix::fcntl::fcntl(fd, F_SETFL(flags)).expect("fcntl(F_SETFD)"); +} + +fn socketpair() -> (FileDescriptor, FileDescriptor) { + use nix::sys::socket::{socketpair, AddressFamily, SockFlag, SockType}; + + let (fd_a, fd_b) = socketpair( + AddressFamily::Unix, + SockType::Stream, + None, + SockFlag::empty(), + ) + .expect("socketpair"); + let fds = (FileDescriptor { fd: fd_a }, FileDescriptor { fd: fd_b }); + + set_nonblocking(fds.0.fd); + set_nonblocking(fds.1.fd); + + fds +} + +fn drain(fd: &mut FileDescriptor) { + let mut buf = [0u8; 512]; + + loop { + match fd.read(&mut buf[..]) { + Err(e) if e.kind() == ErrorKind::WouldBlock => break, + Ok(0) => panic!("unexpected EOF"), + Err(e) => panic!("unexpected error: {:?}", e), + Ok(_) => continue, + } + } +} + +#[tokio::test] +async fn initially_writable() { + let (a, b) = socketpair(); + + let afd_a = AsyncFd::new(a.fd).unwrap(); + let afd_b = AsyncFd::new(b.fd).unwrap(); + + afd_a.writable().await.unwrap().clear_ready(); + afd_b.writable().await.unwrap().clear_ready(); + + futures::select_biased! { + _ = tokio::time::sleep(Duration::from_millis(10)).fuse() => {}, + _ = afd_a.readable().fuse() => panic!("Unexpected readable state"), + _ = afd_b.readable().fuse() => panic!("Unexpected readable state"), + } +} + +#[tokio::test] +async fn reset_readable() { + let (mut a, mut b) = socketpair(); + + let afd_a = AsyncFd::new(a.fd).unwrap(); + + let readable = afd_a.readable(); + tokio::pin!(readable); + + tokio::select! { + _ = readable.as_mut() => panic!(), + _ = tokio::time::sleep(Duration::from_millis(10)) => {} + } + + b.write_all(b"0").unwrap(); + + let mut guard = readable.await.unwrap(); + + guard.with_io(|| a.read(&mut [0])).unwrap(); + + // `a` is not readable, but the reactor still thinks it is + // (because we have not observed a not-ready error yet) + afd_a.readable().await.unwrap().retain_ready(); + + // Explicitly clear the ready state + guard.clear_ready(); + + let readable = afd_a.readable(); + tokio::pin!(readable); + + tokio::select! { + _ = readable.as_mut() => panic!(), + _ = tokio::time::sleep(Duration::from_millis(10)) => {} + } + + b.write_all(b"0").unwrap(); + + // We can observe the new readable event + afd_a.readable().await.unwrap().clear_ready(); +} + +#[tokio::test] +async fn reset_writable() { + let (mut a, mut b) = socketpair(); + + let afd_a = AsyncFd::new(a.fd).unwrap(); + + let mut guard = afd_a.writable().await.unwrap(); + + // Write until we get a WouldBlock. This also clears the ready state. + loop { + if let Err(e) = guard.with_io(|| a.write(&[0; 512][..])) { + assert_eq!(ErrorKind::WouldBlock, e.kind()); + break; + } + } + + // Writable state should be cleared now. + let writable = afd_a.writable(); + tokio::pin!(writable); + + tokio::select! { + _ = writable.as_mut() => panic!(), + _ = tokio::time::sleep(Duration::from_millis(10)) => {} + } + + // Read from the other side; we should become writable now. + drain(&mut b); + + let _ = writable.await.unwrap(); +} + +#[tokio::test] +async fn with_poll() { + use std::task::Poll; + + let (mut a, mut b) = socketpair(); + + b.write_all(b"0").unwrap(); + + let afd_a = AsyncFd::new(a.fd).unwrap(); + + let mut guard = afd_a.readable().await.unwrap(); + + a.read_exact(&mut [0]).unwrap(); + + // Should not clear the readable state + let _ = guard.with_poll(|| Poll::Ready(())); + + // Still readable... + let _ = afd_a.readable().await.unwrap(); + + // Should clear the readable state + let _ = guard.with_poll(|| Poll::Pending::<()>); + + // Assert not readable + let readable = afd_a.readable(); + tokio::pin!(readable); + + tokio::select! { + _ = readable.as_mut() => panic!(), + _ = tokio::time::sleep(Duration::from_millis(10)) => {} + } + + // Write something down b again and make sure we're reawoken + b.write_all(b"0").unwrap(); + let _ = readable.await.unwrap(); +} + +#[tokio::test] +async fn multiple_waiters() { + let (a, mut b) = socketpair(); + let a = Arc::new(a); + let afd_a = Arc::new(AsyncFd::new(a.fd).unwrap()); + + let barrier = Arc::new(tokio::sync::Barrier::new(11)); + + let mut tasks = Vec::new(); + for _ in 0..10 { + let a = a.clone(); + let afd_a = afd_a.clone(); + let barrier = barrier.clone(); + + let f = async move { + let notify_barrier = async { + barrier.wait().await; + futures::future::pending::<()>().await; + }; + + futures::select_biased! { + guard = afd_a.readable().fuse() => { + tokio::task::yield_now().await; + guard.unwrap().clear_ready() + }, + _ = notify_barrier.fuse() => unreachable!(), + } + + std::mem::drop(afd_a); + std::mem::drop(a); // drop the fd only once we deregister + }; + + tasks.push(tokio::spawn(f)); + } + + let mut all_tasks = futures::future::try_join_all(tasks); + + tokio::select! { + r = std::pin::Pin::new(&mut all_tasks) => { + r.unwrap(); // propagate panic + panic!("Tasks exited unexpectedly") + }, + _ = barrier.wait() => {} + }; + + b.write_all(b"0").unwrap(); + + all_tasks.await.unwrap(); +}