Skip to content

Commit

Permalink
codec: add borrow framed
Browse files Browse the repository at this point in the history
  • Loading branch information
suikammd committed Jun 28, 2023
1 parent ce23db6 commit df2b2e5
Show file tree
Hide file tree
Showing 7 changed files with 308 additions and 0 deletions.
78 changes: 78 additions & 0 deletions tokio-util/src/codec/framed.rs
Expand Up @@ -30,6 +30,22 @@ pin_project! {
}
}

pin_project! {
/// A borrowed unified [`Stream`] and [`Sink`] interface to an underlying I/O object, using
/// the `Encoder` and `Decoder` traits to encode and decode frames.
///
/// You can create a `BorrowFramed` instance by using the `with_codec` function of Framed
///
/// [`Stream`]: futures_core::Stream
/// [`Sink`]: futures_sink::Sink
/// [`AsyncRead`]: tokio::io::AsyncRead
/// [`Decoder::framed`]: crate::codec::Decoder::framed()
pub struct BorrowFramed<'borrow, T, U> {
#[pin]
inner: FramedImpl<&'borrow mut T, U, &'borrow mut RWFrames>,
}
}

impl<T, U> Framed<T, U>
where
T: AsyncRead + AsyncWrite,
Expand Down Expand Up @@ -224,6 +240,29 @@ impl<T, U> Framed<T, U> {
})
}

/// Maps the codec `U` to `C` temporarily using &mut self
/// preserving the read and write buffers wrapped by `Framed`.
///
/// Note that care should be taken to not tamper with the underlying codec
/// as it may corrupt the stream of frames otherwise being worked with.
pub fn with_codec<C, F>(&mut self, map: F) -> BorrowFramed<'_, T, C>
where
F: FnOnce(&mut U) -> C,
{
let FramedImpl {
inner,
state,
codec,
} = &mut self.inner;
BorrowFramed {
inner: FramedImpl {
inner,
state,
codec: map(codec),
},
}
}

/// Returns a mutable reference to the underlying codec wrapped by
/// `Framed`.
///
Expand Down Expand Up @@ -341,6 +380,45 @@ where
}
}

// This impl just defers to the underlying FramedImpl
impl<'borrow, T, U> Stream for BorrowFramed<'borrow, T, U>
where
T: AsyncRead + Unpin,
U: Decoder,
{
type Item = Result<U::Item, U::Error>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().inner.poll_next(cx)
}
}

// This impl just defers to the underlying FramedImpl
impl<'borrow, T, I, U> Sink<I> for BorrowFramed<'borrow, T, U>
where
T: AsyncWrite + Unpin,
U: Encoder<I>,
U::Error: From<io::Error>,
{
type Error = U::Error;

fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_ready(cx)
}

fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
self.project().inner.start_send(item)
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_flush(cx)
}

fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_close(cx)
}
}

/// `FramedParts` contains an export of the data of a Framed transport.
/// It can be used to construct a new [`Framed`] with a different codec.
/// It contains all current buffers and the inner transport.
Expand Down
20 changes: 20 additions & 0 deletions tokio-util/src/codec/framed_impl.rs
Expand Up @@ -115,6 +115,26 @@ impl BorrowMut<WriteFrame> for RWFrames {
&mut self.write
}
}
impl Borrow<ReadFrame> for &mut RWFrames {
fn borrow(&self) -> &ReadFrame {
&self.read
}
}
impl BorrowMut<ReadFrame> for &mut RWFrames {
fn borrow_mut(&mut self) -> &mut ReadFrame {
&mut self.read
}
}
impl Borrow<WriteFrame> for &mut RWFrames {
fn borrow(&self) -> &WriteFrame {
&self.write
}
}
impl BorrowMut<WriteFrame> for &mut RWFrames {
fn borrow_mut(&mut self) -> &mut WriteFrame {
&mut self.write
}
}
impl<T, U, R> Stream for FramedImpl<T, U, R>
where
T: AsyncRead,
Expand Down
69 changes: 69 additions & 0 deletions tokio-util/src/codec/framed_read.rs
Expand Up @@ -22,6 +22,17 @@ pin_project! {
}
}

