diff --git a/tokio/src/net/udp/socket.rs b/tokio/src/net/udp/socket.rs index 8c1bcf42d23..3229432507a 100644 --- a/tokio/src/net/udp/socket.rs +++ b/tokio/src/net/udp/socket.rs @@ -1,4 +1,4 @@ -use crate::io::PollEvented; +use crate::io::{PollEvented, ReadBuf}; use crate::net::{to_socket_addrs, ToSocketAddrs}; use std::convert::TryFrom; @@ -278,6 +278,8 @@ impl UdpSocket { /// The [`connect`] method will connect this socket to a remote address. The future /// will resolve to an error if the socket is not connected.. /// + /// **Note**:`poll_*` methods are only able to associate to one task per read or write direction. + /// /// # Return value /// /// The function returns: @@ -342,6 +344,8 @@ impl UdpSocket { /// The [`connect`] method will connect this socket to a remote address. The future /// will resolve to an error if the socket is not connected.. /// + /// **Note**: `poll_*` methods are only able to associate to one task per read or write direction. + /// /// # Return value /// /// The function returns: @@ -355,15 +359,32 @@ impl UdpSocket { /// This function may encounter any standard I/O error except `WouldBlock`. /// /// [`connect`]: method@Self::connect - pub fn poll_recv(&self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { + pub fn poll_recv( + &self, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { loop { let ev = ready!(self.io.poll_read_ready(cx))?; - match self.io.get_ref().recv(buf) { + // Safety: will not read the maybe uinitialized bytes. + let b = unsafe { + &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit] as *mut [u8]) + }; + match self.io.get_ref().recv(b) { Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { self.io.clear_readiness(ev); } - x => return Poll::Ready(x), + Err(e) => return Poll::Ready(Err(e)), + Ok(n) => { + // Safety: We trust `recv` to have filled up `n` bytes + // in the buffer. + unsafe { + buf.assume_init(n); + } + buf.advance(n); + return Poll::Ready(Ok(n)); + } } } } @@ -403,6 +424,8 @@ impl UdpSocket { /// Attempts to send data on the socket to a given address. /// + /// **Note**: `poll_*` methods are only able to associate to one task per read or write direction. + /// /// # Return value /// /// The function returns: @@ -500,6 +523,8 @@ impl UdpSocket { /// Attempts to receive a single datagram on the socket. /// + /// **Note**: `poll_*` methods are only able to associate to one task per read or write direction. + /// /// # Return value /// /// The function returns: @@ -514,16 +539,29 @@ impl UdpSocket { pub fn poll_recv_from( &self, cx: &mut Context<'_>, - buf: &mut [u8], + buf: &mut ReadBuf<'_>, ) -> Poll> { loop { let ev = ready!(self.io.poll_read_ready(cx))?; - match self.io.get_ref().recv_from(buf) { + // Safety: will not read the maybe uinitialized bytes. + let b = unsafe { + &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit] as *mut [u8]) + }; + match self.io.get_ref().recv_from(b) { Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { self.io.clear_readiness(ev); } - x => return Poll::Ready(x), + Err(e) => return Poll::Ready(Err(e)), + Ok((n, addr)) => { + // Safety: We trust `recv` to have filled up `n` bytes + // in the buffer. + unsafe { + buf.assume_init(n); + } + buf.advance(n); + return Poll::Ready(Ok((n, addr))); + } } } } diff --git a/tokio/tests/udp.rs b/tokio/tests/udp.rs index 473302dfc54..eac1d869533 100644 --- a/tokio/tests/udp.rs +++ b/tokio/tests/udp.rs @@ -3,7 +3,7 @@ use futures::future::poll_fn; use std::sync::Arc; -use tokio::net::UdpSocket; +use tokio::{io::ReadBuf, net::UdpSocket}; const MSG: &[u8] = b"hello"; const MSG_LEN: usize = MSG.len(); @@ -36,9 +36,10 @@ async fn send_recv_poll() -> std::io::Result<()> { poll_fn(|cx| sender.poll_send(cx, MSG)).await?; let mut recv_buf = [0u8; 32]; - let len = poll_fn(|cx| receiver.poll_recv(cx, &mut recv_buf[..])).await?; + let mut read = ReadBuf::new(&mut recv_buf); + let _len = poll_fn(|cx| receiver.poll_recv(cx, &mut read)).await?; - assert_eq!(&recv_buf[..len], MSG); + assert_eq!(read.filled(), MSG); Ok(()) } @@ -67,9 +68,10 @@ async fn send_to_recv_from_poll() -> std::io::Result<()> { poll_fn(|cx| sender.poll_send_to(cx, MSG, &receiver_addr)).await?; let mut recv_buf = [0u8; 32]; - let (len, addr) = poll_fn(|cx| receiver.poll_recv_from(cx, &mut recv_buf[..])).await?; + let mut read = ReadBuf::new(&mut recv_buf); + let (_len, addr) = poll_fn(|cx| receiver.poll_recv_from(cx, &mut read)).await?; - assert_eq!(&recv_buf[..len], MSG); + assert_eq!(read.filled(), MSG); assert_eq!(addr, sender.local_addr()?); Ok(()) } @@ -140,19 +142,22 @@ async fn split_chan_poll() -> std::io::Result<()> { }); tokio::spawn(async move { - let mut buf = [0u8; 32]; + let mut recv_buf = [0u8; 32]; + let mut read = ReadBuf::new(&mut recv_buf); loop { - let (len, addr) = poll_fn(|cx| r.poll_recv_from(cx, &mut buf)).await.unwrap(); - tx.send((buf[..len].to_vec(), addr)).await.unwrap(); + let (_len, addr) = poll_fn(|cx| r.poll_recv_from(cx, &mut read)).await.unwrap(); + tx.send((read.filled().to_vec(), addr)).await.unwrap(); } }); // test that we can send a value and get back some response let sender = UdpSocket::bind("127.0.0.1:0").await?; poll_fn(|cx| sender.poll_send_to(cx, MSG, &addr)).await?; + let mut recv_buf = [0u8; 32]; - let (len, _) = poll_fn(|cx| sender.poll_recv_from(cx, &mut recv_buf)).await?; - assert_eq!(&recv_buf[..len], MSG); + let mut read = ReadBuf::new(&mut recv_buf); + let (_len, _) = poll_fn(|cx| sender.poll_recv_from(cx, &mut read)).await?; + assert_eq!(read.filled(), MSG); Ok(()) }