From 5f35b00e772858aa2d77f8b298f3d1567c9e2e40 Mon Sep 17 00:00:00 2001 From: Evan Cameron Date: Fri, 16 Oct 2020 23:55:26 -0400 Subject: [PATCH] tokio: Add back poll_* for udp --- tokio/src/net/udp/socket.rs | 70 +++++++++++++++++++++++++++++++++++++ tokio/tests/udp.rs | 68 +++++++++++++++++++++++++++++++++++ 2 files changed, 138 insertions(+) diff --git a/tokio/src/net/udp/socket.rs b/tokio/src/net/udp/socket.rs index 77e5dd43e7b..cd5932e1e27 100644 --- a/tokio/src/net/udp/socket.rs +++ b/tokio/src/net/udp/socket.rs @@ -5,6 +5,7 @@ use std::convert::TryFrom; use std::fmt; use std::io; use std::net::{self, Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::task::{Context, Poll}; cfg_net! { /// A UDP socket @@ -271,6 +272,21 @@ impl UdpSocket { .await } + #[doc(hidden)] + pub fn poll_send(&self, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + loop { + let ev = ready!(self.io.poll_write_ready(cx))?; + + match self.io.get_ref().send(buf) { + Ok(len) => return Poll::Ready(Ok(len)), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + self.io.clear_readiness(ev); + } + Err(e) => return Poll::Ready(Err(e)), + } + } + } + /// Try to send data on the socket to the remote address to which it is /// connected. /// @@ -303,6 +319,21 @@ impl UdpSocket { .await } + #[doc(hidden)] + pub fn poll_recv(&self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { + loop { + let ev = ready!(self.io.poll_write_ready(cx))?; + + match self.io.get_ref().recv(buf) { + Ok(len) => return Poll::Ready(Ok(len)), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + self.io.clear_readiness(ev); + } + Err(e) => return Poll::Ready(Err(e)), + } + } + } + /// Returns a future that sends data on the socket to the given address. /// On success, the future will resolve to the number of bytes written. /// @@ -336,6 +367,26 @@ impl UdpSocket { } } + #[doc(hidden)] + pub fn poll_send_to( + &self, + cx: &mut Context<'_>, + buf: &[u8], + target: &SocketAddr, + ) -> Poll> { + loop { + let ev = ready!(self.io.poll_write_ready(cx))?; + + match self.io.get_ref().send_to(buf, *target) { + Ok(len) => return Poll::Ready(Ok(len)), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + self.io.clear_readiness(ev); + } + Err(e) => return Poll::Ready(Err(e)), + } + } + } + /// Try to send data on the socket to the given address, but if the send is blocked /// this will return right away. /// @@ -402,6 +453,25 @@ impl UdpSocket { .await } + #[doc(hidden)] + pub fn poll_recv_from( + &self, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + loop { + let ev = ready!(self.io.poll_write_ready(cx))?; + + match self.io.get_ref().recv_from(buf) { + Ok(ret) => return Poll::Ready(Ok(ret)), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + self.io.clear_readiness(ev); + } + Err(e) => return Poll::Ready(Err(e)), + } + } + } + /// Gets the value of the `SO_BROADCAST` option for this socket. /// /// For more information about this option, see [`set_broadcast`]. diff --git a/tokio/tests/udp.rs b/tokio/tests/udp.rs index 0bea83aa596..473302dfc54 100644 --- a/tokio/tests/udp.rs +++ b/tokio/tests/udp.rs @@ -1,6 +1,7 @@ #![warn(rust_2018_idioms)] #![cfg(feature = "full")] +use futures::future::poll_fn; use std::sync::Arc; use tokio::net::UdpSocket; @@ -24,6 +25,23 @@ async fn send_recv() -> std::io::Result<()> { Ok(()) } +#[tokio::test] +async fn send_recv_poll() -> std::io::Result<()> { + let sender = UdpSocket::bind("127.0.0.1:0").await?; + let receiver = UdpSocket::bind("127.0.0.1:0").await?; + + sender.connect(receiver.local_addr()?).await?; + receiver.connect(sender.local_addr()?).await?; + + poll_fn(|cx| sender.poll_send(cx, MSG)).await?; + + let mut recv_buf = [0u8; 32]; + let len = poll_fn(|cx| receiver.poll_recv(cx, &mut recv_buf[..])).await?; + + assert_eq!(&recv_buf[..len], MSG); + Ok(()) +} + #[tokio::test] async fn send_to_recv_from() -> std::io::Result<()> { let sender = UdpSocket::bind("127.0.0.1:0").await?; @@ -40,6 +58,22 @@ async fn send_to_recv_from() -> std::io::Result<()> { Ok(()) } +#[tokio::test] +async fn send_to_recv_from_poll() -> std::io::Result<()> { + let sender = UdpSocket::bind("127.0.0.1:0").await?; + let receiver = UdpSocket::bind("127.0.0.1:0").await?; + + let receiver_addr = receiver.local_addr()?; + poll_fn(|cx| sender.poll_send_to(cx, MSG, &receiver_addr)).await?; + + let mut recv_buf = [0u8; 32]; + let (len, addr) = poll_fn(|cx| receiver.poll_recv_from(cx, &mut recv_buf[..])).await?; + + assert_eq!(&recv_buf[..len], MSG); + assert_eq!(addr, sender.local_addr()?); + Ok(()) +} + #[tokio::test] async fn split() -> std::io::Result<()> { let socket = UdpSocket::bind("127.0.0.1:0").await?; @@ -88,6 +122,40 @@ async fn split_chan() -> std::io::Result<()> { Ok(()) } +#[tokio::test] +async fn split_chan_poll() -> std::io::Result<()> { + // setup UdpSocket that will echo all sent items + let socket = UdpSocket::bind("127.0.0.1:0").await?; + let addr = socket.local_addr().unwrap(); + let s = Arc::new(socket); + let r = s.clone(); + + let (tx, mut rx) = tokio::sync::mpsc::channel::<(Vec, std::net::SocketAddr)>(1_000); + tokio::spawn(async move { + while let Some((bytes, addr)) = rx.recv().await { + poll_fn(|cx| s.poll_send_to(cx, &bytes, &addr)) + .await + .unwrap(); + } + }); + + tokio::spawn(async move { + let mut buf = [0u8; 32]; + loop { + let (len, addr) = poll_fn(|cx| r.poll_recv_from(cx, &mut buf)).await.unwrap(); + tx.send((buf[..len].to_vec(), addr)).await.unwrap(); + } + }); + + // test that we can send a value and get back some response + let sender = UdpSocket::bind("127.0.0.1:0").await?; + poll_fn(|cx| sender.poll_send_to(cx, MSG, &addr)).await?; + let mut recv_buf = [0u8; 32]; + let (len, _) = poll_fn(|cx| sender.poll_recv_from(cx, &mut recv_buf)).await?; + assert_eq!(&recv_buf[..len], MSG); + Ok(()) +} + // # Note // // This test is purposely written such that each time `sender` sends data on