pin_project! {
/// A [`Stream`] of messages decoded from an [`AsyncRead`].
///
/// [`Stream`]: futures_core::Stream
/// [`AsyncRead`]: tokio::io::AsyncRead
pub struct BorrowFramedRead<'borrow, T, D> {
#[pin]
inner: FramedImpl<&'borrow mut T, D, &'borrow mut ReadFrame>,
}
}

// ===== impl FramedRead =====

impl<T, D> FramedRead<T, D>
Expand Down Expand Up @@ -129,6 +140,27 @@ impl<T, D> FramedRead<T, D> {
}
}

/// Maps the decoder `D` to `C` temporarily using &mut self,
/// preserving the read buffer wrapped by `Framed`.
pub fn with_decoder<C, F>(&mut self, map: F) -> BorrowFramedRead<'_, T, C>
where
F: FnOnce(&mut D) -> C,
{
// This could be potentially simplified once rust-lang/rust#86555 hits stable
let FramedImpl {
inner,
state,
codec,
} = &mut self.inner;
BorrowFramedRead {
inner: FramedImpl {
inner,
state,
codec: map(codec),
},
}
}

/// Returns a mutable reference to the underlying decoder.
pub fn decoder_pin_mut(self: Pin<&mut Self>) -> &mut D {
self.project().inner.project().codec
Expand Down Expand Up @@ -197,3 +229,40 @@ where
.finish()
}
}

// This impl just defers to the underlying FramedImpl
impl<'borrow, T, D> Stream for BorrowFramedRead<'borrow, T, D>
where
T: AsyncRead + Unpin,
D: Decoder,
{
type Item = Result<D::Item, D::Error>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().inner.poll_next(cx)
}
}

// This impl just defers to the underlying T: Sink
impl<'borrow, T, I, D> Sink<I> for BorrowFramedRead<'borrow, T, D>
where
T: Sink<I> + Unpin,
{
type Error = T::Error;

fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.project().inner.poll_ready(cx)
}

fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
self.project().inner.project().inner.start_send(item)
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.project().inner.poll_flush(cx)
}

fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.project().inner.poll_close(cx)
}
}
69 changes: 69 additions & 0 deletions tokio-util/src/codec/framed_write.rs
Expand Up @@ -22,6 +22,16 @@ pin_project! {
}
}

pin_project! {
/// A [`Sink`] of frames encoded to an `AsyncWrite`.
///
/// [`Sink`]: futures_sink::Sink
pub struct BorrowFramedWrite<'borrow, T, E> {
#[pin]
inner: FramedImpl<&'borrow mut T, E, &'borrow mut WriteFrame>,
}
}

impl<T, E> FramedWrite<T, E>
where
T: AsyncWrite,
Expand Down Expand Up @@ -109,6 +119,27 @@ impl<T, E> FramedWrite<T, E> {
}
}

/// Maps the encoder `E` to `C` temporarily using &mut self,
/// preserving the write buffer wrapped by `Framed`.
pub fn with_encoder<C, F>(&mut self, map: F) -> BorrowFramedWrite<'_, T, C>
where
F: FnOnce(&mut E) -> C,
{
// This could be potentially simplified once rust-lang/rust#86555 hits stable
let FramedImpl {
inner,
state,
codec,
} = &mut self.inner;
BorrowFramedWrite {
inner: FramedImpl {
inner,
state,
codec: map(codec),
},
}
}

/// Returns a mutable reference to the underlying encoder.
pub fn encoder_pin_mut(self: Pin<&mut Self>) -> &mut E {
self.project().inner.project().codec
Expand Down Expand Up @@ -186,3 +217,41 @@ where
.finish()
}
}

