From c28cfbe5be41d301c9f793ab51c9e579befd5234 Mon Sep 17 00:00:00 2001 From: Evan Cameron Date: Wed, 21 Oct 2020 16:55:18 -0400 Subject: [PATCH] udp: add peek_from & poll_peek_from --- tokio/src/net/udp/socket.rs | 110 ++++++++++++++++++++++++++++++++---- tokio/tests/udp.rs | 62 +++++++++++++++++++- 2 files changed, 159 insertions(+), 13 deletions(-) diff --git a/tokio/src/net/udp/socket.rs b/tokio/src/net/udp/socket.rs index dfce637c636..1a88f5a6c7c 100644 --- a/tokio/src/net/udp/socket.rs +++ b/tokio/src/net/udp/socket.rs @@ -355,7 +355,7 @@ impl UdpSocket { /// The function returns: /// /// * `Poll::Pending` if the socket is not ready to read - /// * `Poll::Ready(Ok(n))` `n` is the number of bytes read. + /// * `Poll::Ready(Ok(()))` reads data `ReadBuf` if the socket is ready /// * `Poll::Ready(Err(e))` if an error is encountered. /// /// # Errors @@ -363,11 +363,7 @@ impl UdpSocket { /// This function may encounter any standard I/O error except `WouldBlock`. /// /// [`connect`]: method@Self::connect - pub fn poll_recv( - &self, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { + pub fn poll_recv(&self, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { loop { let ev = ready!(self.io.poll_read_ready(cx))?; @@ -387,7 +383,7 @@ impl UdpSocket { buf.assume_init(n); } buf.advance(n); - return Poll::Ready(Ok(n)); + return Poll::Ready(Ok(())); } } } @@ -538,7 +534,7 @@ impl UdpSocket { /// The function returns: /// /// * `Poll::Pending` if the socket is not ready to read - /// * `Poll::Ready(Ok((n, addr)))` a tuple where `n` is the number of bytes received from `addr` + /// * `Poll::Ready(Ok(addr))` reads data from `addr` into `ReadBuf` if the socket is ready /// * `Poll::Ready(Err(e))` if an error is encountered. /// /// # Errors @@ -548,7 +544,7 @@ impl UdpSocket { &self, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, - ) -> Poll> { + ) -> Poll> { loop { let ev = ready!(self.io.poll_read_ready(cx))?; @@ -568,7 +564,101 @@ impl UdpSocket { buf.assume_init(n); } buf.advance(n); - return Poll::Ready(Ok((n, addr))); + return Poll::Ready(Ok(addr)); + } + } + } + } + + /// Receives data from the socket, without removing it from the input queue. + /// On success, returns the number of bytes read and the address from whence + /// the data came. + /// + /// # Notes + /// + /// On Windows, if the data is larger than the buffer specified, the buffer + /// is filled with the first part of the data, and peek_from returns the error + /// WSAEMSGSIZE(10040). The excess data is lost. + /// Make sure to always use a sufficiently large buffer to hold the + /// maximum UDP packet size, which can be up to 65536 bytes in size. + /// + /// # Examples + /// + /// ```no_run + /// # use std::error::Error; + /// # + /// # fn main() -> Result<(), Box> { + /// use tokio::net::UdpSocket; + /// + /// let socket = UdpSocket::bind("127.0.0.1:0".parse()?)?; + /// + /// // We must check if the socket is readable before calling recv_from, + /// // or we could run into a WouldBlock error. + /// + /// let mut buf = [0; 9]; + /// let (num_recv, from_addr) = socket.peek_from(&mut buf)?; + /// println!("Received {:?} -> {:?} bytes from {:?}", buf, num_recv, from_addr); + /// # + /// # Ok(()) + /// # } + /// ``` + pub async fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + self.io + .async_io(mio::Interest::READABLE, |sock| sock.peek_from(buf)) + .await + } + + /// Receives data from the socket, without removing it from the input queue. + /// On success, returns the number of bytes read. + /// + /// # Notes + /// + /// Note that on multiple calls to a `poll_*` method in the recv direction, only the + /// `Waker` from the `Context` passed to the most recent call will be scheduled to + /// receive a wakeup + /// + /// On Windows, if the data is larger than the buffer specified, the buffer + /// is filled with the first part of the data, and peek returns the error + /// WSAEMSGSIZE(10040). The excess data is lost. + /// Make sure to always use a sufficiently large buffer to hold the + /// maximum UDP packet size, which can be up to 65536 bytes in size. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the socket is not ready to read + /// * `Poll::Ready(Ok(addr))` reads data from `addr` into `ReadBuf` if the socket is ready + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + pub fn poll_peek_from( + &self, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + loop { + let ev = ready!(self.io.poll_read_ready(cx))?; + + // Safety: will not read the maybe uinitialized bytes. + let b = unsafe { + &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit] as *mut [u8]) + }; + match self.io.get_ref().peek_from(b) { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + self.io.clear_readiness(ev); + } + Err(e) => return Poll::Ready(Err(e)), + Ok((n, addr)) => { + // Safety: We trust `recv` to have filled up `n` bytes + // in the buffer. + unsafe { + buf.assume_init(n); + } + buf.advance(n); + return Poll::Ready(Ok(addr)); } } } diff --git a/tokio/tests/udp.rs b/tokio/tests/udp.rs index eac1d869533..8b79cb850be 100644 --- a/tokio/tests/udp.rs +++ b/tokio/tests/udp.rs @@ -69,13 +69,69 @@ async fn send_to_recv_from_poll() -> std::io::Result<()> { let mut recv_buf = [0u8; 32]; let mut read = ReadBuf::new(&mut recv_buf); - let (_len, addr) = poll_fn(|cx| receiver.poll_recv_from(cx, &mut read)).await?; + let addr = poll_fn(|cx| receiver.poll_recv_from(cx, &mut read)).await?; assert_eq!(read.filled(), MSG); assert_eq!(addr, sender.local_addr()?); Ok(()) } +#[tokio::test] +async fn send_to_peek_from() -> std::io::Result<()> { + let sender = UdpSocket::bind("127.0.0.1:0").await?; + let receiver = UdpSocket::bind("127.0.0.1:0").await?; + + let receiver_addr = receiver.local_addr()?; + poll_fn(|cx| sender.poll_send_to(cx, MSG, &receiver_addr)).await?; + + // peek + let mut recv_buf = [0u8; 32]; + let (n, addr) = receiver.peek_from(&mut recv_buf).await?; + assert_eq!(&recv_buf[..n], MSG); + assert_eq!(addr, sender.local_addr()?); + + // peek + let mut recv_buf = [0u8; 32]; + let (n, addr) = receiver.peek_from(&mut recv_buf).await?; + assert_eq!(&recv_buf[..n], MSG); + assert_eq!(addr, sender.local_addr()?); + + let mut recv_buf = [0u8; 32]; + let (n, addr) = receiver.recv_from(&mut recv_buf).await?; + assert_eq!(&recv_buf[..n], MSG); + assert_eq!(addr, sender.local_addr()?); + + Ok(()) +} + +#[tokio::test] +async fn send_to_peek_from_poll() -> std::io::Result<()> { + let sender = UdpSocket::bind("127.0.0.1:0").await?; + let receiver = UdpSocket::bind("127.0.0.1:0").await?; + + let receiver_addr = receiver.local_addr()?; + poll_fn(|cx| sender.poll_send_to(cx, MSG, &receiver_addr)).await?; + + let mut recv_buf = [0u8; 32]; + let mut read = ReadBuf::new(&mut recv_buf); + let addr = poll_fn(|cx| receiver.poll_peek_from(cx, &mut read)).await?; + + assert_eq!(read.filled(), MSG); + assert_eq!(addr, sender.local_addr()?); + + let mut recv_buf = [0u8; 32]; + let mut read = ReadBuf::new(&mut recv_buf); + poll_fn(|cx| receiver.poll_peek_from(cx, &mut read)).await?; + + assert_eq!(read.filled(), MSG); + let mut recv_buf = [0u8; 32]; + let mut read = ReadBuf::new(&mut recv_buf); + + poll_fn(|cx| receiver.poll_recv_from(cx, &mut read)).await?; + assert_eq!(read.filled(), MSG); + Ok(()) +} + #[tokio::test] async fn split() -> std::io::Result<()> { let socket = UdpSocket::bind("127.0.0.1:0").await?; @@ -145,7 +201,7 @@ async fn split_chan_poll() -> std::io::Result<()> { let mut recv_buf = [0u8; 32]; let mut read = ReadBuf::new(&mut recv_buf); loop { - let (_len, addr) = poll_fn(|cx| r.poll_recv_from(cx, &mut read)).await.unwrap(); + let addr = poll_fn(|cx| r.poll_recv_from(cx, &mut read)).await.unwrap(); tx.send((read.filled().to_vec(), addr)).await.unwrap(); } }); @@ -156,7 +212,7 @@ async fn split_chan_poll() -> std::io::Result<()> { let mut recv_buf = [0u8; 32]; let mut read = ReadBuf::new(&mut recv_buf); - let (_len, _) = poll_fn(|cx| sender.poll_recv_from(cx, &mut read)).await?; + let _ = poll_fn(|cx| sender.poll_recv_from(cx, &mut read)).await?; assert_eq!(read.filled(), MSG); Ok(()) }