diff --git a/tokio-util/src/udp/frame.rs b/tokio-util/src/udp/frame.rs index c38663543db..7e152e19abd 100644 --- a/tokio-util/src/udp/frame.rs +++ b/tokio-util/src/udp/frame.rs @@ -6,9 +6,12 @@ use tokio::{io::ReadBuf, net::UdpSocket}; use bytes::{BufMut, BytesMut}; use futures_core::ready; use futures_sink::Sink; -use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; use std::pin::Pin; use std::task::{Context, Poll}; +use std::{ + borrow::Borrow, + net::{Ipv4Addr, SocketAddr, SocketAddrV4}, +}; use std::{io, mem::MaybeUninit}; /// A unified [`Stream`] and [`Sink`] interface to an underlying `UdpSocket`, using @@ -34,8 +37,8 @@ use std::{io, mem::MaybeUninit}; #[must_use = "sinks do nothing unless polled"] #[cfg_attr(docsrs, doc(all(feature = "codec", feature = "udp")))] #[derive(Debug)] -pub struct UdpFramed { - socket: UdpSocket, +pub struct UdpFramed { + socket: T, codec: C, rd: BytesMut, wr: BytesMut, @@ -48,7 +51,13 @@ pub struct UdpFramed { const INITIAL_RD_CAPACITY: usize = 64 * 1024; const INITIAL_WR_CAPACITY: usize = 8 * 1024; -impl Stream for UdpFramed { +impl Unpin for UdpFramed {} + +impl Stream for UdpFramed +where + T: Borrow, + C: Decoder, +{ type Item = Result<(C::Item, SocketAddr), C::Error>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -79,7 +88,7 @@ impl Stream for UdpFramed { let buf = &mut *(pin.rd.chunk_mut() as *mut _ as *mut [MaybeUninit]); let mut read = ReadBuf::uninit(buf); let ptr = read.filled().as_ptr(); - let res = ready!(Pin::new(&mut pin.socket).poll_recv_from(cx, &mut read)); + let res = ready!(pin.socket.borrow().poll_recv_from(cx, &mut read)); assert_eq!(ptr, read.filled().as_ptr()); let addr = res?; @@ -93,7 +102,11 @@ impl Stream for UdpFramed { } } -impl + Unpin> Sink<(I, SocketAddr)> for UdpFramed { +impl Sink<(I, SocketAddr)> for UdpFramed +where + T: Borrow, + C: Encoder, +{ type Error = C::Error; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -125,13 +138,13 @@ impl + Unpin> Sink<(I, SocketAddr)> for UdpFramed { } let Self { - ref mut socket, + ref socket, ref mut out_addr, ref mut wr, .. } = *self; - let n = ready!(socket.poll_send_to(cx, &wr, *out_addr))?; + let n = ready!(socket.borrow().poll_send_to(cx, &wr, *out_addr))?; let wrote_all = n == self.wr.len(); self.wr.clear(); @@ -156,11 +169,14 @@ impl + Unpin> Sink<(I, SocketAddr)> for UdpFramed { } } -impl UdpFramed { +impl UdpFramed +where + T: Borrow, +{ /// Create a new `UdpFramed` backed by the given socket and codec. /// /// See struct level documentation for more details. - pub fn new(socket: UdpSocket, codec: C) -> UdpFramed { + pub fn new(socket: T, codec: C) -> UdpFramed { Self { socket, codec, @@ -180,27 +196,21 @@ impl UdpFramed { /// Care should be taken to not tamper with the underlying stream of data /// coming in as it may corrupt the stream of frames otherwise being worked /// with. - pub fn get_ref(&self) -> &UdpSocket { + pub fn get_ref(&self) -> &T { &self.socket } - /// Returns a mutable reference to the underlying I/O stream wrapped by - /// `Framed`. + /// Returns a mutable reference to the underlying I/O stream wrapped by `Framed`. /// /// # Note /// /// Care should be taken to not tamper with the underlying stream of data /// coming in as it may corrupt the stream of frames otherwise being worked /// with. - pub fn get_mut(&mut self) -> &mut UdpSocket { + pub fn get_mut(&mut self) -> &mut T { &mut self.socket } - /// Consumes the `Framed`, returning its underlying I/O stream. - pub fn into_inner(self) -> UdpSocket { - self.socket - } - /// Returns a reference to the underlying codec wrapped by /// `Framed`. /// @@ -228,4 +238,9 @@ impl UdpFramed { pub fn read_buffer_mut(&mut self) -> &mut BytesMut { &mut self.rd } + + /// Consumes the `Framed`, returning its underlying I/O stream. + pub fn into_inner(self) -> T { + self.socket + } } diff --git a/tokio-util/tests/udp.rs b/tokio-util/tests/udp.rs index 653d20deb51..b9436a30aa6 100644 --- a/tokio-util/tests/udp.rs +++ b/tokio-util/tests/udp.rs @@ -10,6 +10,7 @@ use futures::future::try_join; use futures::future::FutureExt; use futures::sink::SinkExt; use std::io; +use std::sync::Arc; #[cfg_attr(any(target_os = "macos", target_os = "ios"), allow(unused_assignments))] #[tokio::test] @@ -101,3 +102,31 @@ async fn send_framed_lines_codec() -> std::io::Result<()> { Ok(()) } + +#[tokio::test] +async fn framed_half() -> std::io::Result<()> { + let a_soc = Arc::new(UdpSocket::bind("127.0.0.1:0").await?); + let b_soc = a_soc.clone(); + + let a_addr = a_soc.local_addr()?; + let b_addr = b_soc.local_addr()?; + + let mut a = UdpFramed::new(a_soc, ByteCodec); + let mut b = UdpFramed::new(b_soc, LinesCodec::new()); + + let msg = b"1\r\n2\r\n3\r\n".to_vec(); + a.send((&msg, b_addr)).await?; + + let msg = b"4\r\n5\r\n6\r\n".to_vec(); + a.send((&msg, b_addr)).await?; + + assert_eq!(b.next().await.unwrap().unwrap(), ("1".to_string(), a_addr)); + assert_eq!(b.next().await.unwrap().unwrap(), ("2".to_string(), a_addr)); + assert_eq!(b.next().await.unwrap().unwrap(), ("3".to_string(), a_addr)); + + assert_eq!(b.next().await.unwrap().unwrap(), ("4".to_string(), a_addr)); + assert_eq!(b.next().await.unwrap().unwrap(), ("5".to_string(), a_addr)); + assert_eq!(b.next().await.unwrap().unwrap(), ("6".to_string(), a_addr)); + + Ok(()) +}