From 4889d97d0141f3c8bc8347bc75d457ad9ca8f01e Mon Sep 17 00:00:00 2001 From: Carl Lerche Date: Tue, 6 Oct 2020 15:46:14 -0700 Subject: [PATCH] net: use &self with TcpListener::accept Uses the infrastructure added by #2828 to enable switching `TcpListener::accept` to use `&self`. This also switches `poll_accept` to use `&self`. While doing introduces a hazard, `poll_*` style functions are considered low-level. Most users will use the `async fn` variants which are more misuse-resistant. TcpListener::incoming() is temporarily removed as it has the same problem as `TcpSocket::by_ref()` and will be implemented later. --- tokio/src/lib.rs | 2 +- tokio/src/net/tcp/incoming.rs | 42 -------------------- tokio/src/net/tcp/listener.rs | 75 +++++++++-------------------------- tokio/src/net/tcp/mod.rs | 4 -- tokio/src/runtime/mod.rs | 4 +- tokio/src/task/spawn.rs | 2 +- tokio/tests/buffered.rs | 2 +- tokio/tests/io_driver.rs | 2 +- tokio/tests/io_driver_drop.rs | 4 +- tokio/tests/rt_common.rs | 12 +++--- tokio/tests/rt_threaded.rs | 2 +- tokio/tests/tcp_accept.rs | 28 +++++-------- tokio/tests/tcp_connect.rs | 16 ++++---- tokio/tests/tcp_echo.rs | 2 +- tokio/tests/tcp_into_split.rs | 2 +- tokio/tests/tcp_shutdown.rs | 2 +- 16 files changed, 56 insertions(+), 145 deletions(-) delete mode 100644 tokio/src/net/tcp/incoming.rs diff --git a/tokio/src/lib.rs b/tokio/src/lib.rs index 1b0dad5d667..948ac888a56 100644 --- a/tokio/src/lib.rs +++ b/tokio/src/lib.rs @@ -306,7 +306,7 @@ //! //! #[tokio::main] //! async fn main() -> Result<(), Box> { -//! let mut listener = TcpListener::bind("127.0.0.1:8080").await?; +//! let listener = TcpListener::bind("127.0.0.1:8080").await?; //! //! loop { //! let (mut socket, _) = listener.accept().await?; diff --git a/tokio/src/net/tcp/incoming.rs b/tokio/src/net/tcp/incoming.rs deleted file mode 100644 index 062be1e9cf9..00000000000 --- a/tokio/src/net/tcp/incoming.rs +++ /dev/null @@ -1,42 +0,0 @@ -use crate::net::tcp::{TcpListener, TcpStream}; - -use std::io; -use std::pin::Pin; -use std::task::{Context, Poll}; - -/// Stream returned by the `TcpListener::incoming` function representing the -/// stream of sockets received from a listener. -#[must_use = "streams do nothing unless polled"] -#[derive(Debug)] -pub struct Incoming<'a> { - inner: &'a mut TcpListener, -} - -impl Incoming<'_> { - pub(crate) fn new(listener: &mut TcpListener) -> Incoming<'_> { - Incoming { inner: listener } - } - - /// Attempts to poll `TcpStream` by polling inner `TcpListener` to accept - /// connection. - /// - /// If `TcpListener` isn't ready yet, `Poll::Pending` is returned and - /// current task will be notified by a waker. - pub fn poll_accept( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - let (socket, _) = ready!(self.inner.poll_accept(cx))?; - Poll::Ready(Ok(socket)) - } -} - -#[cfg(feature = "stream")] -impl crate::stream::Stream for Incoming<'_> { - type Item = io::Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let (socket, _) = ready!(self.inner.poll_accept(cx))?; - Poll::Ready(Some(Ok(socket))) - } -} diff --git a/tokio/src/net/tcp/listener.rs b/tokio/src/net/tcp/listener.rs index 0ac03632c83..77f71fbffa7 100644 --- a/tokio/src/net/tcp/listener.rs +++ b/tokio/src/net/tcp/listener.rs @@ -1,6 +1,5 @@ -use crate::future::poll_fn; use crate::io::PollEvented; -use crate::net::tcp::{Incoming, TcpStream}; +use crate::net::tcp::TcpStream; use crate::net::{to_socket_addrs, ToSocketAddrs}; use std::convert::TryFrom; @@ -40,7 +39,7 @@ cfg_tcp! { /// /// #[tokio::main] /// async fn main() -> io::Result<()> { - /// let mut listener = TcpListener::bind("127.0.0.1:8080").await?; + /// let listener = TcpListener::bind("127.0.0.1:8080").await?; /// /// loop { /// let (socket, _) = listener.accept().await?; @@ -171,7 +170,7 @@ impl TcpListener { /// /// #[tokio::main] /// async fn main() -> io::Result<()> { - /// let mut listener = TcpListener::bind("127.0.0.1:8080").await?; + /// let listener = TcpListener::bind("127.0.0.1:8080").await?; /// /// match listener.accept().await { /// Ok((_socket, addr)) => println!("new client: {:?}", addr), @@ -181,18 +180,25 @@ impl TcpListener { /// Ok(()) /// } /// ``` - pub async fn accept(&mut self) -> io::Result<(TcpStream, SocketAddr)> { - poll_fn(|cx| self.poll_accept(cx)).await + pub async fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> { + let (mio, addr) = self + .io + .async_io(mio::Interest::READABLE, |sock| sock.accept()) + .await?; + + let stream = TcpStream::new(mio)?; + Ok((stream, addr)) } /// Polls to accept a new incoming connection to this listener. /// - /// If there is no connection to accept, `Poll::Pending` is returned and - /// the current task will be notified by a waker. - pub fn poll_accept( - &mut self, - cx: &mut Context<'_>, - ) -> Poll> { + /// If there is no connection to accept, `Poll::Pending` is returned and the + /// current task will be notified by a waker. + /// + /// When ready, the most recent task that called `poll_accept` is notified. + /// The caller is responsble to ensure that `poll_accept` is called from a + /// single task. Failing to do this could result in tasks hanging. + pub fn poll_accept(&self, cx: &mut Context<'_>) -> Poll> { loop { let ev = ready!(self.io.poll_read_ready(cx))?; @@ -293,46 +299,6 @@ impl TcpListener { self.io.get_ref().local_addr() } - /// Returns a stream over the connections being received on this listener. - /// - /// Note that `TcpListener` also directly implements `Stream`. - /// - /// The returned stream will never return `None` and will also not yield the - /// peer's `SocketAddr` structure. Iterating over it is equivalent to - /// calling accept in a loop. - /// - /// # Errors - /// - /// Note that accepting a connection can lead to various errors and not all - /// of them are necessarily fatal ‒ for example having too many open file - /// descriptors or the other side closing the connection while it waits in - /// an accept queue. These would terminate the stream if not handled in any - /// way. - /// - /// # Examples - /// - /// ```no_run - /// use tokio::{net::TcpListener, stream::StreamExt}; - /// - /// #[tokio::main] - /// async fn main() { - /// let mut listener = TcpListener::bind("127.0.0.1:8080").await.unwrap(); - /// let mut incoming = listener.incoming(); - /// - /// while let Some(stream) = incoming.next().await { - /// match stream { - /// Ok(stream) => { - /// println!("new client!"); - /// } - /// Err(e) => { /* connection failed */ } - /// } - /// } - /// } - /// ``` - pub fn incoming(&mut self) -> Incoming<'_> { - Incoming::new(self) - } - /// Gets the value of the `IP_TTL` option for this socket. /// /// For more information about this option, see [`set_ttl`]. @@ -390,10 +356,7 @@ impl TcpListener { impl crate::stream::Stream for TcpListener { type Item = io::Result; - fn poll_next( - mut self: std::pin::Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { + fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let (socket, _) = ready!(self.poll_accept(cx))?; Poll::Ready(Some(Ok(socket))) } diff --git a/tokio/src/net/tcp/mod.rs b/tokio/src/net/tcp/mod.rs index 7ad36eb0b11..516770e18d2 100644 --- a/tokio/src/net/tcp/mod.rs +++ b/tokio/src/net/tcp/mod.rs @@ -1,10 +1,6 @@ //! TCP utility types pub(crate) mod listener; -pub(crate) use listener::TcpListener; - -mod incoming; -pub use incoming::Incoming; mod split; pub use split::{ReadHalf, WriteHalf}; diff --git a/tokio/src/runtime/mod.rs b/tokio/src/runtime/mod.rs index a6a739bec36..22109f7d155 100644 --- a/tokio/src/runtime/mod.rs +++ b/tokio/src/runtime/mod.rs @@ -25,7 +25,7 @@ //! //! #[tokio::main] //! async fn main() -> Result<(), Box> { -//! let mut listener = TcpListener::bind("127.0.0.1:8080").await?; +//! let listener = TcpListener::bind("127.0.0.1:8080").await?; //! //! loop { //! let (mut socket, _) = listener.accept().await?; @@ -73,7 +73,7 @@ //! //! // Spawn the root task //! rt.block_on(async { -//! let mut listener = TcpListener::bind("127.0.0.1:8080").await?; +//! let listener = TcpListener::bind("127.0.0.1:8080").await?; //! //! loop { //! let (mut socket, _) = listener.accept().await?; diff --git a/tokio/src/task/spawn.rs b/tokio/src/task/spawn.rs index 280e90ead04..d7aca5723cb 100644 --- a/tokio/src/task/spawn.rs +++ b/tokio/src/task/spawn.rs @@ -37,7 +37,7 @@ doc_rt_core! { /// /// #[tokio::main] /// async fn main() -> io::Result<()> { - /// let mut listener = TcpListener::bind("127.0.0.1:8080").await?; + /// let listener = TcpListener::bind("127.0.0.1:8080").await?; /// /// loop { /// let (socket, _) = listener.accept().await?; diff --git a/tokio/tests/buffered.rs b/tokio/tests/buffered.rs index 595f855a0f7..97ba00cd1bf 100644 --- a/tokio/tests/buffered.rs +++ b/tokio/tests/buffered.rs @@ -13,7 +13,7 @@ use std::thread; async fn echo_server() { const N: usize = 1024; - let mut srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await); + let srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await); let addr = assert_ok!(srv.local_addr()); let msg = "foo bar baz"; diff --git a/tokio/tests/io_driver.rs b/tokio/tests/io_driver.rs index d4f4f8d48cf..01be36599a6 100644 --- a/tokio/tests/io_driver.rs +++ b/tokio/tests/io_driver.rs @@ -56,7 +56,7 @@ fn test_drop_on_notify() { // Define a task that just drains the listener let task = Arc::new(Task::new(async move { // Create a listener - let mut listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await); + let listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await); // Send the address let addr = listener.local_addr().unwrap(); diff --git a/tokio/tests/io_driver_drop.rs b/tokio/tests/io_driver_drop.rs index 0a5ce62513b..2ee02a4276b 100644 --- a/tokio/tests/io_driver_drop.rs +++ b/tokio/tests/io_driver_drop.rs @@ -9,7 +9,7 @@ use tokio_test::{assert_err, assert_pending, assert_ready, task}; fn tcp_doesnt_block() { let rt = rt(); - let mut listener = rt.enter(|| { + let listener = rt.enter(|| { let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); TcpListener::from_std(listener).unwrap() }); @@ -27,7 +27,7 @@ fn tcp_doesnt_block() { fn drop_wakes() { let rt = rt(); - let mut listener = rt.enter(|| { + let listener = rt.enter(|| { let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); TcpListener::from_std(listener).unwrap() }); diff --git a/tokio/tests/rt_common.rs b/tokio/tests/rt_common.rs index 3e95c2aa4ce..93d6a44e630 100644 --- a/tokio/tests/rt_common.rs +++ b/tokio/tests/rt_common.rs @@ -471,7 +471,7 @@ rt_test! { rt.block_on(async move { let (tx, rx) = oneshot::channel(); - let mut listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); tokio::spawn(async move { @@ -539,7 +539,7 @@ rt_test! { let rt = rt(); rt.block_on(async move { - let mut listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await); + let listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await); let addr = assert_ok!(listener.local_addr()); let peer = tokio::task::spawn_blocking(move || { @@ -634,7 +634,7 @@ rt_test! { // Do some I/O work rt.block_on(async { - let mut listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await); + let listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await); let addr = assert_ok!(listener.local_addr()); let srv = tokio::spawn(async move { @@ -912,7 +912,7 @@ rt_test! { } async fn client_server(tx: mpsc::Sender<()>) { - let mut server = assert_ok!(TcpListener::bind("127.0.0.1:0").await); + let server = assert_ok!(TcpListener::bind("127.0.0.1:0").await); // Get the assigned address let addr = assert_ok!(server.local_addr()); @@ -943,7 +943,7 @@ rt_test! { local.block_on(&rt, async move { let (tx, rx) = oneshot::channel(); - let mut listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); task::spawn_local(async move { @@ -970,7 +970,7 @@ rt_test! { } async fn client_server_local(tx: mpsc::Sender<()>) { - let mut server = assert_ok!(TcpListener::bind("127.0.0.1:0").await); + let server = assert_ok!(TcpListener::bind("127.0.0.1:0").await); // Get the assigned address let addr = assert_ok!(server.local_addr()); diff --git a/tokio/tests/rt_threaded.rs b/tokio/tests/rt_threaded.rs index 2c7cfb80c1b..1ac6ed32428 100644 --- a/tokio/tests/rt_threaded.rs +++ b/tokio/tests/rt_threaded.rs @@ -139,7 +139,7 @@ fn spawn_shutdown() { } async fn client_server(tx: mpsc::Sender<()>) { - let mut server = assert_ok!(TcpListener::bind("127.0.0.1:0").await); + let server = assert_ok!(TcpListener::bind("127.0.0.1:0").await); // Get the assigned address let addr = assert_ok!(server.local_addr()); diff --git a/tokio/tests/tcp_accept.rs b/tokio/tests/tcp_accept.rs index 9f5b441468d..fd499c3523e 100644 --- a/tokio/tests/tcp_accept.rs +++ b/tokio/tests/tcp_accept.rs @@ -5,6 +5,7 @@ use tokio::net::{TcpListener, TcpStream}; use tokio::sync::{mpsc, oneshot}; use tokio_test::assert_ok; +use std::io; use std::net::{IpAddr, SocketAddr}; macro_rules! test_accept { @@ -12,7 +13,7 @@ macro_rules! test_accept { $( #[tokio::test] async fn $ident() { - let mut listener = assert_ok!(TcpListener::bind($target).await); + let listener = assert_ok!(TcpListener::bind($target).await); let addr = listener.local_addr().unwrap(); let (tx, rx) = oneshot::channel(); @@ -39,7 +40,6 @@ test_accept! { (ip_port_tuple, ("127.0.0.1".parse::().unwrap(), 0)), } -use pin_project_lite::pin_project; use std::pin::Pin; use std::sync::{ atomic::{AtomicUsize, Ordering::SeqCst}, @@ -48,23 +48,17 @@ use std::sync::{ use std::task::{Context, Poll}; use tokio::stream::{Stream, StreamExt}; -pin_project! { - struct TrackPolls { - npolls: Arc, - #[pin] - s: S, - } +struct TrackPolls<'a> { + npolls: Arc, + listener: &'a mut TcpListener, } -impl Stream for TrackPolls -where - S: Stream, -{ - type Item = S::Item; +impl<'a> Stream for TrackPolls<'a> { + type Item = io::Result<(TcpStream, SocketAddr)>; + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.project(); - this.npolls.fetch_add(1, SeqCst); - this.s.poll_next(cx) + self.npolls.fetch_add(1, SeqCst); + self.listener.poll_accept(cx).map(Some) } } @@ -79,7 +73,7 @@ async fn no_extra_poll() { tokio::spawn(async move { let mut incoming = TrackPolls { npolls: Arc::new(AtomicUsize::new(0)), - s: listener.incoming(), + listener: &mut listener, }; assert_ok!(tx.send(Arc::clone(&incoming.npolls))); while incoming.next().await.is_some() { diff --git a/tokio/tests/tcp_connect.rs b/tokio/tests/tcp_connect.rs index de1cead829e..44942c4e979 100644 --- a/tokio/tests/tcp_connect.rs +++ b/tokio/tests/tcp_connect.rs @@ -9,7 +9,7 @@ use futures::join; #[tokio::test] async fn connect_v4() { - let mut srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await); + let srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await); let addr = assert_ok!(srv.local_addr()); assert!(addr.is_ipv4()); @@ -36,7 +36,7 @@ async fn connect_v4() { #[tokio::test] async fn connect_v6() { - let mut srv = assert_ok!(TcpListener::bind("[::1]:0").await); + let srv = assert_ok!(TcpListener::bind("[::1]:0").await); let addr = assert_ok!(srv.local_addr()); assert!(addr.is_ipv6()); @@ -63,7 +63,7 @@ async fn connect_v6() { #[tokio::test] async fn connect_addr_ip_string() { - let mut srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await); + let srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await); let addr = assert_ok!(srv.local_addr()); let addr = format!("127.0.0.1:{}", addr.port()); @@ -80,7 +80,7 @@ async fn connect_addr_ip_string() { #[tokio::test] async fn connect_addr_ip_str_slice() { - let mut srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await); + let srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await); let addr = assert_ok!(srv.local_addr()); let addr = format!("127.0.0.1:{}", addr.port()); @@ -97,7 +97,7 @@ async fn connect_addr_ip_str_slice() { #[tokio::test] async fn connect_addr_host_string() { - let mut srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await); + let srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await); let addr = assert_ok!(srv.local_addr()); let addr = format!("localhost:{}", addr.port()); @@ -114,7 +114,7 @@ async fn connect_addr_host_string() { #[tokio::test] async fn connect_addr_ip_port_tuple() { - let mut srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await); + let srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await); let addr = assert_ok!(srv.local_addr()); let addr = (addr.ip(), addr.port()); @@ -131,7 +131,7 @@ async fn connect_addr_ip_port_tuple() { #[tokio::test] async fn connect_addr_ip_str_port_tuple() { - let mut srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await); + let srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await); let addr = assert_ok!(srv.local_addr()); let addr = ("127.0.0.1", addr.port()); @@ -148,7 +148,7 @@ async fn connect_addr_ip_str_port_tuple() { #[tokio::test] async fn connect_addr_host_str_port_tuple() { - let mut srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await); + let srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await); let addr = assert_ok!(srv.local_addr()); let addr = ("localhost", addr.port()); diff --git a/tokio/tests/tcp_echo.rs b/tokio/tests/tcp_echo.rs index 1feba63ee73..d9cb456ff6b 100644 --- a/tokio/tests/tcp_echo.rs +++ b/tokio/tests/tcp_echo.rs @@ -12,7 +12,7 @@ async fn echo_server() { let (tx, rx) = oneshot::channel(); - let mut srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await); + let srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await); let addr = assert_ok!(srv.local_addr()); let msg = "foo bar baz"; diff --git a/tokio/tests/tcp_into_split.rs b/tokio/tests/tcp_into_split.rs index 86ed461923d..b4bb2eeb99c 100644 --- a/tokio/tests/tcp_into_split.rs +++ b/tokio/tests/tcp_into_split.rs @@ -13,7 +13,7 @@ use tokio::try_join; async fn split() -> Result<()> { const MSG: &[u8] = b"split"; - let mut listener = TcpListener::bind("127.0.0.1:0").await?; + let listener = TcpListener::bind("127.0.0.1:0").await?; let addr = listener.local_addr()?; let (stream1, (mut stream2, _)) = try_join! { diff --git a/tokio/tests/tcp_shutdown.rs b/tokio/tests/tcp_shutdown.rs index bd43e143b8d..615855f1b5f 100644 --- a/tokio/tests/tcp_shutdown.rs +++ b/tokio/tests/tcp_shutdown.rs @@ -8,7 +8,7 @@ use tokio_test::assert_ok; #[tokio::test] async fn shutdown() { - let mut srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await); + let srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await); let addr = assert_ok!(srv.local_addr()); tokio::spawn(async move {