// This impl just defers to the underlying FramedImpl
impl<'borrow, T, I, E> Sink<I> for BorrowFramedWrite<'borrow, T, E>
where
T: AsyncWrite + Unpin,
E: Encoder<I>,
E::Error: From<io::Error>,
{
type Error = E::Error;

fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_ready(cx)
}

fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
self.project().inner.start_send(item)
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_flush(cx)
}

fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_close(cx)
}
}

// This impl just defers to the underlying T: Stream
impl<'borrow, T, D> Stream for BorrowFramedWrite<'borrow, T, D>
where
T: Stream + Unpin,
{
type Item = T::Item;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().inner.project().inner.poll_next(cx)
}
}
27 changes: 27 additions & 0 deletions tokio-util/tests/framed.rs
Expand Up @@ -150,3 +150,30 @@ fn external_buf_does_not_shrink() {

assert_eq!(read_buf.capacity(), INITIAL_CAPACITY * 2);
}

#[tokio::test]
async fn borrow_framed() {
let mut parts = FramedParts::new(DontReadIntoThis, U32Codec::default());
parts.read_buf = BytesMut::from(
&[
0, 0, 0, 42, 0, 0, 0, 0, 0, 0, 0, 84, 0, 0, 0, 0, 0, 0, 0, 84, 0, 0, 0, 42,
][..],
);

let mut framed = Framed::from_parts(parts);
let num = assert_ok!(framed.next().await.unwrap());

assert_eq!(num, 42);
assert_eq!(framed.codec().read_bytes, 4);

let mut borrow_framed = framed.with_codec(|codec| U64Codec {
read_bytes: codec.read_bytes,
});
assert_eq!(assert_ok!(borrow_framed.next().await.unwrap()), 84);
assert_eq!(assert_ok!(borrow_framed.next().await.unwrap()), 84);

let num = assert_ok!(framed.next().await.unwrap());

assert_eq!(num, 42);
assert_eq!(framed.codec().read_bytes, 8);
}
19 changes: 19 additions & 0 deletions tokio-util/tests/framed_read.rs
Expand Up @@ -118,6 +118,25 @@ fn read_multi_frame_in_packet_after_codec_changed() {
});
}

#[test]
fn borrow_framed_read() {
let mut task = task::spawn(());
let mock = mock! {
Ok(b"\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x08\x00\x00\x00\x04".to_vec()),
};
let mut framed = FramedRead::new(mock, U32Decoder);

task.enter(|cx, _| {
assert_read!(pin!(framed).poll_next(cx), 0x04);

let mut borrow_framed = framed.with_decoder(|_| U64Decoder);
assert_read!(pin!(borrow_framed).poll_next(cx), 0x08);

assert_read!(pin!(framed).poll_next(cx), 0x04);
assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none());
});
}

#[test]
fn read_not_ready() {
let mut task = task::spawn(());
Expand Down
26 changes: 26 additions & 0 deletions tokio-util/tests/framed_write.rs
Expand Up @@ -104,6 +104,32 @@ fn write_multi_frame_after_codec_changed() {
});
}

#[test]
fn borrow_framed_write() {
let mut task = task::spawn(());
let mock = mock! {
Ok(b"\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x08".to_vec()),
};
let mut framed = FramedWrite::new(mock, U32Encoder);

task.enter(|cx, _| {
assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok());
assert!(pin!(framed).start_send(0x04).is_ok());

let mut borrow_framed = framed.with_encoder(|_| U64Encoder);
assert!(assert_ready!(pin!(borrow_framed).poll_ready(cx)).is_ok());
assert!(pin!(borrow_framed).start_send(0x08).is_ok());

// Nothing written yet
assert_eq!(1, framed.get_ref().calls.len());

// Flush the writes
assert!(assert_ready!(pin!(framed).poll_flush(cx)).is_ok());

assert_eq!(0, framed.get_ref().calls.len());
});
}

#[test]
fn write_hits_backpressure() {
const ITER: usize = 2 * 1024;
Expand Down

0 comments on commit df2b2e5

Please sign in to comment.