Skip to content

Commit

Permalink
udp: add peek_from & poll_peek_from
Browse files Browse the repository at this point in the history
  • Loading branch information
leshow committed Oct 21, 2020
1 parent 599986e commit 3871034
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 13 deletions.
110 changes: 100 additions & 10 deletions tokio/src/net/udp/socket.rs
Expand Up @@ -356,19 +356,15 @@ impl UdpSocket {
/// The function returns:
///
/// * `Poll::Pending` if the socket is not ready to read
/// * `Poll::Ready(Ok(n))` `n` is the number of bytes read.
/// * `Poll::Ready(Ok(()))` reads data `ReadBuf` if the socket is ready
/// * `Poll::Ready(Err(e))` if an error is encountered.
///
/// # Errors
///
/// This function may encounter any standard I/O error except `WouldBlock`.
///
/// [`connect`]: method@Self::connect
pub fn poll_recv(
&self,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<usize>> {
pub fn poll_recv(&self, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
loop {
let ev = ready!(self.io.poll_read_ready(cx))?;

Expand All @@ -388,7 +384,7 @@ impl UdpSocket {
buf.assume_init(n);
}
buf.advance(n);
return Poll::Ready(Ok(n));
return Poll::Ready(Ok(()));
}
}
}
Expand Down Expand Up @@ -539,7 +535,7 @@ impl UdpSocket {
/// The function returns:
///
/// * `Poll::Pending` if the socket is not ready to read
/// * `Poll::Ready(Ok((n, addr)))` a tuple where `n` is the number of bytes received from `addr`
/// * `Poll::Ready(Ok(addr))` reads data from `addr` into `ReadBuf` if the socket is ready
/// * `Poll::Ready(Err(e))` if an error is encountered.
///
/// # Errors
Expand All @@ -549,7 +545,7 @@ impl UdpSocket {
&self,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<(usize, SocketAddr)>> {
) -> Poll<io::Result<SocketAddr>> {
loop {
let ev = ready!(self.io.poll_read_ready(cx))?;

Expand All @@ -569,7 +565,101 @@ impl UdpSocket {
buf.assume_init(n);
}
buf.advance(n);
return Poll::Ready(Ok((n, addr)));
return Poll::Ready(Ok(addr));
}
}
}
}

/// Receives data from the socket, without removing it from the input queue.
/// On success, returns the number of bytes read and the address from whence
/// the data came.
///
/// # Notes
///
/// On Windows, if the data is larger than the buffer specified, the buffer
/// is filled with the first part of the data, and peek_from returns the error
/// WSAEMSGSIZE(10040). The excess data is lost.
/// Make sure to always use a sufficiently large buffer to hold the
/// maximum UDP packet size, which can be up to 65536 bytes in size.
///
/// # Examples
///
/// ```no_run
/// # use std::error::Error;
/// #
/// # fn main() -> Result<(), Box<dyn Error>> {
/// use tokio::net::UdpSocket;
///
/// let socket = UdpSocket::bind("127.0.0.1:0".parse()?)?;
///
/// // We must check if the socket is readable before calling recv_from,
/// // or we could run into a WouldBlock error.
///
/// let mut buf = [0; 9];
/// let (num_recv, from_addr) = socket.peek_from(&mut buf)?;
/// println!("Received {:?} -> {:?} bytes from {:?}", buf, num_recv, from_addr);
/// #
/// # Ok(())
/// # }
/// ```
pub async fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
self.io
.async_io(mio::Interest::READABLE, |sock| sock.peek_from(buf))
.await
}

/// Receives data from the socket, without removing it from the input queue.
/// On success, returns the number of bytes read.
///
/// # Notes
///
/// Note that on multiple calls to a `poll_*` method in the recv direction, only the
/// `Waker` from the `Context` passed to the most recent call will be scheduled to
/// receive a wakeup
///
/// On Windows, if the data is larger than the buffer specified, the buffer
/// is filled with the first part of the data, and peek returns the error
/// WSAEMSGSIZE(10040). The excess data is lost.
/// Make sure to always use a sufficiently large buffer to hold the
/// maximum UDP packet size, which can be up to 65536 bytes in size.
///
/// # Return value
///
/// The function returns:
///
/// * `Poll::Pending` if the socket is not ready to read
/// * `Poll::Ready(Ok(addr))` reads data from `addr` into `ReadBuf` if the socket is ready
/// * `Poll::Ready(Err(e))` if an error is encountered.
///
/// # Errors
///
/// This function may encounter any standard I/O error except `WouldBlock`.
pub fn poll_peek_from(
&self,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<SocketAddr>> {
loop {
let ev = ready!(self.io.poll_read_ready(cx))?;

// Safety: will not read the maybe uinitialized bytes.
let b = unsafe {
&mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8])
};
match self.io.get_ref().peek_from(b) {
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
self.io.clear_readiness(ev);
}
Err(e) => return Poll::Ready(Err(e)),
Ok((n, addr)) => {
// Safety: We trust `recv` to have filled up `n` bytes
// in the buffer.
unsafe {
buf.assume_init(n);
}
buf.advance(n);
return Poll::Ready(Ok(addr));
}
}
}
Expand Down
62 changes: 59 additions & 3 deletions tokio/tests/udp.rs
Expand Up @@ -69,13 +69,69 @@ async fn send_to_recv_from_poll() -> std::io::Result<()> {

let mut recv_buf = [0u8; 32];
let mut read = ReadBuf::new(&mut recv_buf);
let (_len, addr) = poll_fn(|cx| receiver.poll_recv_from(cx, &mut read)).await?;
let addr = poll_fn(|cx| receiver.poll_recv_from(cx, &mut read)).await?;

assert_eq!(read.filled(), MSG);
assert_eq!(addr, sender.local_addr()?);
Ok(())
}

