Skip to content

Commit

Permalink
util: make UdpFramed take Borrow<UdpSocket> (#3451)
Browse files Browse the repository at this point in the history
  • Loading branch information
leshow committed Apr 14, 2021
1 parent 39706b1 commit 9eeec03
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 19 deletions.
53 changes: 34 additions & 19 deletions tokio-util/src/udp/frame.rs
Expand Up @@ -6,9 +6,12 @@ use tokio::{io::ReadBuf, net::UdpSocket};
use bytes::{BufMut, BytesMut};
use futures_core::ready;
use futures_sink::Sink;
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{
borrow::Borrow,
net::{Ipv4Addr, SocketAddr, SocketAddrV4},
};
use std::{io, mem::MaybeUninit};

/// A unified [`Stream`] and [`Sink`] interface to an underlying `UdpSocket`, using
Expand All @@ -34,8 +37,8 @@ use std::{io, mem::MaybeUninit};
#[must_use = "sinks do nothing unless polled"]
#[cfg_attr(docsrs, doc(all(feature = "codec", feature = "udp")))]
#[derive(Debug)]
pub struct UdpFramed<C> {
socket: UdpSocket,
pub struct UdpFramed<C, T = UdpSocket> {
socket: T,
codec: C,
rd: BytesMut,
wr: BytesMut,
Expand All @@ -48,7 +51,13 @@ pub struct UdpFramed<C> {
const INITIAL_RD_CAPACITY: usize = 64 * 1024;
const INITIAL_WR_CAPACITY: usize = 8 * 1024;

impl<C: Decoder + Unpin> Stream for UdpFramed<C> {
impl<C, T> Unpin for UdpFramed<C, T> {}

impl<C, T> Stream for UdpFramed<C, T>
where
T: Borrow<UdpSocket>,
C: Decoder,
{
type Item = Result<(C::Item, SocketAddr), C::Error>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Expand Down Expand Up @@ -79,7 +88,7 @@ impl<C: Decoder + Unpin> Stream for UdpFramed<C> {
let buf = &mut *(pin.rd.chunk_mut() as *mut _ as *mut [MaybeUninit<u8>]);
let mut read = ReadBuf::uninit(buf);
let ptr = read.filled().as_ptr();
let res = ready!(Pin::new(&mut pin.socket).poll_recv_from(cx, &mut read));
let res = ready!(pin.socket.borrow().poll_recv_from(cx, &mut read));

assert_eq!(ptr, read.filled().as_ptr());
let addr = res?;
Expand All @@ -93,7 +102,11 @@ impl<C: Decoder + Unpin> Stream for UdpFramed<C> {
}
}

impl<I, C: Encoder<I> + Unpin> Sink<(I, SocketAddr)> for UdpFramed<C> {
impl<I, C, T> Sink<(I, SocketAddr)> for UdpFramed<C, T>
where
T: Borrow<UdpSocket>,
C: Encoder<I>,
{
type Error = C::Error;

fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Expand Down Expand Up @@ -125,13 +138,13 @@ impl<I, C: Encoder<I> + Unpin> Sink<(I, SocketAddr)> for UdpFramed<C> {
}

let Self {
ref mut socket,
ref socket,
ref mut out_addr,
ref mut wr,
..
} = *self;

let n = ready!(socket.poll_send_to(cx, &wr, *out_addr))?;
let n = ready!(socket.borrow().poll_send_to(cx, &wr, *out_addr))?;

let wrote_all = n == self.wr.len();
self.wr.clear();
Expand All @@ -156,11 +169,14 @@ impl<I, C: Encoder<I> + Unpin> Sink<(I, SocketAddr)> for UdpFramed<C> {
}
}

impl<C> UdpFramed<C> {
impl<C, T> UdpFramed<C, T>
where
T: Borrow<UdpSocket>,
{
/// Create a new `UdpFramed` backed by the given socket and codec.
///
/// See struct level documentation for more details.
pub fn new(socket: UdpSocket, codec: C) -> UdpFramed<C> {
pub fn new(socket: T, codec: C) -> UdpFramed<C, T> {
Self {
socket,
codec,
Expand All @@ -180,27 +196,21 @@ impl<C> UdpFramed<C> {
/// Care should be taken to not tamper with the underlying stream of data
/// coming in as it may corrupt the stream of frames otherwise being worked
/// with.
pub fn get_ref(&self) -> &UdpSocket {
pub fn get_ref(&self) -> &T {
&self.socket
}

/// Returns a mutable reference to the underlying I/O stream wrapped by
/// `Framed`.
/// Returns a mutable reference to the underlying I/O stream wrapped by `Framed`.
///
/// # Note
///
/// Care should be taken to not tamper with the underlying stream of data
/// coming in as it may corrupt the stream of frames otherwise being worked
/// with.
pub fn get_mut(&mut self) -> &mut UdpSocket {
pub fn get_mut(&mut self) -> &mut T {
&mut self.socket
}

/// Consumes the `Framed`, returning its underlying I/O stream.
pub fn into_inner(self) -> UdpSocket {
self.socket
}

/// Returns a reference to the underlying codec wrapped by
/// `Framed`.
///
Expand Down Expand Up @@ -228,4 +238,9 @@ impl<C> UdpFramed<C> {
pub fn read_buffer_mut(&mut self) -> &mut BytesMut {
&mut self.rd
}

/// Consumes the `Framed`, returning its underlying I/O stream.
pub fn into_inner(self) -> T {
self.socket
}
}
29 changes: 29 additions & 0 deletions tokio-util/tests/udp.rs
Expand Up @@ -10,6 +10,7 @@ use futures::future::try_join;
use futures::future::FutureExt;
use futures::sink::SinkExt;
use std::io;
use std::sync::Arc;

#[cfg_attr(any(target_os = "macos", target_os = "ios"), allow(unused_assignments))]
#[tokio::test]
Expand Down Expand Up @@ -101,3 +102,31 @@ async fn send_framed_lines_codec() -> std::io::Result<()> {

Ok(())
}

#[tokio::test]
async fn framed_half() -> std::io::Result<()> {
let a_soc = Arc::new(UdpSocket::bind("127.0.0.1:0").await?);
let b_soc = a_soc.clone();

let a_addr = a_soc.local_addr()?;
let b_addr = b_soc.local_addr()?;

let mut a = UdpFramed::new(a_soc, ByteCodec);
let mut b = UdpFramed::new(b_soc, LinesCodec::new());

let msg = b"1\r\n2\r\n3\r\n".to_vec();
a.send((&msg, b_addr)).await?;

let msg = b"4\r\n5\r\n6\r\n".to_vec();
a.send((&msg, b_addr)).await?;

assert_eq!(b.next().await.unwrap().unwrap(), ("1".to_string(), a_addr));
assert_eq!(b.next().await.unwrap().unwrap(), ("2".to_string(), a_addr));
assert_eq!(b.next().await.unwrap().unwrap(), ("3".to_string(), a_addr));

assert_eq!(b.next().await.unwrap().unwrap(), ("4".to_string(), a_addr));
assert_eq!(b.next().await.unwrap().unwrap(), ("5".to_string(), a_addr));
assert_eq!(b.next().await.unwrap().unwrap(), ("6".to_string(), a_addr));

Ok(())
}

0 comments on commit 9eeec03

Please sign in to comment.