diff --git a/tokio/src/net/udp/socket.rs b/tokio/src/net/udp/socket.rs index 77e5dd43e7b..8c1bcf42d23 100644 --- a/tokio/src/net/udp/socket.rs +++ b/tokio/src/net/udp/socket.rs @@ -5,6 +5,7 @@ use std::convert::TryFrom; use std::fmt; use std::io; use std::net::{self, Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::task::{Context, Poll}; cfg_net! { /// A UDP socket @@ -271,6 +272,38 @@ impl UdpSocket { .await } + /// Attempts to send data on the socket to the remote address to which it was previously + /// `connect`ed. + /// + /// The [`connect`] method will connect this socket to a remote address. The future + /// will resolve to an error if the socket is not connected.. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the socket is not available to write + /// * `Poll::Ready(Ok(n))` `n` is the number of bytes sent + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + /// + /// [`connect`]: method@Self::connect + pub fn poll_send(&self, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + loop { + let ev = ready!(self.io.poll_write_ready(cx))?; + + match self.io.get_ref().send(buf) { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + self.io.clear_readiness(ev); + } + x => return Poll::Ready(x), + } + } + } + /// Try to send data on the socket to the remote address to which it is /// connected. /// @@ -303,6 +336,38 @@ impl UdpSocket { .await } + /// Attempts to receive a single datagram message on the socket from the remote + /// address to which it is `connect`ed. + /// + /// The [`connect`] method will connect this socket to a remote address. The future + /// will resolve to an error if the socket is not connected.. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the socket is not ready to read + /// * `Poll::Ready(Ok(n))` `n` is the number of bytes read. + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + /// + /// [`connect`]: method@Self::connect + pub fn poll_recv(&self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { + loop { + let ev = ready!(self.io.poll_read_ready(cx))?; + + match self.io.get_ref().recv(buf) { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + self.io.clear_readiness(ev); + } + x => return Poll::Ready(x), + } + } + } + /// Returns a future that sends data on the socket to the given address. /// On success, the future will resolve to the number of bytes written. /// @@ -336,6 +401,37 @@ impl UdpSocket { } } + /// Attempts to send data on the socket to a given address. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the socket is not ready to write + /// * `Poll::Ready(Ok(n))` `n` is the number of bytes sent. + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + pub fn poll_send_to( + &self, + cx: &mut Context<'_>, + buf: &[u8], + target: &SocketAddr, + ) -> Poll> { + loop { + let ev = ready!(self.io.poll_write_ready(cx))?; + + match self.io.get_ref().send_to(buf, *target) { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + self.io.clear_readiness(ev); + } + x => return Poll::Ready(x), + } + } + } + /// Try to send data on the socket to the given address, but if the send is blocked /// this will return right away. /// @@ -402,6 +498,36 @@ impl UdpSocket { .await } + /// Attempts to receive a single datagram on the socket. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the socket is not ready to read + /// * `Poll::Ready(Ok((n, addr)))` a tuple where `n` is the number of bytes received from `addr` + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + pub fn poll_recv_from( + &self, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + loop { + let ev = ready!(self.io.poll_read_ready(cx))?; + + match self.io.get_ref().recv_from(buf) { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + self.io.clear_readiness(ev); + } + x => return Poll::Ready(x), + } + } + } + /// Gets the value of the `SO_BROADCAST` option for this socket. /// /// For more information about this option, see [`set_broadcast`]. diff --git a/tokio/tests/udp.rs b/tokio/tests/udp.rs index 0bea83aa596..473302dfc54 100644 --- a/tokio/tests/udp.rs +++ b/tokio/tests/udp.rs @@ -1,6 +1,7 @@ #![warn(rust_2018_idioms)] #![cfg(feature = "full")] +use futures::future::poll_fn; use std::sync::Arc; use tokio::net::UdpSocket; @@ -24,6 +25,23 @@ async fn send_recv() -> std::io::Result<()> { Ok(()) } +#[tokio::test] +async fn send_recv_poll() -> std::io::Result<()> { + let sender = UdpSocket::bind("127.0.0.1:0").await?; + let receiver = UdpSocket::bind("127.0.0.1:0").await?; + + sender.connect(receiver.local_addr()?).await?; + receiver.connect(sender.local_addr()?).await?; + + poll_fn(|cx| sender.poll_send(cx, MSG)).await?; + + let mut recv_buf = [0u8; 32]; + let len = poll_fn(|cx| receiver.poll_recv(cx, &mut recv_buf[..])).await?; + + assert_eq!(&recv_buf[..len], MSG); + Ok(()) +} + #[tokio::test] async fn send_to_recv_from() -> std::io::Result<()> { let sender = UdpSocket::bind("127.0.0.1:0").await?; @@ -40,6 +58,22 @@ async fn send_to_recv_from() -> std::io::Result<()> { Ok(()) } +#[tokio::test] +async fn send_to_recv_from_poll() -> std::io::Result<()> { + let sender = UdpSocket::bind("127.0.0.1:0").await?; + let receiver = UdpSocket::bind("127.0.0.1:0").await?; + + let receiver_addr = receiver.local_addr()?; + poll_fn(|cx| sender.poll_send_to(cx, MSG, &receiver_addr)).await?; + + let mut recv_buf = [0u8; 32]; + let (len, addr) = poll_fn(|cx| receiver.poll_recv_from(cx, &mut recv_buf[..])).await?; + + assert_eq!(&recv_buf[..len], MSG); + assert_eq!(addr, sender.local_addr()?); + Ok(()) +} + #[tokio::test] async fn split() -> std::io::Result<()> { let socket = UdpSocket::bind("127.0.0.1:0").await?; @@ -88,6 +122,40 @@ async fn split_chan() -> std::io::Result<()> { Ok(()) } +#[tokio::test] +async fn split_chan_poll() -> std::io::Result<()> { + // setup UdpSocket that will echo all sent items + let socket = UdpSocket::bind("127.0.0.1:0").await?; + let addr = socket.local_addr().unwrap(); + let s = Arc::new(socket); + let r = s.clone(); + + let (tx, mut rx) = tokio::sync::mpsc::channel::<(Vec, std::net::SocketAddr)>(1_000); + tokio::spawn(async move { + while let Some((bytes, addr)) = rx.recv().await { + poll_fn(|cx| s.poll_send_to(cx, &bytes, &addr)) + .await + .unwrap(); + } + }); + + tokio::spawn(async move { + let mut buf = [0u8; 32]; + loop { + let (len, addr) = poll_fn(|cx| r.poll_recv_from(cx, &mut buf)).await.unwrap(); + tx.send((buf[..len].to_vec(), addr)).await.unwrap(); + } + }); + + // test that we can send a value and get back some response + let sender = UdpSocket::bind("127.0.0.1:0").await?; + poll_fn(|cx| sender.poll_send_to(cx, MSG, &addr)).await?; + let mut recv_buf = [0u8; 32]; + let (len, _) = poll_fn(|cx| sender.poll_recv_from(cx, &mut recv_buf)).await?; + assert_eq!(&recv_buf[..len], MSG); + Ok(()) +} + // # Note // // This test is purposely written such that each time `sender` sends data on