#[tokio::test]
async fn send_to_peek_from() -> std::io::Result<()> {
let sender = UdpSocket::bind("127.0.0.1:0").await?;
let receiver = UdpSocket::bind("127.0.0.1:0").await?;

let receiver_addr = receiver.local_addr()?;
poll_fn(|cx| sender.poll_send_to(cx, MSG, &receiver_addr)).await?;

// peek
let mut recv_buf = [0u8; 32];
let (n, addr) = receiver.peek_from(&mut recv_buf).await?;
assert_eq!(&recv_buf[..n], MSG);
assert_eq!(addr, sender.local_addr()?);

// peek
let mut recv_buf = [0u8; 32];
let (n, addr) = receiver.peek_from(&mut recv_buf).await?;
assert_eq!(&recv_buf[..n], MSG);
assert_eq!(addr, sender.local_addr()?);

let mut recv_buf = [0u8; 32];
let (n, addr) = receiver.recv_from(&mut recv_buf).await?;
assert_eq!(&recv_buf[..n], MSG);
assert_eq!(addr, sender.local_addr()?);

Ok(())
}

#[tokio::test]
async fn send_to_peek_from_poll() -> std::io::Result<()> {
let sender = UdpSocket::bind("127.0.0.1:0").await?;
let receiver = UdpSocket::bind("127.0.0.1:0").await?;

let receiver_addr = receiver.local_addr()?;
poll_fn(|cx| sender.poll_send_to(cx, MSG, &receiver_addr)).await?;

let mut recv_buf = [0u8; 32];
let mut read = ReadBuf::new(&mut recv_buf);
let addr = poll_fn(|cx| receiver.poll_peek_from(cx, &mut read)).await?;

assert_eq!(read.filled(), MSG);
assert_eq!(addr, sender.local_addr()?);

let mut recv_buf = [0u8; 32];
let mut read = ReadBuf::new(&mut recv_buf);
poll_fn(|cx| receiver.poll_peek_from(cx, &mut read)).await?;

assert_eq!(read.filled(), MSG);
let mut recv_buf = [0u8; 32];
let mut read = ReadBuf::new(&mut recv_buf);

poll_fn(|cx| receiver.poll_recv_from(cx, &mut read)).await?;
assert_eq!(read.filled(), MSG);
Ok(())
}

#[tokio::test]
async fn split() -> std::io::Result<()> {
let socket = UdpSocket::bind("127.0.0.1:0").await?;
Expand Down Expand Up @@ -145,7 +201,7 @@ async fn split_chan_poll() -> std::io::Result<()> {
let mut recv_buf = [0u8; 32];
let mut read = ReadBuf::new(&mut recv_buf);
loop {
let (_len, addr) = poll_fn(|cx| r.poll_recv_from(cx, &mut read)).await.unwrap();
let addr = poll_fn(|cx| r.poll_recv_from(cx, &mut read)).await.unwrap();
tx.send((read.filled().to_vec(), addr)).await.unwrap();
}
});
Expand All @@ -156,7 +212,7 @@ async fn split_chan_poll() -> std::io::Result<()> {

let mut recv_buf = [0u8; 32];
let mut read = ReadBuf::new(&mut recv_buf);
let (_len, _) = poll_fn(|cx| sender.poll_recv_from(cx, &mut read)).await?;
let _ = poll_fn(|cx| sender.poll_recv_from(cx, &mut read)).await?;
assert_eq!(read.filled(), MSG);
Ok(())
}
Expand Down

0 comments on commit 3871034

Please sign in to comment.