diff --git a/tokio-util/src/udp/frame.rs b/tokio-util/src/udp/frame.rs index 5b098bd49b2..560f35c9cfa 100644 --- a/tokio-util/src/udp/frame.rs +++ b/tokio-util/src/udp/frame.rs @@ -6,6 +6,7 @@ use bytes::{BufMut, BytesMut}; use futures_core::ready; use futures_sink::Sink; use std::io; +use std::mem::MaybeUninit; use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; use std::pin::Pin; use std::task::{Context, Poll}; @@ -36,6 +37,8 @@ pub struct UdpFramed { wr: BytesMut, out_addr: SocketAddr, flushed: bool, + is_readable: bool, + current_addr: Option, } impl Stream for UdpFramed { @@ -46,27 +49,39 @@ impl Stream for UdpFramed { pin.rd.reserve(INITIAL_RD_CAPACITY); - let (_n, addr) = unsafe { - // Read into the buffer without having to initialize the memory. - // - // safety: we know tokio::net::UdpSocket never reads from the memory - // during a recv - let res = { - let bytes = &mut *(pin.rd.bytes_mut() as *mut _ as *mut [u8]); - ready!(Pin::new(&mut pin.socket).poll_recv_from(cx, bytes)) - }; + loop { + // Are there are still bytes left in the read buffer to decode? + if pin.is_readable { + if let Some(frame) = pin.codec.decode_eof(&mut pin.rd)? { + let current_addr = pin + .current_addr + .expect("will always be set before this line is called"); - let (n, addr) = res?; - pin.rd.advance_mut(n); - (n, addr) - }; + return Poll::Ready(Some(Ok((frame, current_addr)))); + } + + // if this line has been reached then decode has returned `None`. + pin.is_readable = false; + pin.rd.clear(); + } - let frame_res = pin.codec.decode(&mut pin.rd); - pin.rd.clear(); - let frame = frame_res?; - let result = frame.map(|frame| Ok((frame, addr))); // frame -> (frame, addr) + // We're out of data. Try and fetch more data to decode + let addr = unsafe { + // Convert `&mut [MaybeUnit]` to `&mut [u8]` because we will be + // writing to it via `poll_recv_from` and therefore initializing the memory. + let buf: &mut [u8] = + &mut *(pin.rd.bytes_mut() as *mut [MaybeUninit] as *mut [u8]); - Poll::Ready(result) + let res = ready!(Pin::new(&mut pin.socket).poll_recv_from(cx, buf)); + + let (n, addr) = res?; + pin.rd.advance_mut(n); + addr + }; + + pin.current_addr = Some(addr); + pin.is_readable = true; + } } } @@ -148,6 +163,8 @@ impl UdpFramed { rd: BytesMut::with_capacity(INITIAL_RD_CAPACITY), wr: BytesMut::with_capacity(INITIAL_WR_CAPACITY), flushed: true, + is_readable: false, + current_addr: None, } } diff --git a/tokio-util/tests/udp.rs b/tokio-util/tests/udp.rs index 0ba0574281c..d0320beb185 100644 --- a/tokio-util/tests/udp.rs +++ b/tokio-util/tests/udp.rs @@ -1,5 +1,5 @@ use tokio::{net::UdpSocket, stream::StreamExt}; -use tokio_util::codec::{Decoder, Encoder}; +use tokio_util::codec::{Decoder, Encoder, LinesCodec}; use tokio_util::udp::UdpFramed; use bytes::{BufMut, BytesMut}; @@ -10,7 +10,7 @@ use std::io; #[cfg_attr(any(target_os = "macos", target_os = "ios"), allow(unused_assignments))] #[tokio::test] -async fn send_framed() -> std::io::Result<()> { +async fn send_framed_byte_codec() -> std::io::Result<()> { let mut a_soc = UdpSocket::bind("127.0.0.1:0").await?; let mut b_soc = UdpSocket::bind("127.0.0.1:0").await?; @@ -77,3 +77,24 @@ impl Encoder<&[u8]> for ByteCodec { Ok(()) } } + +#[tokio::test] +async fn send_framed_lines_codec() -> std::io::Result<()> { + let a_soc = UdpSocket::bind("127.0.0.1:0").await?; + let b_soc = UdpSocket::bind("127.0.0.1:0").await?; + + let a_addr = a_soc.local_addr()?; + let b_addr = b_soc.local_addr()?; + + let mut a = UdpFramed::new(a_soc, ByteCodec); + let mut b = UdpFramed::new(b_soc, LinesCodec::new()); + + let msg = b"1\r\n2\r\n3\r\n".to_vec(); + a.send((&msg, b_addr)).await?; + + assert_eq!(b.next().await.unwrap().unwrap(), ("1".to_string(), a_addr)); + assert_eq!(b.next().await.unwrap().unwrap(), ("2".to_string(), a_addr)); + assert_eq!(b.next().await.unwrap().unwrap(), ("3".to_string(), a_addr)); + + Ok(()) +}