Skip to content

Commit

Permalink
tokio: Add back poll_* for udp
Browse files Browse the repository at this point in the history
  • Loading branch information
leshow committed Oct 17, 2020
1 parent 3cc6ce7 commit 5f35b00
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 0 deletions.
70 changes: 70 additions & 0 deletions tokio/src/net/udp/socket.rs
Expand Up @@ -5,6 +5,7 @@ use std::convert::TryFrom;
use std::fmt;
use std::io;
use std::net::{self, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::task::{Context, Poll};

cfg_net! {
/// A UDP socket
Expand Down Expand Up @@ -271,6 +272,21 @@ impl UdpSocket {
.await
}

#[doc(hidden)]
pub fn poll_send(&self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
loop {
let ev = ready!(self.io.poll_write_ready(cx))?;

match self.io.get_ref().send(buf) {
Ok(len) => return Poll::Ready(Ok(len)),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
self.io.clear_readiness(ev);
}
Err(e) => return Poll::Ready(Err(e)),
}
}
}

/// Try to send data on the socket to the remote address to which it is
/// connected.
///
Expand Down Expand Up @@ -303,6 +319,21 @@ impl UdpSocket {
.await
}

#[doc(hidden)]
pub fn poll_recv(&self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
loop {
let ev = ready!(self.io.poll_write_ready(cx))?;

match self.io.get_ref().recv(buf) {
Ok(len) => return Poll::Ready(Ok(len)),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
self.io.clear_readiness(ev);
}
Err(e) => return Poll::Ready(Err(e)),
}
}
}

/// Returns a future that sends data on the socket to the given address.
/// On success, the future will resolve to the number of bytes written.
///
Expand Down Expand Up @@ -336,6 +367,26 @@ impl UdpSocket {
}
}

#[doc(hidden)]
pub fn poll_send_to(
&self,
cx: &mut Context<'_>,
buf: &[u8],
target: &SocketAddr,
) -> Poll<io::Result<usize>> {
loop {
let ev = ready!(self.io.poll_write_ready(cx))?;

match self.io.get_ref().send_to(buf, *target) {
Ok(len) => return Poll::Ready(Ok(len)),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
self.io.clear_readiness(ev);
}
Err(e) => return Poll::Ready(Err(e)),
}
}
}

/// Try to send data on the socket to the given address, but if the send is blocked
/// this will return right away.
///
Expand Down Expand Up @@ -402,6 +453,25 @@ impl UdpSocket {
.await
}

#[doc(hidden)]
pub fn poll_recv_from(
&self,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<(usize, SocketAddr)>> {
loop {
let ev = ready!(self.io.poll_write_ready(cx))?;

match self.io.get_ref().recv_from(buf) {
Ok(ret) => return Poll::Ready(Ok(ret)),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
self.io.clear_readiness(ev);
}
Err(e) => return Poll::Ready(Err(e)),
}
}
}

/// Gets the value of the `SO_BROADCAST` option for this socket.
///
/// For more information about this option, see [`set_broadcast`].
Expand Down
68 changes: 68 additions & 0 deletions tokio/tests/udp.rs
@@ -1,6 +1,7 @@
#![warn(rust_2018_idioms)]
#![cfg(feature = "full")]

use futures::future::poll_fn;
use std::sync::Arc;
use tokio::net::UdpSocket;

Expand All @@ -24,6 +25,23 @@ async fn send_recv() -> std::io::Result<()> {
Ok(())
}

#[tokio::test]
async fn send_recv_poll() -> std::io::Result<()> {
let sender = UdpSocket::bind("127.0.0.1:0").await?;
let receiver = UdpSocket::bind("127.0.0.1:0").await?;

sender.connect(receiver.local_addr()?).await?;
receiver.connect(sender.local_addr()?).await?;

poll_fn(|cx| sender.poll_send(cx, MSG)).await?;

let mut recv_buf = [0u8; 32];
let len = poll_fn(|cx| receiver.poll_recv(cx, &mut recv_buf[..])).await?;

assert_eq!(&recv_buf[..len], MSG);
Ok(())
}

#[tokio::test]
async fn send_to_recv_from() -> std::io::Result<()> {
let sender = UdpSocket::bind("127.0.0.1:0").await?;
Expand All @@ -40,6 +58,22 @@ async fn send_to_recv_from() -> std::io::Result<()> {
Ok(())
}

#[tokio::test]
async fn send_to_recv_from_poll() -> std::io::Result<()> {
let sender = UdpSocket::bind("127.0.0.1:0").await?;
let receiver = UdpSocket::bind("127.0.0.1:0").await?;

let receiver_addr = receiver.local_addr()?;
poll_fn(|cx| sender.poll_send_to(cx, MSG, &receiver_addr)).await?;

let mut recv_buf = [0u8; 32];
let (len, addr) = poll_fn(|cx| receiver.poll_recv_from(cx, &mut recv_buf[..])).await?;

assert_eq!(&recv_buf[..len], MSG);
assert_eq!(addr, sender.local_addr()?);
Ok(())
}

#[tokio::test]
async fn split() -> std::io::Result<()> {
let socket = UdpSocket::bind("127.0.0.1:0").await?;
Expand Down Expand Up @@ -88,6 +122,40 @@ async fn split_chan() -> std::io::Result<()> {
Ok(())
}

#[tokio::test]
async fn split_chan_poll() -> std::io::Result<()> {
// setup UdpSocket that will echo all sent items
let socket = UdpSocket::bind("127.0.0.1:0").await?;
let addr = socket.local_addr().unwrap();
let s = Arc::new(socket);
let r = s.clone();

let (tx, mut rx) = tokio::sync::mpsc::channel::<(Vec<u8>, std::net::SocketAddr)>(1_000);
tokio::spawn(async move {
while let Some((bytes, addr)) = rx.recv().await {
poll_fn(|cx| s.poll_send_to(cx, &bytes, &addr))
.await
.unwrap();
}
});

tokio::spawn(async move {
let mut buf = [0u8; 32];
loop {
let (len, addr) = poll_fn(|cx| r.poll_recv_from(cx, &mut buf)).await.unwrap();
tx.send((buf[..len].to_vec(), addr)).await.unwrap();
}
});

// test that we can send a value and get back some response
let sender = UdpSocket::bind("127.0.0.1:0").await?;
poll_fn(|cx| sender.poll_send_to(cx, MSG, &addr)).await?;
let mut recv_buf = [0u8; 32];
let (len, _) = poll_fn(|cx| sender.poll_recv_from(cx, &mut recv_buf)).await?;
assert_eq!(&recv_buf[..len], MSG);
Ok(())
}

// # Note
//
// This test is purposely written such that each time `sender` sends data on
Expand Down

0 comments on commit 5f35b00

Please sign in to comment.