Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement Framed::map_codec #4427

Merged
merged 5 commits into from Jan 27, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
20 changes: 20 additions & 0 deletions tokio-util/src/codec/framed.rs
Expand Up @@ -204,6 +204,26 @@ impl<T, U> Framed<T, U> {
&mut self.inner.codec
}

/// Maps the codec `U` to `C`, preserving the read and write buffers
/// wrapped by `Framed`.
///
/// Note that care should be taken to not tamper with the underlying codec
/// as it may corrupt the stream of frames otherwise being worked with.
pub fn map_codec<C, F>(self, map: F) -> Framed<T, C>
where
F: FnOnce(U) -> C
saiintbrisson marked this conversation as resolved.
Show resolved Hide resolved
{
// This could be potentially simplified once rust-lang/rust#86555 hits stable
let parts = self.into_parts();
Framed::from_parts(FramedParts {
io: parts.io,
codec: map(parts.codec),
read_buf: parts.read_buf,
write_buf: parts.write_buf,
_priv: ()
})
}

/// Returns a mutable reference to the underlying codec wrapped by
/// `Framed`.
///
Expand Down
17 changes: 17 additions & 0 deletions tokio-util/src/codec/framed_read.rs
Expand Up @@ -108,6 +108,23 @@ impl<T, D> FramedRead<T, D> {
&mut self.inner.codec
}

/// Maps the decoder `D` to `C`, preserving the read buffer
/// wrapped by `Framed`.
pub fn map_decoder<C, F>(self, map: F) -> FramedRead<T, C>
where
F: FnOnce(D) -> C
{
// This could be potentially simplified once rust-lang/rust#86555 hits stable
let FramedImpl { inner, state, codec } = self.inner;
FramedRead {
inner: FramedImpl {
inner,
state,
codec: map(codec)
}
}
}

/// Returns a mutable reference to the underlying decoder.
pub fn decoder_pin_mut(self: Pin<&mut Self>) -> &mut D {
self.project().inner.project().codec
Expand Down
17 changes: 17 additions & 0 deletions tokio-util/src/codec/framed_write.rs
Expand Up @@ -88,6 +88,23 @@ impl<T, E> FramedWrite<T, E> {
&mut self.inner.codec
}

/// Maps the encoder `E` to `C`, preserving the write buffer
/// wrapped by `Framed`.
pub fn map_encoder<C, F>(self, map: F) -> FramedWrite<T, C>
where
F: FnOnce(E) -> C
{
// This could be potentially simplified once rust-lang/rust#86555 hits stable
let FramedImpl { inner, state, codec } = self.inner;
FramedWrite {
inner: FramedImpl {
inner,
state,
codec: map(codec)
}
}
}

/// Returns a mutable reference to the underlying encoder.
pub fn encoder_pin_mut(self: Pin<&mut Self>) -> &mut E {
self.project().inner.project().codec
Expand Down
63 changes: 59 additions & 4 deletions tokio-util/tests/framed.rs
Expand Up @@ -12,7 +12,10 @@ use std::task::{Context, Poll};
const INITIAL_CAPACITY: usize = 8 * 1024;

/// Encode and decode u32 values.
struct U32Codec;
#[derive(Default)]
struct U32Codec {
read_bytes: usize
}

impl Decoder for U32Codec {
type Item = u32;
Expand All @@ -24,6 +27,7 @@ impl Decoder for U32Codec {
}

let n = buf.split_to(4).get_u32();
self.read_bytes += 4;
Ok(Some(n))
}
}
Expand All @@ -39,6 +43,38 @@ impl Encoder<u32> for U32Codec {
}
}

/// Encode and decode u64 values.
#[derive(Default)]
struct U64Codec {
read_bytes: usize
}

impl Decoder for U64Codec {
type Item = u64;
type Error = io::Error;

fn decode(&mut self, buf: &mut BytesMut) -> io::Result<Option<u64>> {
if buf.len() < 8 {
return Ok(None);
}

let n = buf.split_to(8).get_u64();
self.read_bytes += 8;
Ok(Some(n))
}
}

impl Encoder<u64> for U64Codec {
type Error = io::Error;

fn encode(&mut self, item: u64, dst: &mut BytesMut) -> io::Result<()> {
// Reserve space
dst.reserve(8);
dst.put_u64(item);
Ok(())
}
}

/// This value should never be used
struct DontReadIntoThis;

