From a78f371759d8b9eae7fc799e7ed74456ccf7d7a0 Mon Sep 17 00:00:00 2001 From: cssivision Date: Sat, 26 Dec 2020 21:14:29 +0800 Subject: [PATCH] net: Add try_read_buf and try_recv_buf --- tokio/src/net/tcp/stream.rs | 85 ++++++++++++++++ tokio/src/net/udp.rs | 140 +++++++++++++++++++++++++- tokio/src/net/unix/datagram/socket.rs | 132 ++++++++++++++++++++++++ tokio/src/net/unix/stream.rs | 87 ++++++++++++++++ tokio/tests/tcp_stream.rs | 78 ++++++++++++++ tokio/tests/udp.rs | 87 ++++++++++++++++ tokio/tests/uds_datagram.rs | 92 +++++++++++++++++ tokio/tests/uds_stream.rs | 82 +++++++++++++++ 8 files changed, 782 insertions(+), 1 deletion(-) diff --git a/tokio/src/net/tcp/stream.rs b/tokio/src/net/tcp/stream.rs index d4bfba4606a..df0a08d583d 100644 --- a/tokio/src/net/tcp/stream.rs +++ b/tokio/src/net/tcp/stream.rs @@ -17,6 +17,8 @@ use std::pin::Pin; use std::task::{Context, Poll}; use std::time::Duration; +use bytes::BufMut; + cfg_net! { /// A TCP stream between a local and a remote socket. /// @@ -559,6 +561,89 @@ impl TcpStream { .try_io(Interest::READABLE, || (&*self.io).read(buf)) } + /// Try to read data from the stream into the provided buffer, advancing the + /// buffer's internal cursor, returning how many bytes were read. + /// + /// Receives any pending data from the socket but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read_buf()` is non-blocking, the buffer does not have to be stored by + /// the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: TcpStream::readable() + /// [`ready()`]: TcpStream::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the stream's read half is closed + /// and will no longer yield data. If the stream is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::TcpStream; + /// use std::error::Error; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box> { + /// // Connect to a peer + /// let stream = TcpStream::connect("127.0.0.1:8080").await?; + /// + /// loop { + /// // Wait for the socket to be readable + /// stream.readable().await?; + /// + /// // Creating the buffer **after** the `await` prevents it from + /// // being stored in the async task. + /// let mut buf = Vec::with_capacity(4096); + /// + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match stream.try_read_buf(&mut buf) { + /// Ok(0) => break, + /// Ok(n) => { + /// println!("read {} bytes", n); + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_read_buf(&self, buf: &mut B) -> io::Result { + if !buf.has_remaining_mut() { + return Ok(0); + } + + self.io.registration().try_io(Interest::READABLE, || { + use std::io::Read; + + let dst = buf.chunk_mut(); + let dst = + unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit] as *mut [u8]) }; + + // Safety: We trust `TcpStream::read` to have filled up `n` bytes in the + // buffer. + let n = (&*self.io).read(dst)?; + + unsafe { + buf.advance_mut(n); + } + + Ok(n) + }) + } + /// Wait for the socket to become writable. /// /// This function is equivalent to `ready(Interest::WRITABLE)` and is usually diff --git a/tokio/src/net/udp.rs b/tokio/src/net/udp.rs index 23abe98e457..2322281738f 100644 --- a/tokio/src/net/udp.rs +++ b/tokio/src/net/udp.rs @@ -7,6 +7,8 @@ use std::io; use std::net::{self, Ipv4Addr, Ipv6Addr, SocketAddr}; use std::task::{Context, Poll}; +use bytes::BufMut; + cfg_net! { /// A UDP socket /// @@ -683,6 +685,77 @@ impl UdpSocket { .try_io(Interest::READABLE, || self.io.recv(buf)) } + /// Try to receive data from the stream into the provided buffer, advancing the + /// buffer's internal cursor, returning how many bytes were read. + /// + /// The function must be called with valid byte array buf of sufficient size + /// to hold the message bytes. If a message is too long to fit in the + /// supplied buffer, excess bytes may be discarded. + /// + /// When there is no pending data, `Err(io::ErrorKind::WouldBlock)` is + /// returned. This function is usually paired with `readable()`. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UdpSocket; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// // Connect to a peer + /// let socket = UdpSocket::bind("127.0.0.1:8080").await?; + /// socket.connect("127.0.0.1:8081").await?; + /// + /// loop { + /// // Wait for the socket to be readable + /// socket.readable().await?; + /// + /// // The buffer is **not** included in the async task and will + /// // only exist on the stack. + /// let mut buf = Vec::with_capacity(1024); + /// + /// // Try to recv data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match socket.try_recv_from(&mut buf) { + /// Ok(n) => { + /// println!("GOT {:?}", &buf[..n]); + /// break; + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_recv_buf(&self, buf: &mut B) -> io::Result { + if !buf.has_remaining_mut() { + return Ok(0); + } + + self.io.registration().try_io(Interest::READABLE, || { + let dst = buf.chunk_mut(); + let dst = + unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit] as *mut [u8]) }; + + // Safety: We trust `UdpSocket::recv` to have filled up `n` bytes in the + // buffer. + let n = (&*self.io).recv(dst)?; + + unsafe { + buf.advance_mut(n); + } + + Ok(n) + }) + } + /// Sends data on the socket to the given address. On success, returns the /// number of bytes written. /// @@ -904,7 +977,6 @@ impl UdpSocket { /// async fn main() -> io::Result<()> { /// // Connect to a peer /// let socket = UdpSocket::bind("127.0.0.1:8080").await?; - /// socket.connect("127.0.0.1:8081").await?; /// /// loop { /// // Wait for the socket to be readable @@ -939,6 +1011,72 @@ impl UdpSocket { .try_io(Interest::READABLE, || self.io.recv_from(buf)) } + /// Try to receive a single datagram message on the socket. On success, + /// returns the number of bytes read and the origin. + /// + /// The function must be called with valid byte array buf of sufficient size + /// to hold the message bytes. If a message is too long to fit in the + /// supplied buffer, excess bytes may be discarded. + /// + /// When there is no pending data, `Err(io::ErrorKind::WouldBlock)` is + /// returned. This function is usually paired with `readable()`. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UdpSocket; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// // Connect to a peer + /// let socket = UdpSocket::bind("127.0.0.1:8080").await?; + /// + /// loop { + /// // Wait for the socket to be readable + /// socket.readable().await?; + /// + /// // The buffer is **not** included in the async task and will + /// // only exist on the stack. + /// let mut buf = Vec::with_capacity(1024); + /// + /// // Try to recv data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match socket.try_recv_from(&mut buf) { + /// Ok((n, _addr)) => { + /// println!("GOT {:?}", &buf[..n]); + /// break; + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_recv_buf_from(&self, buf: &mut B) -> io::Result<(usize, SocketAddr)> { + self.io.registration().try_io(Interest::READABLE, || { + let dst = buf.chunk_mut(); + let dst = + unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit] as *mut [u8]) }; + + // Safety: We trust `UdpSocket::recv_from` to have filled up `n` bytes in the + // buffer. + let (n, addr) = (&*self.io).recv_from(dst)?; + + unsafe { + buf.advance_mut(n); + } + + Ok((n, addr)) + }) + } + /// Receives data from the socket, without removing it from the input queue. /// On success, returns the number of bytes read and the address from whence /// the data came. diff --git a/tokio/src/net/unix/datagram/socket.rs b/tokio/src/net/unix/datagram/socket.rs index fb5f602959d..a5afabb251d 100644 --- a/tokio/src/net/unix/datagram/socket.rs +++ b/tokio/src/net/unix/datagram/socket.rs @@ -10,6 +10,8 @@ use std::os::unix::net; use std::path::Path; use std::task::{Context, Poll}; +use bytes::BufMut; + cfg_net_unix! { /// An I/O object representing a Unix datagram socket. /// @@ -652,6 +654,73 @@ impl UnixDatagram { .try_io(Interest::READABLE, || self.io.recv(buf)) } + /// Try to read data from the stream into the provided buffer, advancing the + /// buffer's internal cursor, returning how many bytes were read. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UnixDatagram; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// // Connect to a peer + /// let dir = tempfile::tempdir().unwrap(); + /// let client_path = dir.path().join("client.sock"); + /// let server_path = dir.path().join("server.sock"); + /// let socket = UnixDatagram::bind(&client_path)?; + /// socket.connect(&server_path)?; + /// + /// loop { + /// // Wait for the socket to be readable + /// socket.readable().await?; + /// + /// // The buffer is **not** included in the async task and will + /// // only exist on the stack. + /// let mut buf = Vec::with_capacity(1024); + /// + /// // Try to recv data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match socket.try_recv_buf(&mut buf) { + /// Ok(n) => { + /// println!("GOT {:?}", &buf[..n]); + /// break; + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_recv_buf(&self, buf: &mut B) -> io::Result { + if !buf.has_remaining_mut() { + return Ok(0); + } + + self.io.registration().try_io(Interest::READABLE, || { + let dst = buf.chunk_mut(); + let dst = + unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit] as *mut [u8]) }; + + // Safety: We trust `UnixDatagram::recv` to have filled up `n` bytes in the + // buffer. + let n = (&*self.io).recv(dst)?; + + unsafe { + buf.advance_mut(n); + } + + Ok(n) + }) + } + /// Sends data on the socket to the specified address. /// /// # Examples @@ -930,6 +999,69 @@ impl UnixDatagram { Ok((n, SocketAddr(addr))) } + /// Try to receive data from the socket without waiting. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UnixDatagram; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// // Connect to a peer + /// let dir = tempfile::tempdir().unwrap(); + /// let client_path = dir.path().join("client.sock"); + /// let server_path = dir.path().join("server.sock"); + /// let socket = UnixDatagram::bind(&client_path)?; + /// + /// loop { + /// // Wait for the socket to be readable + /// socket.readable().await?; + /// + /// // The buffer is **not** included in the async task and will + /// // only exist on the stack. + /// let mut buf = Vec::with_capacity(1024); + /// + /// // Try to recv data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match socket.try_recv_buf_from(&mut buf) { + /// Ok((n, _addr)) => { + /// println!("GOT {:?}", &buf[..n]); + /// break; + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_recv_buf_from(&self, buf: &mut B) -> io::Result<(usize, SocketAddr)> { + let (n, addr) = self.io.registration().try_io(Interest::READABLE, || { + let dst = buf.chunk_mut(); + let dst = + unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit] as *mut [u8]) }; + + // Safety: We trust `UnixDatagram::recv_from` to have filled up `n` bytes in the + // buffer. + let (n, addr) = (&*self.io).recv_from(dst)?; + + unsafe { + buf.advance_mut(n); + } + + Ok((n, addr)) + })?; + + Ok((n, SocketAddr(addr))) + } + /// Returns the local address that this socket is bound to. /// /// # Examples diff --git a/tokio/src/net/unix/stream.rs b/tokio/src/net/unix/stream.rs index dc929dc21ea..ecc5602edb1 100644 --- a/tokio/src/net/unix/stream.rs +++ b/tokio/src/net/unix/stream.rs @@ -15,6 +15,8 @@ use std::path::Path; use std::pin::Pin; use std::task::{Context, Poll}; +use bytes::BufMut; + cfg_net_unix! { /// A structure representing a connected Unix socket. /// @@ -267,6 +269,91 @@ impl UnixStream { .try_io(Interest::READABLE, || (&*self.io).read(buf)) } + /// Try to read data from the stream into the provided buffer, advancing the + /// buffer's internal cursor, returning how many bytes were read. + /// + /// Receives any pending data from the socket but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read_buf()` is non-blocking, the buffer does not have to be stored by + /// the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: UnixStream::readable() + /// [`ready()`]: UnixStream::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the stream's read half is closed + /// and will no longer yield data. If the stream is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UnixStream; + /// use std::error::Error; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box> { + /// // Connect to a peer + /// let dir = tempfile::tempdir().unwrap(); + /// let bind_path = dir.path().join("bind_path"); + /// let stream = UnixStream::connect(bind_path).await?; + /// + /// loop { + /// // Wait for the socket to be readable + /// stream.readable().await?; + /// + /// // Creating the buffer **after** the `await` prevents it from + /// // being stored in the async task. + /// let mut buf = Vec::with_capacity(4096); + /// + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match stream.try_read_buf(&mut buf) { + /// Ok(0) => break, + /// Ok(n) => { + /// println!("read {} bytes", n); + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_read_buf(&self, buf: &mut B) -> io::Result { + if !buf.has_remaining_mut() { + return Ok(0); + } + + self.io.registration().try_io(Interest::READABLE, || { + use std::io::Read; + + let dst = buf.chunk_mut(); + let dst = + unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit] as *mut [u8]) }; + + // Safety: We trust `UnixStream::read` to have filled up `n` bytes in the + // buffer. + let n = (&*self.io).read(dst)?; + + unsafe { + buf.advance_mut(n); + } + + Ok(n) + }) + } + /// Wait for the socket to become writable. /// /// This function is equivalent to `ready(Interest::WRITABLE)` and is usually diff --git a/tokio/tests/tcp_stream.rs b/tokio/tests/tcp_stream.rs index 58b06ee3233..e34c2bbc6f7 100644 --- a/tokio/tests/tcp_stream.rs +++ b/tokio/tests/tcp_stream.rs @@ -234,3 +234,81 @@ fn write_until_pending(stream: &mut TcpStream) { } } } + +#[tokio::test] +async fn try_read_buf() { + const DATA: &[u8] = b"this is some data to write to the socket"; + + // Create listener + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + + // Create socket pair + let client = TcpStream::connect(listener.local_addr().unwrap()) + .await + .unwrap(); + let (server, _) = listener.accept().await.unwrap(); + let mut written = DATA.to_vec(); + + // Track the server receiving data + let mut readable = task::spawn(server.readable()); + assert_pending!(readable.poll()); + + // Write data. + client.writable().await.unwrap(); + assert_eq!(DATA.len(), client.try_write(DATA).unwrap()); + + // The task should be notified + while !readable.is_woken() { + tokio::task::yield_now().await; + } + + // Fill the write buffer + loop { + // Still ready + let mut writable = task::spawn(client.writable()); + assert_ready_ok!(writable.poll()); + + match client.try_write(DATA) { + Ok(n) => written.extend(&DATA[..n]), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + break; + } + Err(e) => panic!("error = {:?}", e), + } + } + + { + // Write buffer full + let mut writable = task::spawn(client.writable()); + assert_pending!(writable.poll()); + + // Drain the socket from the server end + let mut read = Vec::with_capacity(written.len()); + let mut i = 0; + + while i < read.capacity() { + server.readable().await.unwrap(); + + match server.try_read_buf(&mut read) { + Ok(n) => i += n, + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue, + Err(e) => panic!("error = {:?}", e), + } + } + + assert_eq!(read, written); + } + + // Now, we listen for shutdown + drop(client); + + loop { + let ready = server.ready(Interest::READABLE).await.unwrap(); + + if ready.is_read_closed() { + return; + } else { + tokio::task::yield_now().await; + } + } +} diff --git a/tokio/tests/udp.rs b/tokio/tests/udp.rs index 7cbba1b3585..715d8ebe2c7 100644 --- a/tokio/tests/udp.rs +++ b/tokio/tests/udp.rs @@ -353,3 +353,90 @@ async fn try_send_to_recv_from() { } } } + +#[tokio::test] +async fn try_recv_buf() { + // Create listener + let server = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + + // Create socket pair + let client = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + + // Connect the two + client.connect(server.local_addr().unwrap()).await.unwrap(); + server.connect(client.local_addr().unwrap()).await.unwrap(); + + for _ in 0..5 { + loop { + client.writable().await.unwrap(); + + match client.try_send(b"hello world") { + Ok(n) => { + assert_eq!(n, 11); + break; + } + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue, + Err(e) => panic!("{:?}", e), + } + } + + loop { + server.readable().await.unwrap(); + + let mut buf = Vec::with_capacity(512); + + match server.try_recv_buf(&mut buf) { + Ok(n) => { + assert_eq!(n, 11); + assert_eq!(&buf[0..11], &b"hello world"[..]); + break; + } + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue, + Err(e) => panic!("{:?}", e), + } + } + } +} + +#[tokio::test] +async fn try_recv_buf_from() { + // Create listener + let server = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + let saddr = server.local_addr().unwrap(); + + // Create socket pair + let client = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + let caddr = client.local_addr().unwrap(); + + for _ in 0..5 { + loop { + client.writable().await.unwrap(); + + match client.try_send_to(b"hello world", saddr) { + Ok(n) => { + assert_eq!(n, 11); + break; + } + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue, + Err(e) => panic!("{:?}", e), + } + } + + loop { + server.readable().await.unwrap(); + + let mut buf = Vec::with_capacity(512); + + match server.try_recv_buf_from(&mut buf) { + Ok((n, addr)) => { + assert_eq!(n, 11); + assert_eq!(addr, caddr); + assert_eq!(&buf[0..11], &b"hello world"[..]); + break; + } + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue, + Err(e) => panic!("{:?}", e), + } + } + } +} diff --git a/tokio/tests/uds_datagram.rs b/tokio/tests/uds_datagram.rs index cdabd7b196e..10314bebf9c 100644 --- a/tokio/tests/uds_datagram.rs +++ b/tokio/tests/uds_datagram.rs @@ -230,3 +230,95 @@ async fn try_send_to_recv_from() -> std::io::Result<()> { Ok(()) } + +#[tokio::test] +async fn try_recv_buf_from() -> std::io::Result<()> { + let dir = tempfile::tempdir().unwrap(); + let server_path = dir.path().join("server.sock"); + let client_path = dir.path().join("client.sock"); + + // Create listener + let server = UnixDatagram::bind(&server_path)?; + + // Create socket pair + let client = UnixDatagram::bind(&client_path)?; + + for _ in 0..5 { + loop { + client.writable().await?; + + match client.try_send_to(b"hello world", &server_path) { + Ok(n) => { + assert_eq!(n, 11); + break; + } + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue, + Err(e) => panic!("{:?}", e), + } + } + + loop { + server.readable().await?; + + let mut buf = Vec::with_capacity(512); + + match server.try_recv_buf_from(&mut buf) { + Ok((n, addr)) => { + assert_eq!(n, 11); + assert_eq!(addr.as_pathname(), Some(client_path.as_ref())); + assert_eq!(&buf[0..11], &b"hello world"[..]); + break; + } + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue, + Err(e) => panic!("{:?}", e), + } + } + } + + Ok(()) +} + +// Even though we use sync non-blocking io we still need a reactor. +#[tokio::test] +async fn try_recv_buf_never_block() -> io::Result<()> { + let payload = b"PAYLOAD"; + let mut count = 0; + + let (dgram1, dgram2) = UnixDatagram::pair()?; + + // Send until we hit the OS `net.unix.max_dgram_qlen`. + loop { + dgram1.writable().await.unwrap(); + + match dgram1.try_send(payload) { + Err(err) => match err.kind() { + io::ErrorKind::WouldBlock | io::ErrorKind::Other => break, + _ => unreachable!("unexpected error {:?}", err), + }, + Ok(len) => { + assert_eq!(len, payload.len()); + } + } + count += 1; + } + + // Read every dgram we sent. + while count > 0 { + let mut recv_buf = Vec::with_capacity(16); + + dgram2.readable().await.unwrap(); + let len = dgram2.try_recv_buf(&mut recv_buf)?; + assert_eq!(len, payload.len()); + assert_eq!(payload, &recv_buf[..len]); + count -= 1; + } + + let mut recv_buf = vec![0; 16]; + let err = dgram2.try_recv_from(&mut recv_buf).unwrap_err(); + match err.kind() { + io::ErrorKind::WouldBlock => (), + _ => unreachable!("unexpected error {:?}", err), + } + + Ok(()) +} diff --git a/tokio/tests/uds_stream.rs b/tokio/tests/uds_stream.rs index 5160f17e4da..c528620ba9a 100644 --- a/tokio/tests/uds_stream.rs +++ b/tokio/tests/uds_stream.rs @@ -252,3 +252,85 @@ fn write_until_pending(stream: &mut UnixStream) { } } } + +#[tokio::test] +async fn try_read_buf() -> std::io::Result<()> { + let msg = b"hello world"; + + let dir = tempfile::tempdir()?; + let bind_path = dir.path().join("bind.sock"); + + // Create listener + let listener = UnixListener::bind(&bind_path)?; + + // Create socket pair + let client = UnixStream::connect(&bind_path).await?; + + let (server, _) = listener.accept().await?; + let mut written = msg.to_vec(); + + // Track the server receiving data + let mut readable = task::spawn(server.readable()); + assert_pending!(readable.poll()); + + // Write data. + client.writable().await?; + assert_eq!(msg.len(), client.try_write(msg)?); + + // The task should be notified + while !readable.is_woken() { + tokio::task::yield_now().await; + } + + // Fill the write buffer + loop { + // Still ready + let mut writable = task::spawn(client.writable()); + assert_ready_ok!(writable.poll()); + + match client.try_write(msg) { + Ok(n) => written.extend(&msg[..n]), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + break; + } + Err(e) => panic!("error = {:?}", e), + } + } + + { + // Write buffer full + let mut writable = task::spawn(client.writable()); + assert_pending!(writable.poll()); + + // Drain the socket from the server end + let mut read = Vec::with_capacity(written.len()); + let mut i = 0; + + while i < read.capacity() { + server.readable().await?; + + match server.try_read_buf(&mut read) { + Ok(n) => i += n, + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue, + Err(e) => panic!("error = {:?}", e), + } + } + + assert_eq!(read, written); + } + + // Now, we listen for shutdown + drop(client); + + loop { + let ready = server.ready(Interest::READABLE).await?; + + if ready.is_read_closed() { + break; + } else { + tokio::task::yield_now().await; + } + } + + Ok(()) +}