Skip to content

Commit

Permalink
Add ReadBuf to UdpSocket poll_recv
Browse files Browse the repository at this point in the history
  • Loading branch information
leshow committed Oct 19, 2020
1 parent b260e0c commit 8418437
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 17 deletions.
52 changes: 45 additions & 7 deletions tokio/src/net/udp/socket.rs
@@ -1,4 +1,4 @@
use crate::io::PollEvented;
use crate::io::{PollEvented, ReadBuf};
use crate::net::{to_socket_addrs, ToSocketAddrs};

use std::convert::TryFrom;
Expand Down Expand Up @@ -278,6 +278,8 @@ impl UdpSocket {
/// The [`connect`] method will connect this socket to a remote address. The future
/// will resolve to an error if the socket is not connected..
///
/// **Note**:`poll_*` methods are only able to associate to one task per read or write direction.
///
/// # Return value
///
/// The function returns:
Expand Down Expand Up @@ -342,6 +344,8 @@ impl UdpSocket {
/// The [`connect`] method will connect this socket to a remote address. The future
/// will resolve to an error if the socket is not connected..
///
/// **Note**: `poll_*` methods are only able to associate to one task per read or write direction.
///
/// # Return value
///
/// The function returns:
Expand All @@ -355,15 +359,32 @@ impl UdpSocket {
/// This function may encounter any standard I/O error except `WouldBlock`.
///
/// [`connect`]: method@Self::connect
pub fn poll_recv(&self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
pub fn poll_recv(
&self,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<usize>> {
loop {
let ev = ready!(self.io.poll_read_ready(cx))?;

match self.io.get_ref().recv(buf) {
// Safety: will not read the maybe uinitialized bytes.
let b = unsafe {
&mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8])
};
match self.io.get_ref().recv(b) {
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
self.io.clear_readiness(ev);
}
x => return Poll::Ready(x),
Err(e) => return Poll::Ready(Err(e)),
Ok(n) => {
// Safety: We trust `recv` to have filled up `n` bytes
// in the buffer.
unsafe {
buf.assume_init(n);
}
buf.advance(n);
return Poll::Ready(Ok(n));
}
}
}
}
Expand Down Expand Up @@ -403,6 +424,8 @@ impl UdpSocket {

/// Attempts to send data on the socket to a given address.
///
/// **Note**: `poll_*` methods are only able to associate to one task per read or write direction.
///
/// # Return value
///
/// The function returns:
Expand Down Expand Up @@ -500,6 +523,8 @@ impl UdpSocket {

/// Attempts to receive a single datagram on the socket.
///
/// **Note**: `poll_*` methods are only able to associate to one task per read or write direction.
///
/// # Return value
///
/// The function returns:
Expand All @@ -514,16 +539,29 @@ impl UdpSocket {
pub fn poll_recv_from(
&self,
cx: &mut Context<'_>,
buf: &mut [u8],
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<(usize, SocketAddr)>> {
loop {
let ev = ready!(self.io.poll_read_ready(cx))?;

match self.io.get_ref().recv_from(buf) {
// Safety: will not read the maybe uinitialized bytes.
let b = unsafe {
&mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8])
};
match self.io.get_ref().recv_from(b) {
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
self.io.clear_readiness(ev);
}
x => return Poll::Ready(x),
Err(e) => return Poll::Ready(Err(e)),
Ok((n, addr)) => {
// Safety: We trust `recv` to have filled up `n` bytes
// in the buffer.
unsafe {
buf.assume_init(n);
}
buf.advance(n);
return Poll::Ready(Ok((n, addr)));
}
}
}
}
Expand Down
25 changes: 15 additions & 10 deletions tokio/tests/udp.rs
Expand Up @@ -3,7 +3,7 @@

use futures::future::poll_fn;
use std::sync::Arc;
use tokio::net::UdpSocket;
use tokio::{io::ReadBuf, net::UdpSocket};

const MSG: &[u8] = b"hello";
const MSG_LEN: usize = MSG.len();
Expand Down Expand Up @@ -36,9 +36,10 @@ async fn send_recv_poll() -> std::io::Result<()> {
poll_fn(|cx| sender.poll_send(cx, MSG)).await?;

let mut recv_buf = [0u8; 32];
let len = poll_fn(|cx| receiver.poll_recv(cx, &mut recv_buf[..])).await?;
let mut read = ReadBuf::new(&mut recv_buf);
let _len = poll_fn(|cx| receiver.poll_recv(cx, &mut read)).await?;

assert_eq!(&recv_buf[..len], MSG);
assert_eq!(read.filled(), MSG);
Ok(())
}

Expand Down Expand Up @@ -67,9 +68,10 @@ async fn send_to_recv_from_poll() -> std::io::Result<()> {
poll_fn(|cx| sender.poll_send_to(cx, MSG, &receiver_addr)).await?;

let mut recv_buf = [0u8; 32];
let (len, addr) = poll_fn(|cx| receiver.poll_recv_from(cx, &mut recv_buf[..])).await?;
let mut read = ReadBuf::new(&mut recv_buf);
let (_len, addr) = poll_fn(|cx| receiver.poll_recv_from(cx, &mut read)).await?;

assert_eq!(&recv_buf[..len], MSG);
assert_eq!(read.filled(), MSG);
assert_eq!(addr, sender.local_addr()?);
Ok(())
}
Expand Down Expand Up @@ -140,19 +142,22 @@ async fn split_chan_poll() -> std::io::Result<()> {
});

tokio::spawn(async move {
let mut buf = [0u8; 32];
let mut recv_buf = [0u8; 32];
let mut read = ReadBuf::new(&mut recv_buf);
loop {
let (len, addr) = poll_fn(|cx| r.poll_recv_from(cx, &mut buf)).await.unwrap();
tx.send((buf[..len].to_vec(), addr)).await.unwrap();
let (_len, addr) = poll_fn(|cx| r.poll_recv_from(cx, &mut read)).await.unwrap();
tx.send((read.filled().to_vec(), addr)).await.unwrap();
}
});

// test that we can send a value and get back some response
let sender = UdpSocket::bind("127.0.0.1:0").await?;
poll_fn(|cx| sender.poll_send_to(cx, MSG, &addr)).await?;

let mut recv_buf = [0u8; 32];
let (len, _) = poll_fn(|cx| sender.poll_recv_from(cx, &mut recv_buf)).await?;
assert_eq!(&recv_buf[..len], MSG);
let mut read = ReadBuf::new(&mut recv_buf);
let (_len, _) = poll_fn(|cx| sender.poll_recv_from(cx, &mut read)).await?;
assert_eq!(read.filled(), MSG);
Ok(())
}

Expand Down

0 comments on commit 8418437

Please sign in to comment.