diff --git a/tokio-util/src/lib.rs b/tokio-util/src/lib.rs index 3786a4002db..1099fe4993c 100644 --- a/tokio-util/src/lib.rs +++ b/tokio-util/src/lib.rs @@ -30,6 +30,7 @@ cfg_codec! { cfg_net! { pub mod udp; + pub mod net; } cfg_compat! { diff --git a/tokio-util/src/net/mod.rs b/tokio-util/src/net/mod.rs new file mode 100644 index 00000000000..4817e10d0f3 --- /dev/null +++ b/tokio-util/src/net/mod.rs @@ -0,0 +1,97 @@ +//! TCP/UDP/Unix helpers for tokio. + +use crate::either::Either; +use std::future::Future; +use std::io::Result; +use std::pin::Pin; +use std::task::{Context, Poll}; + +#[cfg(unix)] +pub mod unix; + +/// A trait for a listener: `TcpListener` and `UnixListener`. +pub trait Listener { + /// The stream's type of this listener. + type Io: tokio::io::AsyncRead + tokio::io::AsyncWrite; + /// The socket address type of this listener. + type Addr; + + /// Polls to accept a new incoming connection to this listener. + fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll>; + + /// Accepts a new incoming connection from this listener. + fn accept(&mut self) -> ListenerAcceptFut<'_, Self> + where + Self: Sized, + { + ListenerAcceptFut { listener: self } + } + + /// Returns the local address that this listener is bound to. + fn local_addr(&self) -> Result; +} + +impl Listener for tokio::net::TcpListener { + type Io = tokio::net::TcpStream; + type Addr = std::net::SocketAddr; + + fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll> { + Self::poll_accept(self, cx) + } + + fn local_addr(&self) -> Result { + self.local_addr().map(Into::into) + } +} + +/// Future for accepting a new connection from a listener. +#[derive(Debug)] +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct ListenerAcceptFut<'a, L> { + listener: &'a mut L, +} + +impl<'a, L> Future for ListenerAcceptFut<'a, L> +where + L: Listener, +{ + type Output = Result<(L::Io, L::Addr)>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.listener.poll_accept(cx) + } +} + +impl Either +where + L: Listener, + R: Listener, +{ + /// Accepts a new incoming connection from this listener. + pub async fn accept(&mut self) -> Result> { + match self { + Either::Left(listener) => { + let (stream, addr) = listener.accept().await?; + Ok(Either::Left((stream, addr))) + } + Either::Right(listener) => { + let (stream, addr) = listener.accept().await?; + Ok(Either::Right((stream, addr))) + } + } + } + + /// Returns the local address that this listener is bound to. + pub fn local_addr(&self) -> Result> { + match self { + Either::Left(listener) => { + let addr = listener.local_addr()?; + Ok(Either::Left(addr)) + } + Either::Right(listener) => { + let addr = listener.local_addr()?; + Ok(Either::Right(addr)) + } + } + } +} diff --git a/tokio-util/src/net/unix/mod.rs b/tokio-util/src/net/unix/mod.rs new file mode 100644 index 00000000000..0b522c90a34 --- /dev/null +++ b/tokio-util/src/net/unix/mod.rs @@ -0,0 +1,18 @@ +//! Unix domain socket helpers. + +use super::Listener; +use std::io::Result; +use std::task::{Context, Poll}; + +impl Listener for tokio::net::UnixListener { + type Io = tokio::net::UnixStream; + type Addr = tokio::net::unix::SocketAddr; + + fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll> { + Self::poll_accept(self, cx) + } + + fn local_addr(&self) -> Result { + self.local_addr().map(Into::into) + } +}