Expand All @@ -63,18 +99,37 @@ impl tokio::io::AsyncRead for DontReadIntoThis {

#[tokio::test]
async fn can_read_from_existing_buf() {
let mut parts = FramedParts::new(DontReadIntoThis, U32Codec);
let mut parts = FramedParts::new(DontReadIntoThis, U32Codec::default());
parts.read_buf = BytesMut::from(&[0, 0, 0, 42][..]);

let mut framed = Framed::from_parts(parts);
let num = assert_ok!(framed.next().await.unwrap());

assert_eq!(num, 42);
assert_eq!(framed.codec().read_bytes, 4);
}

#[tokio::test]
async fn can_read_from_existing_buf_after_codec_changed() {
let mut parts = FramedParts::new(DontReadIntoThis, U32Codec::default());
parts.read_buf = BytesMut::from(&[0, 0, 0, 42, 0, 0, 0, 0, 0, 0, 0, 84][..]);

let mut framed = Framed::from_parts(parts);
let num = assert_ok!(framed.next().await.unwrap());

assert_eq!(num, 42);
assert_eq!(framed.codec().read_bytes, 4);

let mut framed = framed.map_codec(|codec| U64Codec { read_bytes: codec.read_bytes });
let num = assert_ok!(framed.next().await.unwrap());

assert_eq!(num, 84);
assert_eq!(framed.codec().read_bytes, 12);
}

#[test]
fn external_buf_grows_to_init() {
let mut parts = FramedParts::new(DontReadIntoThis, U32Codec);
let mut parts = FramedParts::new(DontReadIntoThis, U32Codec::default());
parts.read_buf = BytesMut::from(&[0, 0, 0, 42][..]);

let framed = Framed::from_parts(parts);
Expand All @@ -85,7 +140,7 @@ fn external_buf_grows_to_init() {

#[test]
fn external_buf_does_not_shrink() {
let mut parts = FramedParts::new(DontReadIntoThis, U32Codec);
let mut parts = FramedParts::new(DontReadIntoThis, U32Codec::default());
parts.read_buf = BytesMut::from(&vec![0; INITIAL_CAPACITY * 2][..]);

let framed = Framed::from_parts(parts);
Expand Down
34 changes: 34 additions & 0 deletions tokio-util/tests/framed_read.rs
Expand Up @@ -50,6 +50,22 @@ impl Decoder for U32Decoder {
}
}

struct U64Decoder;

impl Decoder for U64Decoder {
type Item = u64;
type Error = io::Error;

fn decode(&mut self, buf: &mut BytesMut) -> io::Result<Option<u64>> {
if buf.len() < 8 {
return Ok(None);
}

let n = buf.split_to(8).get_u64();
Ok(Some(n))
}
}

#[test]
fn read_multi_frame_in_packet() {
let mut task = task::spawn(());
Expand Down Expand Up @@ -84,6 +100,24 @@ fn read_multi_frame_across_packets() {
});
}

#[test]
fn read_multi_frame_in_packet_after_codec_changed() {
let mut task = task::spawn(());
let mock = mock! {
Ok(b"\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x08".to_vec()),
};
let mut framed = FramedRead::new(mock, U32Decoder);

task.enter(|cx, _| {
assert_read!(pin!(framed).poll_next(cx), 0x04);

let mut framed = framed.map_decoder(|_| U64Decoder);
assert_read!(pin!(framed).poll_next(cx), 0x08);

assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none());
});
}

#[test]
fn read_not_ready() {
let mut task = task::spawn(());
Expand Down
39 changes: 39 additions & 0 deletions tokio-util/tests/framed_write.rs
Expand Up @@ -39,6 +39,19 @@ impl Encoder<u32> for U32Encoder {
}
}

struct U64Encoder;

impl Encoder<u64> for U64Encoder {
type Error = io::Error;

fn encode(&mut self, item: u64, dst: &mut BytesMut) -> io::Result<()> {
// Reserve space
dst.reserve(8);
dst.put_u64(item);
Ok(())
}
}

#[test]
fn write_multi_frame_in_packet() {
let mut task = task::spawn(());
Expand All @@ -65,6 +78,32 @@ fn write_multi_frame_in_packet() {
});
}

#[test]
fn write_multi_frame_after_codec_changed() {
let mut task = task::spawn(());
let mock = mock! {
Ok(b"\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x08".to_vec()),
};
let mut framed = FramedWrite::new(mock, U32Encoder);

task.enter(|cx, _| {
assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok());
assert!(pin!(framed).start_send(0x04).is_ok());

let mut framed = framed.map_encoder(|_| U64Encoder);
assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok());
assert!(pin!(framed).start_send(0x08).is_ok());

// Nothing written yet
assert_eq!(1, framed.get_ref().calls.len());

// Flush the writes
assert!(assert_ready!(pin!(framed).poll_flush(cx)).is_ok());

assert_eq!(0, framed.get_ref().calls.len());
});
}

#[test]
fn write_hits_backpressure() {
const ITER: usize = 2 * 1024;
Expand Down