diff --git a/tokio/Cargo.toml b/tokio/Cargo.toml index 57e49d68a81..6589f485299 100644 --- a/tokio/Cargo.toml +++ b/tokio/Cargo.toml @@ -53,6 +53,7 @@ net = [ "lazy_static", "libc", "mio/os-poll", + "mio/os-util", "mio/tcp", "mio/udp", "mio/uds", @@ -78,6 +79,7 @@ signal = [ "libc", "mio/os-poll", "mio/uds", + "mio/os-util", "signal-hook-registry", "winapi/consoleapi", ] @@ -107,6 +109,10 @@ tracing = { version = "0.1.16", default-features = false, features = ["std"], op libc = { version = "0.2.42", optional = true } signal-hook-registry = { version = "1.1.1", optional = true } +[target.'cfg(unix)'.dev-dependencies] +libc = { version = "0.2.42" } +nix = { version = "0.18.0" } + [target.'cfg(windows)'.dependencies.winapi] version = "0.3.8" default-features = false diff --git a/tokio/src/io/async_fd.rs b/tokio/src/io/async_fd.rs new file mode 100644 index 00000000000..e5ad2ab4314 --- /dev/null +++ b/tokio/src/io/async_fd.rs @@ -0,0 +1,337 @@ +use std::os::unix::io::{AsRawFd, RawFd}; +use std::{task::Context, task::Poll}; + +use std::io; + +use mio::unix::SourceFd; + +use crate::io::driver::{Direction, Handle, ReadyEvent, ScheduledIo}; +use crate::util::slab; + +/// Associates an IO object backed by a Unix file descriptor with the tokio +/// reactor, allowing for readiness to be polled. The file descriptor must be of +/// a type that can be used with the OS polling facilities (ie, `poll`, `epoll`, +/// `kqueue`, etc), such as a network socket or pipe. +/// +/// Creating an AsyncFd registers the file descriptor with the current tokio +/// Reactor, allowing you to directly await the file descriptor being readable +/// or writable. Once registered, the file descriptor remains registered until +/// the AsyncFd is dropped. +/// +/// The AsyncFd takes ownership of an arbitrary object to represent the IO +/// object. It is intended that this object will handle closing the file +/// descriptor when it is dropped, avoiding resource leaks and ensuring that the +/// AsyncFd can clean up the registration before closing the file descriptor. +/// The [`AsyncFd::into_inner`] function can be used to extract the inner object +/// to retake control from the tokio IO reactor. +/// +/// The inner object is required to implement [`AsRawFd`]. This file descriptor +/// must not change while [`AsyncFd`] owns the inner object. Changing the file +/// descriptor results in unspecified behavior in the IO driver, which may +/// include breaking notifications for other sockets/etc. +/// +/// Polling for readiness is done by calling the async functions [`readable`] +/// and [`writable`]. These functions complete when the associated readiness +/// condition is observed. Any number of tasks can query the same `AsyncFd` in +/// parallel, on the same or different conditions. +/// +/// On some platforms, the readiness detecting mechanism relies on +/// edge-triggered notifications. This means that the OS will only notify Tokio +/// when the file descriptor transitions from not-ready to ready. Tokio +/// internally tracks when it has received a ready notification, and when +/// readiness checking functions like [`readable`] and [`writable`] are called, +/// if the readiness flag is set, these async functions will complete +/// immediately. +/// +/// This however does mean that it is critical to ensure that this ready flag is +/// cleared when (and only when) the file descriptor ceases to be ready. The +/// [`AsyncFdReadyGuard`] returned from readiness checking functions serves this +/// function; after calling a readiness-checking async function, you must use +/// this [`AsyncFdReadyGuard`] to signal to tokio whether the file descriptor is no +/// longer in a ready state. +/// +/// ## Use with to a poll-based API +/// +/// In some cases it may be desirable to use `AsyncFd` from APIs similar to +/// [`TcpStream::poll_read_ready`]. The [`AsyncFd::poll_read_ready`] and +/// [`AsyncFd::poll_write_ready`] functions are provided for this purpose. +/// Because these functions don't create a future to hold their state, they have +/// the limitation that only one task can wait on each direction (read or write) +/// at a time. +/// +/// [`readable`]: method@Self::readable +/// [`writable`]: method@Self::writable +/// [`AsyncFdReadyGuard`]: struct@self::AsyncFdReadyGuard +/// [`TcpStream::poll_read_ready`]: struct@crate::net::TcpStream +pub struct AsyncFd { + handle: Handle, + shared: slab::Ref, + inner: Option, +} + +impl AsRawFd for AsyncFd { + fn as_raw_fd(&self) -> RawFd { + self.inner.as_ref().unwrap().as_raw_fd() + } +} + +impl std::fmt::Debug for AsyncFd { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AsyncFd") + .field("inner", &self.inner) + .finish() + } +} + +const ALL_INTEREST: mio::Interest = mio::Interest::READABLE.add(mio::Interest::WRITABLE); + +/// Represents an IO-ready event detected on a particular file descriptor, which +/// has not yet been acknowledged. This is a `must_use` structure to help ensure +/// that you do not forget to explicitly clear (or not clear) the event. +#[must_use = "You must explicitly choose whether to clear the readiness state by calling a method on ReadyGuard"] +pub struct AsyncFdReadyGuard<'a, T: AsRawFd> { + async_fd: &'a AsyncFd, + event: Option, +} + +impl<'a, T: std::fmt::Debug + AsRawFd> std::fmt::Debug for AsyncFdReadyGuard<'a, T> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ReadyGuard") + .field("async_fd", &self.async_fd) + .finish() + } +} + +impl<'a, Inner: AsRawFd> AsyncFdReadyGuard<'a, Inner> { + /// Indicates to tokio that the file descriptor is no longer ready. The + /// internal readiness flag will be cleared, and tokio will wait for the + /// next edge-triggered readiness notification from the OS. + /// + /// It is critical that this function not be called unless your code + /// _actually observes_ that the file descriptor is _not_ ready. Do not call + /// it simply because, for example, a read succeeded; it should be called + /// when a read is observed to block. + /// + /// [`drop`]: method@std::mem::drop + pub fn clear_ready(&mut self) { + if let Some(event) = self.event.take() { + self.async_fd.shared.clear_readiness(event); + } + } + + /// This function should be invoked when you intentionally want to keep the + /// ready flag asserted. + /// + /// While this function is itself a no-op, it satisfies the `#[must_use]` + /// constraint on the [`AsyncFdReadyGuard`] type. + pub fn retain_ready(&mut self) { + // no-op + } + + /// Performs the IO operation `f`; if `f` returns a [`WouldBlock`] error, + /// the readiness state associated with this file descriptor is cleared. + /// + /// This method helps ensure that the readiness state of the underlying file + /// descriptor remains in sync with the tokio-side readiness state, by + /// clearing the tokio-side state only when a [`WouldBlock`] condition + /// occurs. It is the responsibility of the caller to ensure that `f` + /// returns [`WouldBlock`] only if the file descriptor that originated this + /// `AsyncFdReadyGuard` no longer expresses the readiness state that was queried to + /// create this `AsyncFdReadyGuard`. + /// + /// [`WouldBlock`]: std::io::ErrorKind::WouldBlock + pub fn with_io(&mut self, f: impl FnOnce() -> io::Result) -> io::Result { + let result = f(); + + if let Err(e) = result.as_ref() { + if e.kind() == io::ErrorKind::WouldBlock { + self.clear_ready(); + } + } + + result + } + + /// Performs the IO operation `f`; if `f` returns [`Pending`], the readiness + /// state associated with this file descriptor is cleared. + /// + /// This method helps ensure that the readiness state of the underlying file + /// descriptor remains in sync with the tokio-side readiness state, by + /// clearing the tokio-side state only when a [`Pending`] condition occurs. + /// It is the responsibility of the caller to ensure that `f` returns + /// [`Pending`] only if the file descriptor that originated this + /// `AsyncFdReadyGuard` no longer expresses the readiness state that was queried to + /// create this `AsyncFdReadyGuard`. + /// + /// [`Pending`]: std::task::Poll::Pending + pub fn with_poll(&mut self, f: impl FnOnce() -> std::task::Poll) -> std::task::Poll { + let result = f(); + + if result.is_pending() { + self.clear_ready(); + } + + result + } +} + +impl Drop for AsyncFd { + fn drop(&mut self) { + if let Some(driver) = self.handle.inner() { + if let Some(inner) = self.inner.as_ref() { + let fd = inner.as_raw_fd(); + let _ = driver.deregister_source(&mut SourceFd(&fd)); + } + } + } +} + +impl AsyncFd { + /// Creates an AsyncFd backed by (and taking ownership of) an object + /// implementing [`AsRawFd`]. The backing file descriptor is cached at the + /// time of creation. + /// + /// This function must be called in the context of a tokio runtime. + pub fn new(inner: T) -> io::Result + where + T: AsRawFd, + { + Self::new_with_handle(inner, Handle::current()) + } + + pub(crate) fn new_with_handle(inner: T, handle: Handle) -> io::Result { + let fd = inner.as_raw_fd(); + + let shared = if let Some(inner) = handle.inner() { + inner.add_source(&mut SourceFd(&fd), ALL_INTEREST)? + } else { + return Err(io::Error::new( + io::ErrorKind::Other, + "failed to find event loop", + )); + }; + + Ok(AsyncFd { + handle, + shared, + inner: Some(inner), + }) + } + + /// Returns a shared reference to the backing object of this [`AsyncFd`] + #[inline] + pub fn get_ref(&self) -> &T { + self.inner.as_ref().unwrap() + } + + /// Returns a mutable reference to the backing object of this [`AsyncFd`] + #[inline] + pub fn get_mut(&mut self) -> &mut T { + self.inner.as_mut().unwrap() + } + + /// Deregisters this file descriptor, and returns ownership of the backing + /// object. + pub fn into_inner(mut self) -> T { + self.inner.take().unwrap() + } + + /// Polls for read readiness. This function retains the waker for the last + /// context that called [`poll_read_ready`]; it therefore can only be used + /// by a single task at a time (however, [`poll_write_ready`] retains a + /// second, independent waker). + /// + /// This function is intended for cases where creating and pinning a future + /// via [`readable`] is not feasible. Where possible, using [`readable`] is + /// preferred, as this supports polling from multiple tasks at once. + /// + /// [`poll_read_ready`]: method@Self::poll_read_ready + /// [`poll_write_ready`]: method@Self::poll_write_ready + /// [`readable`]: method@Self::readable + pub fn poll_read_ready<'a>( + &'a self, + cx: &mut Context<'_>, + ) -> Poll>> { + let event = ready!(self.shared.poll_readiness(cx, Direction::Read)); + + if !self.handle.is_alive() { + return Err(io::Error::new( + io::ErrorKind::Other, + "IO driver has terminated", + )) + .into(); + } + + Ok(AsyncFdReadyGuard { + async_fd: self, + event: Some(event), + }) + .into() + } + + /// Polls for write readiness. This function retains the waker for the last + /// context that called [`poll_write_ready`]; it therefore can only be used + /// by a single task at a time (however, [`poll_read_ready`] retains a + /// second, independent waker). + /// + /// This function is intended for cases where creating and pinning a future + /// via [`writable`] is not feasible. Where possible, using [`writable`] is + /// preferred, as this supports polling from multiple tasks at once. + /// + /// [`poll_read_ready`]: method@Self::poll_read_ready + /// [`poll_write_ready`]: method@Self::poll_write_ready + /// [`writable`]: method@Self::writable + pub fn poll_write_ready<'a>( + &'a self, + cx: &mut Context<'_>, + ) -> Poll>> { + let event = ready!(self.shared.poll_readiness(cx, Direction::Write)); + + if !self.handle.is_alive() { + return Err(io::Error::new( + io::ErrorKind::Other, + "IO driver has terminated", + )) + .into(); + } + + Ok(AsyncFdReadyGuard { + async_fd: self, + event: Some(event), + }) + .into() + } + + async fn readiness(&self, interest: mio::Interest) -> io::Result> { + let event = self.shared.readiness(interest); + + if !self.handle.is_alive() { + return Err(io::Error::new( + io::ErrorKind::Other, + "IO driver has terminated", + )); + } + + let event = event.await; + Ok(AsyncFdReadyGuard { + async_fd: self, + event: Some(event), + }) + } + + /// Waits for the file descriptor to become readable, returning a + /// [`AsyncFdReadyGuard`] that must be dropped to resume read-readiness polling. + /// + /// [`AsyncFdReadyGuard`]: struct@self::AsyncFdReadyGuard + pub async fn readable(&self) -> io::Result> { + self.readiness(mio::Interest::READABLE).await + } + + /// Waits for the file descriptor to become writable, returning a + /// [`AsyncFdReadyGuard`] that must be dropped to resume write-readiness polling. + /// + /// [`AsyncFdReadyGuard`]: struct@self::AsyncFdReadyGuard + pub async fn writable(&self) -> io::Result> { + self.readiness(mio::Interest::WRITABLE).await + } +} diff --git a/tokio/src/io/driver/mod.rs b/tokio/src/io/driver/mod.rs index cd82b26f0f8..a0d8e6f2382 100644 --- a/tokio/src/io/driver/mod.rs +++ b/tokio/src/io/driver/mod.rs @@ -7,8 +7,8 @@ mod scheduled_io; pub(crate) use scheduled_io::ScheduledIo; // pub(crate) for tests use crate::park::{Park, Unpark}; -use crate::util::bit; use crate::util::slab::{self, Slab}; +use crate::{loom::sync::Mutex, util::bit}; use std::fmt; use std::io; @@ -25,8 +25,10 @@ pub(crate) struct Driver { events: Option, /// Primary slab handle containing the state for each resource registered - /// with this driver. - resources: Slab, + /// with this driver. During Drop this is moved into the Inner structure, so + /// this is an Option to allow it to be vacated (until Drop this is always + /// Some) + resources: Option>, /// The system event queue poll: mio::Poll, @@ -47,6 +49,14 @@ pub(crate) struct ReadyEvent { } pub(super) struct Inner { + /// Primary slab handle containing the state for each resource registered + /// with this driver. + /// + /// The ownership of this slab is moved into this structure during + /// `Driver::drop`, so that `Inner::drop` can notify all outstanding handles + /// without risking new ones being registered in the meantime. + resources: Mutex>>, + /// Registers I/O resources registry: mio::Registry, @@ -104,9 +114,10 @@ impl Driver { Ok(Driver { tick: 0, events: Some(mio::Events::with_capacity(1024)), - resources: slab, poll, + resources: Some(slab), inner: Arc::new(Inner { + resources: Mutex::new(None), registry, io_dispatch: allocator, waker, @@ -133,7 +144,7 @@ impl Driver { self.tick = self.tick.wrapping_add(1); if self.tick == COMPACT_INTERVAL { - self.resources.compact(); + self.resources.as_mut().unwrap().compact() } let mut events = self.events.take().expect("i/o driver event store missing"); @@ -163,7 +174,9 @@ impl Driver { fn dispatch(&mut self, token: mio::Token, ready: Ready) { let addr = slab::Address::from_usize(ADDRESS.unpack(token.0)); - let io = match self.resources.get(addr) { + let resources = self.resources.as_mut().unwrap(); + + let io = match resources.get(addr) { Some(io) => io, None => return, }; @@ -181,12 +194,22 @@ impl Driver { impl Drop for Driver { fn drop(&mut self) { - self.resources.for_each(|io| { - // If a task is waiting on the I/O resource, notify it. The task - // will then attempt to use the I/O resource and fail due to the - // driver being shutdown. - io.wake(Ready::ALL); - }) + (*self.inner.resources.lock()) = self.resources.take(); + } +} + +impl Drop for Inner { + fn drop(&mut self) { + let resources = self.resources.lock().take(); + + if let Some(mut slab) = resources { + slab.for_each(|io| { + // If a task is waiting on the I/O resource, notify it. The task + // will then attempt to use the I/O resource and fail due to the + // driver being shutdown. + io.shutdown(); + }); + } } } @@ -267,6 +290,12 @@ impl Handle { pub(super) fn inner(&self) -> Option> { self.inner.upgrade() } + + cfg_net_unix! { + pub(super) fn is_alive(&self) -> bool { + self.inner.strong_count() > 0 + } + } } impl Unpark for Handle { diff --git a/tokio/src/io/driver/scheduled_io.rs b/tokio/src/io/driver/scheduled_io.rs index b1354a0551a..3aefb3766d5 100644 --- a/tokio/src/io/driver/scheduled_io.rs +++ b/tokio/src/io/driver/scheduled_io.rs @@ -1,4 +1,4 @@ -use super::{Direction, Ready, ReadyEvent, Tick}; +use super::{Ready, ReadyEvent, Tick}; use crate::loom::sync::atomic::AtomicUsize; use crate::loom::sync::Mutex; use crate::util::bit; @@ -7,6 +7,8 @@ use crate::util::slab::Entry; use std::sync::atomic::Ordering::{AcqRel, Acquire, Release}; use std::task::{Context, Poll, Waker}; +use super::Direction; + cfg_io_readiness! { use crate::util::linked_list::{self, LinkedList}; @@ -41,6 +43,9 @@ struct Waiters { /// Waker used for AsyncWrite writer: Option, + + /// True if this ScheduledIo has been killed due to IO driver shutdown + is_shutdown: bool, } cfg_io_readiness! { @@ -121,6 +126,12 @@ impl ScheduledIo { GENERATION.unpack(self.readiness.load(Acquire)) } + /// Invoked when the IO driver is shut down; forces this ScheduledIo into a + /// permanently ready state. + pub(super) fn shutdown(&self) { + self.wake0(Ready::ALL, true) + } + /// Sets the readiness on this `ScheduledIo` by invoking the given closure on /// the current value, returning the previous readiness value. /// @@ -197,6 +208,10 @@ impl ScheduledIo { /// than 32 wakers to notify, if the stack array fills up, the lock is /// released, the array is cleared, and the iteration continues. pub(super) fn wake(&self, ready: Ready) { + self.wake0(ready, false); + } + + fn wake0(&self, ready: Ready, shutdown: bool) { const NUM_WAKERS: usize = 32; let mut wakers: [Option; NUM_WAKERS] = Default::default(); @@ -204,6 +219,8 @@ impl ScheduledIo { let mut waiters = self.waiters.lock(); + waiters.is_shutdown |= shutdown; + // check for AsyncRead slot if ready.is_readable() { if let Some(waker) = waiters.reader.take() { @@ -288,7 +305,12 @@ impl ScheduledIo { // taking the waiters lock let curr = self.readiness.load(Acquire); let ready = direction.mask() & Ready::from_usize(READINESS.unpack(curr)); - if ready.is_empty() { + if waiters.is_shutdown { + Poll::Ready(ReadyEvent { + tick: TICK.unpack(curr) as u8, + ready: direction.mask(), + }) + } else if ready.is_empty() { Poll::Pending } else { Poll::Ready(ReadyEvent { @@ -401,7 +423,12 @@ cfg_io_readiness! { let mut waiters = scheduled_io.waiters.lock(); let curr = scheduled_io.readiness.load(SeqCst); - let ready = Ready::from_usize(READINESS.unpack(curr)); + let mut ready = Ready::from_usize(READINESS.unpack(curr)); + + if waiters.is_shutdown { + ready = Ready::ALL; + } + let ready = ready.intersection(interest); if !ready.is_empty() { diff --git a/tokio/src/io/mod.rs b/tokio/src/io/mod.rs index 9191bbcd19e..20d92233c73 100644 --- a/tokio/src/io/mod.rs +++ b/tokio/src/io/mod.rs @@ -207,11 +207,21 @@ pub use std::io::{Error, ErrorKind, Result, SeekFrom}; cfg_io_driver! { pub(crate) mod driver; + mod registration; + mod poll_evented; + #[cfg(not(loom))] pub(crate) use poll_evented::PollEvented; +} - mod registration; +cfg_net_unix! { + mod async_fd; + + pub mod unix { + //! Asynchronous IO structures specific to Unix-like operating systems. + pub use super::async_fd::{AsyncFd, AsyncFdReadyGuard}; + } } cfg_io_std! { diff --git a/tokio/src/lib.rs b/tokio/src/lib.rs index 66e266c3a71..690beab39c4 100644 --- a/tokio/src/lib.rs +++ b/tokio/src/lib.rs @@ -305,7 +305,8 @@ //! - `rt-multi-thread`: Enables the heavier, multi-threaded, work-stealing scheduler. //! - `io-util`: Enables the IO based `Ext` traits. //! - `io-std`: Enable `Stdout`, `Stdin` and `Stderr` types. -//! - `net`: Enables `tokio::net` types such as `TcpStream`, `UnixStream` and `UdpSocket`. +//! - `net`: Enables `tokio::net` types such as `TcpStream`, `UnixStream` and `UdpSocket`, +//! as well as (on Unix-like systems) `AsyncFd` //! - `time`: Enables `tokio::time` types and allows the schedulers to enable //! the built in timer. //! - `process`: Enables `tokio::process` types. diff --git a/tokio/src/loom/std/mod.rs b/tokio/src/loom/std/mod.rs index 9525286895f..414ef90623b 100644 --- a/tokio/src/loom/std/mod.rs +++ b/tokio/src/loom/std/mod.rs @@ -74,7 +74,7 @@ pub(crate) mod sync { pub(crate) use crate::loom::std::atomic_u8::AtomicU8; pub(crate) use crate::loom::std::atomic_usize::AtomicUsize; - pub(crate) use std::sync::atomic::{spin_loop_hint, AtomicBool}; + pub(crate) use std::sync::atomic::{fence, spin_loop_hint, AtomicBool, Ordering}; } } diff --git a/tokio/src/loom/std/parking_lot.rs b/tokio/src/loom/std/parking_lot.rs index c03190feb8b..8448bed53d7 100644 --- a/tokio/src/loom/std/parking_lot.rs +++ b/tokio/src/loom/std/parking_lot.rs @@ -43,6 +43,11 @@ impl Mutex { self.0.try_lock() } + #[inline] + pub(crate) fn get_mut(&mut self) -> &mut T { + self.0.get_mut() + } + // Note: Additional methods `is_poisoned` and `into_inner`, can be // provided here as needed. } diff --git a/tokio/tests/io_async_fd.rs b/tokio/tests/io_async_fd.rs new file mode 100644 index 00000000000..0303eff6612 --- /dev/null +++ b/tokio/tests/io_async_fd.rs @@ -0,0 +1,604 @@ +#![warn(rust_2018_idioms)] +#![cfg(all(unix, feature = "full"))] + +use std::os::unix::io::{AsRawFd, RawFd}; +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, +}; +use std::time::Duration; +use std::{ + future::Future, + io::{self, ErrorKind, Read, Write}, + task::{Context, Waker}, +}; + +use nix::errno::Errno; +use nix::unistd::{close, read, write}; + +use futures::{poll, FutureExt}; + +use tokio::io::unix::{AsyncFd, AsyncFdReadyGuard}; +use tokio_test::{assert_err, assert_pending}; + +struct TestWaker { + inner: Arc, + waker: Waker, +} + +#[derive(Default)] +struct TestWakerInner { + awoken: AtomicBool, +} + +impl futures::task::ArcWake for TestWakerInner { + fn wake_by_ref(arc_self: &Arc) { + arc_self.awoken.store(true, Ordering::SeqCst); + } +} + +impl TestWaker { + fn new() -> Self { + let inner: Arc = Default::default(); + + Self { + inner: inner.clone(), + waker: futures::task::waker(inner), + } + } + + fn awoken(&self) -> bool { + self.inner.awoken.swap(false, Ordering::SeqCst) + } + + fn context(&self) -> Context<'_> { + Context::from_waker(&self.waker) + } +} + +fn is_blocking(e: &nix::Error) -> bool { + Some(Errno::EAGAIN) == e.as_errno() +} + +#[derive(Debug)] +struct FileDescriptor { + fd: RawFd, +} + +impl AsRawFd for FileDescriptor { + fn as_raw_fd(&self) -> RawFd { + self.fd + } +} + +impl Read for &FileDescriptor { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + match read(self.fd, buf) { + Ok(n) => Ok(n), + Err(e) if is_blocking(&e) => Err(ErrorKind::WouldBlock.into()), + Err(e) => Err(io::Error::new(ErrorKind::Other, e)), + } + } +} + +impl Read for FileDescriptor { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + (self as &Self).read(buf) + } +} + +impl Write for &FileDescriptor { + fn write(&mut self, buf: &[u8]) -> io::Result { + match write(self.fd, buf) { + Ok(n) => Ok(n), + Err(e) if is_blocking(&e) => Err(ErrorKind::WouldBlock.into()), + Err(e) => Err(io::Error::new(ErrorKind::Other, e)), + } + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +impl Write for FileDescriptor { + fn write(&mut self, buf: &[u8]) -> io::Result { + (self as &Self).write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + (self as &Self).flush() + } +} + +impl Drop for FileDescriptor { + fn drop(&mut self) { + let _ = close(self.fd); + } +} + +fn set_nonblocking(fd: RawFd) { + use nix::fcntl::{OFlag, F_GETFL, F_SETFL}; + + let flags = nix::fcntl::fcntl(fd, F_GETFL).expect("fcntl(F_GETFD)"); + + if flags < 0 { + panic!( + "bad return value from fcntl(F_GETFL): {} ({:?})", + flags, + nix::Error::last() + ); + } + + let flags = OFlag::from_bits_truncate(flags) | OFlag::O_NONBLOCK; + + nix::fcntl::fcntl(fd, F_SETFL(flags)).expect("fcntl(F_SETFD)"); +} + +fn socketpair() -> (FileDescriptor, FileDescriptor) { + use nix::sys::socket::{self, AddressFamily, SockFlag, SockType}; + + let (fd_a, fd_b) = socket::socketpair( + AddressFamily::Unix, + SockType::Stream, + None, + SockFlag::empty(), + ) + .expect("socketpair"); + let fds = (FileDescriptor { fd: fd_a }, FileDescriptor { fd: fd_b }); + + set_nonblocking(fds.0.fd); + set_nonblocking(fds.1.fd); + + fds +} + +fn drain(mut fd: &FileDescriptor) { + let mut buf = [0u8; 512]; + + loop { + match fd.read(&mut buf[..]) { + Err(e) if e.kind() == ErrorKind::WouldBlock => break, + Ok(0) => panic!("unexpected EOF"), + Err(e) => panic!("unexpected error: {:?}", e), + Ok(_) => continue, + } + } +} + +#[tokio::test] +async fn initially_writable() { + let (a, b) = socketpair(); + + let afd_a = AsyncFd::new(a).unwrap(); + let afd_b = AsyncFd::new(b).unwrap(); + + afd_a.writable().await.unwrap().clear_ready(); + afd_b.writable().await.unwrap().clear_ready(); + + futures::select_biased! { + _ = tokio::time::sleep(Duration::from_millis(10)).fuse() => {}, + _ = afd_a.readable().fuse() => panic!("Unexpected readable state"), + _ = afd_b.readable().fuse() => panic!("Unexpected readable state"), + } +} + +#[tokio::test] +async fn reset_readable() { + let (a, mut b) = socketpair(); + + let afd_a = AsyncFd::new(a).unwrap(); + + let readable = afd_a.readable(); + tokio::pin!(readable); + + tokio::select! { + _ = readable.as_mut() => panic!(), + _ = tokio::time::sleep(Duration::from_millis(10)) => {} + } + + b.write_all(b"0").unwrap(); + + let mut guard = readable.await.unwrap(); + + guard.with_io(|| afd_a.get_ref().read(&mut [0])).unwrap(); + + // `a` is not readable, but the reactor still thinks it is + // (because we have not observed a not-ready error yet) + afd_a.readable().await.unwrap().retain_ready(); + + // Explicitly clear the ready state + guard.clear_ready(); + + let readable = afd_a.readable(); + tokio::pin!(readable); + + tokio::select! { + _ = readable.as_mut() => panic!(), + _ = tokio::time::sleep(Duration::from_millis(10)) => {} + } + + b.write_all(b"0").unwrap(); + + // We can observe the new readable event + afd_a.readable().await.unwrap().clear_ready(); +} + +#[tokio::test] +async fn reset_writable() { + let (a, b) = socketpair(); + + let afd_a = AsyncFd::new(a).unwrap(); + + let mut guard = afd_a.writable().await.unwrap(); + + // Write until we get a WouldBlock. This also clears the ready state. + loop { + if let Err(e) = guard.with_io(|| afd_a.get_ref().write(&[0; 512][..])) { + assert_eq!(ErrorKind::WouldBlock, e.kind()); + break; + } + } + + // Writable state should be cleared now. + let writable = afd_a.writable(); + tokio::pin!(writable); + + tokio::select! { + _ = writable.as_mut() => panic!(), + _ = tokio::time::sleep(Duration::from_millis(10)) => {} + } + + // Read from the other side; we should become writable now. + drain(&b); + + let _ = writable.await.unwrap(); +} + +#[derive(Debug)] +struct ArcFd(Arc); +impl AsRawFd for ArcFd { + fn as_raw_fd(&self) -> RawFd { + self.0.as_raw_fd() + } +} + +#[tokio::test] +async fn drop_closes() { + let (a, mut b) = socketpair(); + + let afd_a = AsyncFd::new(a).unwrap(); + + assert_eq!( + ErrorKind::WouldBlock, + b.read(&mut [0]).err().unwrap().kind() + ); + + std::mem::drop(afd_a); + + assert_eq!(0, b.read(&mut [0]).unwrap()); + + // into_inner does not close the fd + + let (a, mut b) = socketpair(); + let afd_a = AsyncFd::new(a).unwrap(); + let _a: FileDescriptor = afd_a.into_inner(); + + assert_eq!( + ErrorKind::WouldBlock, + b.read(&mut [0]).err().unwrap().kind() + ); + + // Drop closure behavior is delegated to the inner object + let (a, mut b) = socketpair(); + let arc_fd = Arc::new(a); + let afd_a = AsyncFd::new(ArcFd(arc_fd.clone())).unwrap(); + std::mem::drop(afd_a); + + assert_eq!( + ErrorKind::WouldBlock, + b.read(&mut [0]).err().unwrap().kind() + ); + + std::mem::drop(arc_fd); // suppress unnecessary clone clippy warning +} + +#[tokio::test] +async fn with_poll() { + use std::task::Poll; + + let (a, mut b) = socketpair(); + + b.write_all(b"0").unwrap(); + + let afd_a = AsyncFd::new(a).unwrap(); + + let mut guard = afd_a.readable().await.unwrap(); + + afd_a.get_ref().read_exact(&mut [0]).unwrap(); + + // Should not clear the readable state + let _ = guard.with_poll(|| Poll::Ready(())); + + // Still readable... + let _ = afd_a.readable().await.unwrap(); + + // Should clear the readable state + let _ = guard.with_poll(|| Poll::Pending::<()>); + + // Assert not readable + let readable = afd_a.readable(); + tokio::pin!(readable); + + tokio::select! { + _ = readable.as_mut() => panic!(), + _ = tokio::time::sleep(Duration::from_millis(10)) => {} + } + + // Write something down b again and make sure we're reawoken + b.write_all(b"0").unwrap(); + let _ = readable.await.unwrap(); +} + +#[tokio::test] +async fn multiple_waiters() { + let (a, mut b) = socketpair(); + let afd_a = Arc::new(AsyncFd::new(a).unwrap()); + + let barrier = Arc::new(tokio::sync::Barrier::new(11)); + + let mut tasks = Vec::new(); + for _ in 0..10 { + let afd_a = afd_a.clone(); + let barrier = barrier.clone(); + + let f = async move { + let notify_barrier = async { + barrier.wait().await; + futures::future::pending::<()>().await; + }; + + futures::select_biased! { + guard = afd_a.readable().fuse() => { + tokio::task::yield_now().await; + guard.unwrap().clear_ready() + }, + _ = notify_barrier.fuse() => unreachable!(), + } + + std::mem::drop(afd_a); + }; + + tasks.push(tokio::spawn(f)); + } + + let mut all_tasks = futures::future::try_join_all(tasks); + + tokio::select! { + r = std::pin::Pin::new(&mut all_tasks) => { + r.unwrap(); // propagate panic + panic!("Tasks exited unexpectedly") + }, + _ = barrier.wait() => {} + }; + + b.write_all(b"0").unwrap(); + + all_tasks.await.unwrap(); +} + +#[tokio::test] +async fn poll_fns() { + let (a, b) = socketpair(); + let afd_a = Arc::new(AsyncFd::new(a).unwrap()); + let afd_b = Arc::new(AsyncFd::new(b).unwrap()); + + // Fill up the write side of A + while afd_a.get_ref().write(&[0; 512]).is_ok() {} + + let waker = TestWaker::new(); + + assert_pending!(afd_a.as_ref().poll_read_ready(&mut waker.context())); + + let afd_a_2 = afd_a.clone(); + let r_barrier = Arc::new(tokio::sync::Barrier::new(2)); + let barrier_clone = r_barrier.clone(); + + let read_fut = tokio::spawn(async move { + // Move waker onto this task first + assert_pending!(poll!(futures::future::poll_fn(|cx| afd_a_2 + .as_ref() + .poll_read_ready(cx)))); + barrier_clone.wait().await; + + let _ = futures::future::poll_fn(|cx| afd_a_2.as_ref().poll_read_ready(cx)).await; + }); + + let afd_a_2 = afd_a.clone(); + let w_barrier = Arc::new(tokio::sync::Barrier::new(2)); + let barrier_clone = w_barrier.clone(); + + let mut write_fut = tokio::spawn(async move { + // Move waker onto this task first + assert_pending!(poll!(futures::future::poll_fn(|cx| afd_a_2 + .as_ref() + .poll_write_ready(cx)))); + barrier_clone.wait().await; + + let _ = futures::future::poll_fn(|cx| afd_a_2.as_ref().poll_write_ready(cx)).await; + }); + + r_barrier.wait().await; + w_barrier.wait().await; + + let readable = afd_a.readable(); + tokio::pin!(readable); + + tokio::select! { + _ = &mut readable => unreachable!(), + _ = tokio::task::yield_now() => {} + } + + // Make A readable. We expect that 'readable' and 'read_fut' will both complete quickly + afd_b.get_ref().write_all(b"0").unwrap(); + + let _ = tokio::join!(readable, read_fut); + + // Our original waker should _not_ be awoken (poll_read_ready retains only the last context) + assert!(!waker.awoken()); + + // The writable side should not be awoken + tokio::select! { + _ = &mut write_fut => unreachable!(), + _ = tokio::time::sleep(Duration::from_millis(5)) => {} + } + + // Make it writable now + drain(afd_b.get_ref()); + + // now we should be writable (ie - the waker for poll_write should still be registered after we wake the read side) + let _ = write_fut.await; +} + +fn assert_pending>(f: F) -> std::pin::Pin> { + let mut pinned = Box::pin(f); + + assert_pending!(pinned + .as_mut() + .poll(&mut Context::from_waker(futures::task::noop_waker_ref()))); + + pinned +} + +fn rt() -> tokio::runtime::Runtime { + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap() +} + +#[test] +fn driver_shutdown_wakes_currently_pending() { + let rt = rt(); + + let (a, _b) = socketpair(); + let afd_a = { + let _enter = rt.enter(); + AsyncFd::new(a).unwrap() + }; + + let readable = assert_pending(afd_a.readable()); + + std::mem::drop(rt); + + // Being awoken by a rt drop does not return an error, currently... + let _ = futures::executor::block_on(readable).unwrap(); + + // However, attempting to initiate a readiness wait when the rt is dropped is an error + assert_err!(futures::executor::block_on(afd_a.readable())); +} + +#[test] +fn driver_shutdown_wakes_future_pending() { + let rt = rt(); + + let (a, _b) = socketpair(); + let afd_a = { + let _enter = rt.enter(); + AsyncFd::new(a).unwrap() + }; + + std::mem::drop(rt); + + assert_err!(futures::executor::block_on(afd_a.readable())); +} + +#[test] +fn driver_shutdown_wakes_pending_race() { + // TODO: make this a loom test + for _ in 0..100 { + let rt = rt(); + + let (a, _b) = socketpair(); + let afd_a = { + let _enter = rt.enter(); + AsyncFd::new(a).unwrap() + }; + + let _ = std::thread::spawn(move || std::mem::drop(rt)); + + // This may or may not return an error (but will be awoken) + let _ = futures::executor::block_on(afd_a.readable()); + + // However retrying will always return an error + assert_err!(futures::executor::block_on(afd_a.readable())); + } +} + +async fn poll_readable(fd: &AsyncFd) -> std::io::Result> { + futures::future::poll_fn(|cx| fd.poll_read_ready(cx)).await +} + +async fn poll_writable(fd: &AsyncFd) -> std::io::Result> { + futures::future::poll_fn(|cx| fd.poll_write_ready(cx)).await +} + +#[test] +fn driver_shutdown_wakes_currently_pending_polls() { + let rt = rt(); + + let (a, _b) = socketpair(); + let afd_a = { + let _enter = rt.enter(); + AsyncFd::new(a).unwrap() + }; + + while afd_a.get_ref().write(&[0; 512]).is_ok() {} // make not writable + + let readable = assert_pending(poll_readable(&afd_a)); + let writable = assert_pending(poll_writable(&afd_a)); + + std::mem::drop(rt); + + // Attempting to poll readiness when the rt is dropped is an error + assert_err!(futures::executor::block_on(readable)); + assert_err!(futures::executor::block_on(writable)); +} + +#[test] +fn driver_shutdown_wakes_poll() { + let rt = rt(); + + let (a, _b) = socketpair(); + let afd_a = { + let _enter = rt.enter(); + AsyncFd::new(a).unwrap() + }; + + std::mem::drop(rt); + + assert_err!(futures::executor::block_on(poll_readable(&afd_a))); + assert_err!(futures::executor::block_on(poll_writable(&afd_a))); +} + +#[test] +fn driver_shutdown_wakes_poll_race() { + // TODO: make this a loom test + for _ in 0..100 { + let rt = rt(); + + let (a, _b) = socketpair(); + let afd_a = { + let _enter = rt.enter(); + AsyncFd::new(a).unwrap() + }; + + while afd_a.get_ref().write(&[0; 512]).is_ok() {} // make not writable + + let _ = std::thread::spawn(move || std::mem::drop(rt)); + + // The poll variants will always return an error in this case + assert_err!(futures::executor::block_on(poll_readable(&afd_a))); + assert_err!(futures::executor::block_on(poll_writable(&afd_a))); + } +}