diff --git a/tokio-util/src/codec/framed.rs b/tokio-util/src/codec/framed.rs index 516590081f4..aff577f22cd 100644 --- a/tokio-util/src/codec/framed.rs +++ b/tokio-util/src/codec/framed.rs @@ -106,6 +106,7 @@ where eof: false, is_readable: false, buffer: BytesMut::with_capacity(capacity), + has_errored: false, }, write: WriteFrame::default(), }, diff --git a/tokio-util/src/codec/framed_impl.rs b/tokio-util/src/codec/framed_impl.rs index f27de028deb..f932414e1e0 100644 --- a/tokio-util/src/codec/framed_impl.rs +++ b/tokio-util/src/codec/framed_impl.rs @@ -27,10 +27,12 @@ pin_project! { const INITIAL_CAPACITY: usize = 8 * 1024; const BACKPRESSURE_BOUNDARY: usize = INITIAL_CAPACITY; +#[derive(Debug)] pub(crate) struct ReadFrame { pub(crate) eof: bool, pub(crate) is_readable: bool, pub(crate) buffer: BytesMut, + pub(crate) has_errored: bool, } pub(crate) struct WriteFrame { @@ -49,6 +51,7 @@ impl Default for ReadFrame { eof: false, is_readable: false, buffer: BytesMut::with_capacity(INITIAL_CAPACITY), + has_errored: false, } } } @@ -72,6 +75,7 @@ impl From for ReadFrame { buffer, is_readable: size > 0, eof: false, + has_errored: false, } } } @@ -126,30 +130,42 @@ where // // The initial state is `reading`. // - // | state | eof | is_readable | - // |---------|-------|-------------| - // | reading | false | false | - // | framing | false | true | - // | pausing | true | true | - // | paused | true | false | - // - // `decode_eof` - // returns `Some` read 0 bytes - // │ │ │ │ - // │ ▼ │ ▼ - // ┌───────┐ `decode_eof` ┌──────┐ - // ┌──read 0 bytes──▶│pausing│─returns `None`─▶│paused│──┐ - // │ └───────┘ └──────┘ │ - // pending read┐ │ ┌──────┐ │ ▲ │ - // │ │ │ │ │ │ │ │ - // │ ▼ │ │ `decode` returns `Some`│ pending read - // │ ╔═══════╗ ┌───────┐◀─┘ │ - // └──║reading║─read n>0 bytes─▶│framing│ │ - // ╚═══════╝ └───────┘◀──────read n>0 bytes┘ - // ▲ │ - // │ │ - // └─`decode` returns `None`─┘ + // | state | eof | is_readable | has_errored | + // |---------|-------|-------------|-------------| + // | reading | false | false | false | + // | framing | false | true | false | + // | pausing | true | true | false | + // | paused | true | false | false | + // | errored | | | true | + // `decode_eof` returns Err + // ┌────────────────────────────────────────────────────────┐ + // `decode_eof` returns │ │ + // `Ok(Some)` │ │ + // ┌─────┐ │ `decode_eof` returns After returning │ + // Read 0 bytes ├─────▼──┴┐ `Ok(None)` ┌────────┐ ◄───┐ `None` ┌───▼─────┐ + // ┌────────────────►│ Pausing ├───────────────────────►│ Paused ├─┐ └───────────┤ Errored │ + // │ └─────────┘ └─┬──▲───┘ │ └───▲───▲─┘ + // Pending read │ │ │ │ │ │ + // ┌──────┐ │ `decode` returns `Some` │ └─────┘ │ │ + // │ │ │ ┌──────┐ │ Pending │ │ + // │ ┌────▼──┴─┐ Read n>0 bytes ┌┴──────▼─┐ read n>0 bytes │ read │ │ + // └─┤ Reading ├───────────────►│ Framing │◄────────────────────────┘ │ │ + // └──┬─▲────┘ └─────┬──┬┘ │ │ + // │ │ │ │ `decode` returns Err │ │ + // │ └───decode` returns `None`──┘ └───────────────────────────────────────────────────────┘ │ + // │ read returns Err │ + // └────────────────────────────────────────────────────────────────────────────────────────────┘ loop { + // Return `None` if we have encountered an error from the underlying decoder + // See: https://github.com/tokio-rs/tokio/issues/3976 + if state.has_errored { + // preparing has_errored -> paused + trace!("Returning None and setting paused"); + state.is_readable = false; + state.has_errored = false; + return Poll::Ready(None); + } + // Repeatedly call `decode` or `decode_eof` while the buffer is "readable", // i.e. it _might_ contain data consumable as a frame or closing frame. // Both signal that there is no such data by returning `None`. @@ -165,7 +181,11 @@ where // pausing or framing if state.eof { // pausing - let frame = pinned.codec.decode_eof(&mut state.buffer)?; + let frame = pinned.codec.decode_eof(&mut state.buffer).map_err(|err| { + trace!("Got an error, going to errored state"); + state.has_errored = true; + err + })?; if frame.is_none() { state.is_readable = false; // prepare pausing -> paused } @@ -176,7 +196,11 @@ where // framing trace!("attempting to decode a frame"); - if let Some(frame) = pinned.codec.decode(&mut state.buffer)? { + if let Some(frame) = pinned.codec.decode(&mut state.buffer).map_err(|op| { + trace!("Got an error, going to errored state"); + state.has_errored = true; + op + })? { trace!("frame decoded from buffer"); // implicit framing -> framing return Poll::Ready(Some(Ok(frame))); @@ -190,7 +214,13 @@ where // Make sure we've got room for at least one byte to read to ensure // that we don't get a spurious 0 that looks like EOF. state.buffer.reserve(1); - let bytect = match poll_read_buf(pinned.inner.as_mut(), cx, &mut state.buffer)? { + let bytect = match poll_read_buf(pinned.inner.as_mut(), cx, &mut state.buffer).map_err( + |err| { + trace!("Got an error, going to errored state"); + state.has_errored = true; + err + }, + )? { Poll::Ready(ct) => ct, // implicit reading -> reading or implicit paused -> paused Poll::Pending => return Poll::Pending, diff --git a/tokio-util/src/codec/framed_read.rs b/tokio-util/src/codec/framed_read.rs index 7347470c409..502a073d0f7 100644 --- a/tokio-util/src/codec/framed_read.rs +++ b/tokio-util/src/codec/framed_read.rs @@ -51,6 +51,7 @@ where eof: false, is_readable: false, buffer: BytesMut::with_capacity(capacity), + has_errored: false, }, }, } diff --git a/tokio-util/tests/framed_stream.rs b/tokio-util/tests/framed_stream.rs new file mode 100644 index 00000000000..76d8af7b7d6 --- /dev/null +++ b/tokio-util/tests/framed_stream.rs @@ -0,0 +1,38 @@ +use futures_core::stream::Stream; +use std::{io, pin::Pin}; +use tokio_test::{assert_ready, io::Builder, task}; +use tokio_util::codec::{BytesCodec, FramedRead}; + +macro_rules! pin { + ($id:ident) => { + Pin::new(&mut $id) + }; +} + +macro_rules! assert_read { + ($e:expr, $n:expr) => {{ + let val = assert_ready!($e); + assert_eq!(val.unwrap().unwrap(), $n); + }}; +} + +#[tokio::test] +async fn return_none_after_error() { + let mut io = FramedRead::new( + Builder::new() + .read(b"abcdef") + .read_error(io::Error::new(io::ErrorKind::Other, "Resource errored out")) + .read(b"more data") + .build(), + BytesCodec::new(), + ); + + let mut task = task::spawn(()); + + task.enter(|cx, _| { + assert_read!(pin!(io).poll_next(cx), b"abcdef".to_vec()); + assert!(assert_ready!(pin!(io).poll_next(cx)).unwrap().is_err()); + assert!(assert_ready!(pin!(io).poll_next(cx)).is_none()); + assert_read!(pin!(io).poll_next(cx), b"more data".to_vec()); + }) +}