Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

net: use &self with TcpListener::accept #2919

Merged
merged 9 commits into from Oct 8, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion tokio/src/lib.rs
Expand Up @@ -306,7 +306,7 @@
//!
//! #[tokio::main]
//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
//! let mut listener = TcpListener::bind("127.0.0.1:8080").await?;
//! let listener = TcpListener::bind("127.0.0.1:8080").await?;
//!
//! loop {
//! let (mut socket, _) = listener.accept().await?;
Expand Down
42 changes: 0 additions & 42 deletions tokio/src/net/tcp/incoming.rs

This file was deleted.

75 changes: 19 additions & 56 deletions tokio/src/net/tcp/listener.rs
@@ -1,6 +1,5 @@
use crate::future::poll_fn;
use crate::io::PollEvented;
use crate::net::tcp::{Incoming, TcpStream};
use crate::net::tcp::TcpStream;
use crate::net::{to_socket_addrs, ToSocketAddrs};

use std::convert::TryFrom;
Expand Down Expand Up @@ -40,7 +39,7 @@ cfg_tcp! {
///
/// #[tokio::main]
/// async fn main() -> io::Result<()> {
/// let mut listener = TcpListener::bind("127.0.0.1:8080").await?;
/// let listener = TcpListener::bind("127.0.0.1:8080").await?;
///
/// loop {
/// let (socket, _) = listener.accept().await?;
Expand Down Expand Up @@ -171,7 +170,7 @@ impl TcpListener {
///
/// #[tokio::main]
/// async fn main() -> io::Result<()> {
/// let mut listener = TcpListener::bind("127.0.0.1:8080").await?;
/// let listener = TcpListener::bind("127.0.0.1:8080").await?;
///
/// match listener.accept().await {
/// Ok((_socket, addr)) => println!("new client: {:?}", addr),
Expand All @@ -181,18 +180,25 @@ impl TcpListener {
/// Ok(())
/// }
/// ```
pub async fn accept(&mut self) -> io::Result<(TcpStream, SocketAddr)> {
poll_fn(|cx| self.poll_accept(cx)).await
pub async fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
let (mio, addr) = self
.io
.async_io(mio::Interest::READABLE, |sock| sock.accept())
.await?;

let stream = TcpStream::new(mio)?;
Ok((stream, addr))
}

/// Polls to accept a new incoming connection to this listener.
///
/// If there is no connection to accept, `Poll::Pending` is returned and
/// the current task will be notified by a waker.
pub fn poll_accept(
&mut self,
cx: &mut Context<'_>,
) -> Poll<io::Result<(TcpStream, SocketAddr)>> {
/// If there is no connection to accept, `Poll::Pending` is returned and the
/// current task will be notified by a waker.
///
/// When ready, the most recent task that called `poll_accept` is notified.
/// The caller is responsble to ensure that `poll_accept` is called from a
/// single task. Failing to do this could result in tasks hanging.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should eventually include language that poll_accept should only be used for 1 task.

pub fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<io::Result<(TcpStream, SocketAddr)>> {
loop {
let ev = ready!(self.io.poll_read_ready(cx))?;

Expand Down Expand Up @@ -293,46 +299,6 @@ impl TcpListener {
self.io.get_ref().local_addr()
}

/// Returns a stream over the connections being received on this listener.
///
/// Note that `TcpListener` also directly implements `Stream`.
///
/// The returned stream will never return `None` and will also not yield the
/// peer's `SocketAddr` structure. Iterating over it is equivalent to
/// calling accept in a loop.
///
/// # Errors
///
/// Note that accepting a connection can lead to various errors and not all
/// of them are necessarily fatal ‒ for example having too many open file
/// descriptors or the other side closing the connection while it waits in
/// an accept queue. These would terminate the stream if not handled in any
/// way.
///
/// # Examples
///
/// ```no_run
/// use tokio::{net::TcpListener, stream::StreamExt};
///
/// #[tokio::main]
/// async fn main() {
/// let mut listener = TcpListener::bind("127.0.0.1:8080").await.unwrap();
/// let mut incoming = listener.incoming();
///
/// while let Some(stream) = incoming.next().await {
/// match stream {
/// Ok(stream) => {
/// println!("new client!");
/// }
/// Err(e) => { /* connection failed */ }
/// }
/// }
/// }
/// ```
pub fn incoming(&mut self) -> Incoming<'_> {
Incoming::new(self)
}

/// Gets the value of the `IP_TTL` option for this socket.
///
/// For more information about this option, see [`set_ttl`].
Expand Down Expand Up @@ -390,10 +356,7 @@ impl TcpListener {
impl crate::stream::Stream for TcpListener {
type Item = io::Result<TcpStream>;

fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let (socket, _) = ready!(self.poll_accept(cx))?;
Poll::Ready(Some(Ok(socket)))
}
Expand Down
4 changes: 0 additions & 4 deletions tokio/src/net/tcp/mod.rs
@@ -1,10 +1,6 @@
//! TCP utility types

pub(crate) mod listener;
pub(crate) use listener::TcpListener;

mod incoming;
pub use incoming::Incoming;

mod split;
pub use split::{ReadHalf, WriteHalf};
Expand Down
4 changes: 2 additions & 2 deletions tokio/src/runtime/mod.rs
Expand Up @@ -25,7 +25,7 @@
//!
//! #[tokio::main]
//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
//! let mut listener = TcpListener::bind("127.0.0.1:8080").await?;
//! let listener = TcpListener::bind("127.0.0.1:8080").await?;
//!
//! loop {
//! let (mut socket, _) = listener.accept().await?;
Expand Down Expand Up @@ -73,7 +73,7 @@
//!
//! // Spawn the root task
//! rt.block_on(async {
//! let mut listener = TcpListener::bind("127.0.0.1:8080").await?;
//! let listener = TcpListener::bind("127.0.0.1:8080").await?;
//!
//! loop {
//! let (mut socket, _) = listener.accept().await?;
Expand Down
2 changes: 1 addition & 1 deletion tokio/src/task/spawn.rs
Expand Up @@ -37,7 +37,7 @@ doc_rt_core! {
///
/// #[tokio::main]
/// async fn main() -> io::Result<()> {
/// let mut listener = TcpListener::bind("127.0.0.1:8080").await?;
/// let listener = TcpListener::bind("127.0.0.1:8080").await?;
///
/// loop {
/// let (socket, _) = listener.accept().await?;
Expand Down
2 changes: 1 addition & 1 deletion tokio/tests/buffered.rs
Expand Up @@ -13,7 +13,7 @@ use std::thread;
async fn echo_server() {
const N: usize = 1024;

let mut srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
let srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
let addr = assert_ok!(srv.local_addr());

let msg = "foo bar baz";
Expand Down
2 changes: 1 addition & 1 deletion tokio/tests/io_driver.rs
Expand Up @@ -56,7 +56,7 @@ fn test_drop_on_notify() {
// Define a task that just drains the listener
let task = Arc::new(Task::new(async move {
// Create a listener
let mut listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
let listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await);

// Send the address
let addr = listener.local_addr().unwrap();
Expand Down
4 changes: 2 additions & 2 deletions tokio/tests/io_driver_drop.rs
Expand Up @@ -9,7 +9,7 @@ use tokio_test::{assert_err, assert_pending, assert_ready, task};
fn tcp_doesnt_block() {
let rt = rt();

let mut listener = rt.enter(|| {
let listener = rt.enter(|| {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
TcpListener::from_std(listener).unwrap()
});
Expand All @@ -27,7 +27,7 @@ fn tcp_doesnt_block() {
fn drop_wakes() {
let rt = rt();

let mut listener = rt.enter(|| {
let listener = rt.enter(|| {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
TcpListener::from_std(listener).unwrap()
});
Expand Down
12 changes: 6 additions & 6 deletions tokio/tests/rt_common.rs
Expand Up @@ -471,7 +471,7 @@ rt_test! {
rt.block_on(async move {
let (tx, rx) = oneshot::channel();

let mut listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();

tokio::spawn(async move {
Expand Down Expand Up @@ -539,7 +539,7 @@ rt_test! {
let rt = rt();

rt.block_on(async move {
let mut listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
let listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
let addr = assert_ok!(listener.local_addr());

let peer = tokio::task::spawn_blocking(move || {
Expand Down Expand Up @@ -634,7 +634,7 @@ rt_test! {

// Do some I/O work
rt.block_on(async {
let mut listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
let listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
let addr = assert_ok!(listener.local_addr());

let srv = tokio::spawn(async move {
Expand Down Expand Up @@ -912,7 +912,7 @@ rt_test! {
}

async fn client_server(tx: mpsc::Sender<()>) {
let mut server = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
let server = assert_ok!(TcpListener::bind("127.0.0.1:0").await);

// Get the assigned address
let addr = assert_ok!(server.local_addr());
Expand Down Expand Up @@ -943,7 +943,7 @@ rt_test! {
local.block_on(&rt, async move {
let (tx, rx) = oneshot::channel();

let mut listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();

task::spawn_local(async move {
Expand All @@ -970,7 +970,7 @@ rt_test! {
}

async fn client_server_local(tx: mpsc::Sender<()>) {
let mut server = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
let server = assert_ok!(TcpListener::bind("127.0.0.1:0").await);

// Get the assigned address
let addr = assert_ok!(server.local_addr());
Expand Down
2 changes: 1 addition & 1 deletion tokio/tests/rt_threaded.rs
Expand Up @@ -139,7 +139,7 @@ fn spawn_shutdown() {
}

async fn client_server(tx: mpsc::Sender<()>) {
let mut server = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
let server = assert_ok!(TcpListener::bind("127.0.0.1:0").await);

// Get the assigned address
let addr = assert_ok!(server.local_addr());
Expand Down
28 changes: 11 additions & 17 deletions tokio/tests/tcp_accept.rs
Expand Up @@ -5,14 +5,15 @@ use tokio::net::{TcpListener, TcpStream};
use tokio::sync::{mpsc, oneshot};
use tokio_test::assert_ok;

use std::io;
use std::net::{IpAddr, SocketAddr};

macro_rules! test_accept {
($(($ident:ident, $target:expr),)*) => {
$(
#[tokio::test]
async fn $ident() {
let mut listener = assert_ok!(TcpListener::bind($target).await);
let listener = assert_ok!(TcpListener::bind($target).await);
let addr = listener.local_addr().unwrap();

let (tx, rx) = oneshot::channel();
Expand All @@ -39,7 +40,6 @@ test_accept! {
(ip_port_tuple, ("127.0.0.1".parse::<IpAddr>().unwrap(), 0)),
}

use pin_project_lite::pin_project;
use std::pin::Pin;
use std::sync::{
atomic::{AtomicUsize, Ordering::SeqCst},
Expand All @@ -48,23 +48,17 @@ use std::sync::{
use std::task::{Context, Poll};
use tokio::stream::{Stream, StreamExt};

pin_project! {
struct TrackPolls<S> {
npolls: Arc<AtomicUsize>,
#[pin]
s: S,
}
struct TrackPolls<'a> {
npolls: Arc<AtomicUsize>,
listener: &'a mut TcpListener,
}

impl<S> Stream for TrackPolls<S>
where
S: Stream,
{
type Item = S::Item;
impl<'a> Stream for TrackPolls<'a> {
type Item = io::Result<(TcpStream, SocketAddr)>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
this.npolls.fetch_add(1, SeqCst);
this.s.poll_next(cx)
self.npolls.fetch_add(1, SeqCst);
self.listener.poll_accept(cx).map(Some)
}
}

Expand All @@ -79,7 +73,7 @@ async fn no_extra_poll() {
tokio::spawn(async move {
let mut incoming = TrackPolls {
npolls: Arc::new(AtomicUsize::new(0)),
s: listener.incoming(),
listener: &mut listener,
};
assert_ok!(tx.send(Arc::clone(&incoming.npolls)));
while incoming.next().await.is_some() {
Expand Down