Skip to content

Commit

Permalink
net: use &self with TcpListener::accept
Browse files Browse the repository at this point in the history
Uses the infrastructure added by #2828 to enable switching
`TcpListener::accept` to use `&self`.

This also switches `poll_accept` to use `&self`. While doing introduces
a hazard, `poll_*` style functions are considered low-level. Most users
will use the `async fn` variants which are more misuse-resistant.

TcpListener::incoming() is temporarily removed as it has the same
problem as `TcpSocket::by_ref()` and will be implemented later.
  • Loading branch information
carllerche committed Oct 6, 2020
1 parent fcdf934 commit 4889d97
Show file tree
Hide file tree
Showing 16 changed files with 56 additions and 145 deletions.
2 changes: 1 addition & 1 deletion tokio/src/lib.rs
Expand Up @@ -306,7 +306,7 @@
//!
//! #[tokio::main]
//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
//! let mut listener = TcpListener::bind("127.0.0.1:8080").await?;
//! let listener = TcpListener::bind("127.0.0.1:8080").await?;
//!
//! loop {
//! let (mut socket, _) = listener.accept().await?;
Expand Down
42 changes: 0 additions & 42 deletions tokio/src/net/tcp/incoming.rs

This file was deleted.

75 changes: 19 additions & 56 deletions tokio/src/net/tcp/listener.rs
@@ -1,6 +1,5 @@
use crate::future::poll_fn;
use crate::io::PollEvented;
use crate::net::tcp::{Incoming, TcpStream};
use crate::net::tcp::TcpStream;
use crate::net::{to_socket_addrs, ToSocketAddrs};

use std::convert::TryFrom;
Expand Down Expand Up @@ -40,7 +39,7 @@ cfg_tcp! {
///
/// #[tokio::main]
/// async fn main() -> io::Result<()> {
/// let mut listener = TcpListener::bind("127.0.0.1:8080").await?;
/// let listener = TcpListener::bind("127.0.0.1:8080").await?;
///
/// loop {
/// let (socket, _) = listener.accept().await?;
Expand Down Expand Up @@ -171,7 +170,7 @@ impl TcpListener {
///
/// #[tokio::main]
/// async fn main() -> io::Result<()> {
/// let mut listener = TcpListener::bind("127.0.0.1:8080").await?;
/// let listener = TcpListener::bind("127.0.0.1:8080").await?;
///
/// match listener.accept().await {
/// Ok((_socket, addr)) => println!("new client: {:?}", addr),
Expand All @@ -181,18 +180,25 @@ impl TcpListener {
/// Ok(())
/// }
/// ```
pub async fn accept(&mut self) -> io::Result<(TcpStream, SocketAddr)> {
poll_fn(|cx| self.poll_accept(cx)).await
pub async fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
let (mio, addr) = self
.io
.async_io(mio::Interest::READABLE, |sock| sock.accept())
.await?;

let stream = TcpStream::new(mio)?;
Ok((stream, addr))
}

/// Polls to accept a new incoming connection to this listener.
///
/// If there is no connection to accept, `Poll::Pending` is returned and
/// the current task will be notified by a waker.
pub fn poll_accept(
&mut self,
cx: &mut Context<'_>,
) -> Poll<io::Result<(TcpStream, SocketAddr)>> {
/// If there is no connection to accept, `Poll::Pending` is returned and the
/// current task will be notified by a waker.
///
/// When ready, the most recent task that called `poll_accept` is notified.
/// The caller is responsble to ensure that `poll_accept` is called from a
/// single task. Failing to do this could result in tasks hanging.
pub fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<io::Result<(TcpStream, SocketAddr)>> {
loop {
let ev = ready!(self.io.poll_read_ready(cx))?;

Expand Down Expand Up @@ -293,46 +299,6 @@ impl TcpListener {
self.io.get_ref().local_addr()
}

/// Returns a stream over the connections being received on this listener.
///
/// Note that `TcpListener` also directly implements `Stream`.
///
/// The returned stream will never return `None` and will also not yield the
/// peer's `SocketAddr` structure. Iterating over it is equivalent to
/// calling accept in a loop.
///
/// # Errors
///
/// Note that accepting a connection can lead to various errors and not all
/// of them are necessarily fatal ‒ for example having too many open file
/// descriptors or the other side closing the connection while it waits in
/// an accept queue. These would terminate the stream if not handled in any
/// way.
///
/// # Examples
///
/// ```no_run
/// use tokio::{net::TcpListener, stream::StreamExt};
///
/// #[tokio::main]
/// async fn main() {
/// let mut listener = TcpListener::bind("127.0.0.1:8080").await.unwrap();
/// let mut incoming = listener.incoming();
///
/// while let Some(stream) = incoming.next().await {
/// match stream {
/// Ok(stream) => {
/// println!("new client!");
/// }
/// Err(e) => { /* connection failed */ }
/// }
/// }
/// }
/// ```
pub fn incoming(&mut self) -> Incoming<'_> {
Incoming::new(self)
}

/// Gets the value of the `IP_TTL` option for this socket.
///
/// For more information about this option, see [`set_ttl`].
Expand Down Expand Up @@ -390,10 +356,7 @@ impl TcpListener {
impl crate::stream::Stream for TcpListener {
type Item = io::Result<TcpStream>;

fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let (socket, _) = ready!(self.poll_accept(cx))?;
Poll::Ready(Some(Ok(socket)))
}
Expand Down
4 changes: 0 additions & 4 deletions tokio/src/net/tcp/mod.rs
@@ -1,10 +1,6 @@
//! TCP utility types

pub(crate) mod listener;
pub(crate) use listener::TcpListener;

mod incoming;
pub use incoming::Incoming;

mod split;
pub use split::{ReadHalf, WriteHalf};
Expand Down
4 changes: 2 additions & 2 deletions tokio/src/runtime/mod.rs
Expand Up @@ -25,7 +25,7 @@
//!
//! #[tokio::main]
//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
//! let mut listener = TcpListener::bind("127.0.0.1:8080").await?;
//! let listener = TcpListener::bind("127.0.0.1:8080").await?;
//!
//! loop {
//! let (mut socket, _) = listener.accept().await?;
Expand Down Expand Up @@ -73,7 +73,7 @@
//!
//! // Spawn the root task
//! rt.block_on(async {
//! let mut listener = TcpListener::bind("127.0.0.1:8080").await?;
//! let listener = TcpListener::bind("127.0.0.1:8080").await?;
//!
//! loop {
//! let (mut socket, _) = listener.accept().await?;
Expand Down
2 changes: 1 addition & 1 deletion tokio/src/task/spawn.rs
Expand Up @@ -37,7 +37,7 @@ doc_rt_core! {
///
/// #[tokio::main]
/// async fn main() -> io::Result<()> {
/// let mut listener = TcpListener::bind("127.0.0.1:8080").await?;
/// let listener = TcpListener::bind("127.0.0.1:8080").await?;
///
/// loop {
/// let (socket, _) = listener.accept().await?;
Expand Down
2 changes: 1 addition & 1 deletion tokio/tests/buffered.rs
Expand Up @@ -13,7 +13,7 @@ use std::thread;
async fn echo_server() {
const N: usize = 1024;

let mut srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
let srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
let addr = assert_ok!(srv.local_addr());

let msg = "foo bar baz";
Expand Down
2 changes: 1 addition & 1 deletion tokio/tests/io_driver.rs
Expand Up @@ -56,7 +56,7 @@ fn test_drop_on_notify() {
// Define a task that just drains the listener
let task = Arc::new(Task::new(async move {
// Create a listener
let mut listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
let listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await);

// Send the address
let addr = listener.local_addr().unwrap();
Expand Down
4 changes: 2 additions & 2 deletions tokio/tests/io_driver_drop.rs
Expand Up @@ -9,7 +9,7 @@ use tokio_test::{assert_err, assert_pending, assert_ready, task};
fn tcp_doesnt_block() {
let rt = rt();

let mut listener = rt.enter(|| {
let listener = rt.enter(|| {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
TcpListener::from_std(listener).unwrap()
});
Expand All @@ -27,7 +27,7 @@ fn tcp_doesnt_block() {
fn drop_wakes() {
let rt = rt();

let mut listener = rt.enter(|| {
let listener = rt.enter(|| {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
TcpListener::from_std(listener).unwrap()
});
Expand Down
12 changes: 6 additions & 6 deletions tokio/tests/rt_common.rs
Expand Up @@ -471,7 +471,7 @@ rt_test! {
rt.block_on(async move {
let (tx, rx) = oneshot::channel();

let mut listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();

tokio::spawn(async move {
Expand Down Expand Up @@ -539,7 +539,7 @@ rt_test! {
let rt = rt();

rt.block_on(async move {
let mut listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
let listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
let addr = assert_ok!(listener.local_addr());

let peer = tokio::task::spawn_blocking(move || {
Expand Down Expand Up @@ -634,7 +634,7 @@ rt_test! {

// Do some I/O work
rt.block_on(async {
let mut listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
let listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
let addr = assert_ok!(listener.local_addr());

let srv = tokio::spawn(async move {
Expand Down Expand Up @@ -912,7 +912,7 @@ rt_test! {
}

async fn client_server(tx: mpsc::Sender<()>) {
let mut server = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
let server = assert_ok!(TcpListener::bind("127.0.0.1:0").await);

// Get the assigned address
let addr = assert_ok!(server.local_addr());
Expand Down Expand Up @@ -943,7 +943,7 @@ rt_test! {
local.block_on(&rt, async move {
let (tx, rx) = oneshot::channel();

let mut listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();

task::spawn_local(async move {
Expand All @@ -970,7 +970,7 @@ rt_test! {
}

async fn client_server_local(tx: mpsc::Sender<()>) {
let mut server = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
let server = assert_ok!(TcpListener::bind("127.0.0.1:0").await);

// Get the assigned address
let addr = assert_ok!(server.local_addr());
Expand Down
2 changes: 1 addition & 1 deletion tokio/tests/rt_threaded.rs
Expand Up @@ -139,7 +139,7 @@ fn spawn_shutdown() {
}

async fn client_server(tx: mpsc::Sender<()>) {
let mut server = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
let server = assert_ok!(TcpListener::bind("127.0.0.1:0").await);

// Get the assigned address
let addr = assert_ok!(server.local_addr());
Expand Down
28 changes: 11 additions & 17 deletions tokio/tests/tcp_accept.rs
Expand Up @@ -5,14 +5,15 @@ use tokio::net::{TcpListener, TcpStream};
use tokio::sync::{mpsc, oneshot};
use tokio_test::assert_ok;

use std::io;
use std::net::{IpAddr, SocketAddr};

macro_rules! test_accept {
($(($ident:ident, $target:expr),)*) => {
$(
#[tokio::test]
async fn $ident() {
let mut listener = assert_ok!(TcpListener::bind($target).await);
let listener = assert_ok!(TcpListener::bind($target).await);
let addr = listener.local_addr().unwrap();

let (tx, rx) = oneshot::channel();
Expand All @@ -39,7 +40,6 @@ test_accept! {
(ip_port_tuple, ("127.0.0.1".parse::<IpAddr>().unwrap(), 0)),
}

use pin_project_lite::pin_project;
use std::pin::Pin;
use std::sync::{
atomic::{AtomicUsize, Ordering::SeqCst},
Expand All @@ -48,23 +48,17 @@ use std::sync::{
use std::task::{Context, Poll};
use tokio::stream::{Stream, StreamExt};

pin_project! {
struct TrackPolls<S> {
npolls: Arc<AtomicUsize>,
#[pin]
s: S,
}
struct TrackPolls<'a> {
npolls: Arc<AtomicUsize>,
listener: &'a mut TcpListener,
}

impl<S> Stream for TrackPolls<S>
where
S: Stream,
{
type Item = S::Item;
impl<'a> Stream for TrackPolls<'a> {
type Item = io::Result<(TcpStream, SocketAddr)>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
this.npolls.fetch_add(1, SeqCst);
this.s.poll_next(cx)
self.npolls.fetch_add(1, SeqCst);
self.listener.poll_accept(cx).map(Some)
}
}

Expand All @@ -79,7 +73,7 @@ async fn no_extra_poll() {
tokio::spawn(async move {
let mut incoming = TrackPolls {
npolls: Arc::new(AtomicUsize::new(0)),
s: listener.incoming(),
listener: &mut listener,
};
assert_ok!(tx.send(Arc::clone(&incoming.npolls)));
while incoming.next().await.is_some() {
Expand Down

0 comments on commit 4889d97

Please sign in to comment.