Skip to content

Commit

Permalink
Support body health detection
Browse files Browse the repository at this point in the history
  • Loading branch information
sfackler committed Mar 11, 2023
1 parent 9ed175d commit 1b2ca47
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 15 deletions.
3 changes: 3 additions & 0 deletions Cargo.toml
Expand Up @@ -219,3 +219,6 @@ required-features = ["full"]
name = "server"
path = "tests/server.rs"
required-features = ["full"]

[patch.crates-io]
http-body = { git = "https://github.com/sfackler/http-body", branch = "body-poll-alive" }
20 changes: 17 additions & 3 deletions src/proto/h1/dispatch.rs
Expand Up @@ -28,7 +28,8 @@ pub(crate) trait Dispatch {
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> Poll<Option<Result<(Self::PollItem, Self::PollBody), Self::PollError>>>;
fn recv_msg(&mut self, msg: crate::Result<(Self::RecvItem, IncomingBody)>) -> crate::Result<()>;
fn recv_msg(&mut self, msg: crate::Result<(Self::RecvItem, IncomingBody)>)
-> crate::Result<()>;
fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), ()>>;
fn should_poll(&self) -> bool;
}
Expand Down Expand Up @@ -249,7 +250,8 @@ where
let body = match body_len {
DecodedLength::ZERO => IncomingBody::empty(),
other => {
let (tx, rx) = IncomingBody::new_channel(other, wants.contains(Wants::EXPECT));
let (tx, rx) =
IncomingBody::new_channel(other, wants.contains(Wants::EXPECT));
self.body_tx = Some(tx);
rx
}
Expand Down Expand Up @@ -317,7 +319,19 @@ where
return Poll::Ready(Ok(()));
}
} else if !self.conn.can_buffer_body() {
ready!(self.poll_flush(cx))?;
if self.poll_flush(cx)?.is_pending() {
// If we're not able to make progress, check the body health
if let (Some(body), clear_body) =
OptGuard::new(self.body_rx.as_mut()).guard_mut()
{
body.poll_healthy(cx).map_err(|e| {
*clear_body = true;
crate::Error::new_user_body(e)
})?;
}

return Poll::Pending;
}
} else {
// A new scope is needed :(
if let (Some(mut body), clear_body) =
Expand Down
29 changes: 17 additions & 12 deletions src/proto/h2/mod.rs
Expand Up @@ -126,20 +126,29 @@ where

if me.body_tx.capacity() == 0 {
loop {
match ready!(me.body_tx.poll_capacity(cx)) {
Some(Ok(0)) => {}
Some(Ok(_)) => break,
Some(Err(e)) => {
match me.body_tx.poll_capacity(cx) {
Poll::Ready(Some(Ok(0))) => {}
Poll::Ready(Some(Ok(_))) => break,
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Err(crate::Error::new_body_write(e)))
}
None => {
Poll::Ready(None) => {
// None means the stream is no longer in a
// streaming state, we either finished it
// somehow, or the remote reset us.
return Poll::Ready(Err(crate::Error::new_body_write(
"send stream capacity unexpectedly closed",
)));
}
Poll::Pending => {
// If we're not able to make progress, check if the body is healthy
me.stream
.as_mut()
.poll_healthy(cx)
.map_err(|e| me.body_tx.on_user_err(e))?;

return Poll::Pending;
}
}
}
} else if let Poll::Ready(reason) = me
Expand All @@ -148,9 +157,7 @@ where
.map_err(crate::Error::new_body_write)?
{
debug!("stream received RST_STREAM: {:?}", reason);
return Poll::Ready(Err(crate::Error::new_body_write(::h2::Error::from(
reason,
))));
return Poll::Ready(Err(crate::Error::new_body_write(::h2::Error::from(reason))));
}

match ready!(me.stream.as_mut().poll_frame(cx)) {
Expand Down Expand Up @@ -365,14 +372,12 @@ where
cx: &mut Context<'_>,
) -> Poll<Result<(), io::Error>> {
if self.send_stream.write(&[], true).is_ok() {
return Poll::Ready(Ok(()))
return Poll::Ready(Ok(()));
}

Poll::Ready(Err(h2_to_io_error(
match ready!(self.send_stream.poll_reset(cx)) {
Ok(Reason::NO_ERROR) => {
return Poll::Ready(Ok(()))
}
Ok(Reason::NO_ERROR) => return Poll::Ready(Ok(())),
Ok(Reason::CANCEL) | Ok(Reason::STREAM_CLOSED) => {
return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()))
}
Expand Down
145 changes: 145 additions & 0 deletions tests/server.rs
Expand Up @@ -1737,6 +1737,151 @@ async fn http_connect_new() {
assert_eq!(s(&vec), "bar=foo");
}

struct UnhealthyBody {
rx: oneshot::Receiver<()>,
tx: Option<oneshot::Sender<()>>,
}

impl Body for UnhealthyBody {
type Data = Bytes;

type Error = &'static str;

fn poll_frame(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
Poll::Ready(Some(Ok(http_body::Frame::data(Bytes::from_static(
&[0; 1024],
)))))
}

fn poll_healthy(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Result<(), Self::Error> {
if Pin::new(&mut self.rx).poll(cx).is_pending() {
return Ok(());
}

let _ = self.tx.take().unwrap().send(());
Err("blammo")
}
}

#[tokio::test]
async fn h1_unhealthy_body() {
let (listener, addr) = setup_tcp_listener();
let (unhealthy_tx, unhealthy_rx) = oneshot::channel();
let (read_body_tx, read_body_rx) = oneshot::channel();

let client = tokio::spawn(async move {
let mut tcp = connect_async(addr).await;
tcp.write_all(
b"\
GET / HTTP/1.1\r\n\
\r\n\
Host: localhost\r\n\
\r\n
",
)
.await
.expect("write 1");

let mut buf = [0; 1024];
loop {
let nread = tcp.read(&mut buf).await.expect("read 1");
if buf[..nread].contains(&0) {
break;
}
}

read_body_tx.send(()).unwrap();
unhealthy_rx.await.expect("rx");

while tcp.read(&mut buf).await.expect("read") > 0 {}
});

let mut read_body_rx = Some(read_body_rx);
let mut unhealthy_tx = Some(unhealthy_tx);
let svc = service_fn(move |_: Request<IncomingBody>| {
future::ok::<_, &'static str>(
Response::builder()
.status(200)
.body(UnhealthyBody {
rx: read_body_rx.take().unwrap(),
tx: unhealthy_tx.take(),
})
.unwrap(),
)
});

let (socket, _) = listener.accept().await.unwrap();
let err = http1::Builder::new()
.serve_connection(socket, svc)
.await
.err()
.unwrap();
assert!(err.to_string().contains("blammo"));

client.await.unwrap();
}

#[tokio::test]
async fn h2_unhealthy_body() {
let (listener, addr) = setup_tcp_listener();
let (unhealthy_tx, unhealthy_rx) = oneshot::channel();
let (read_body_tx, read_body_rx) = oneshot::channel();

let client = tokio::spawn(async move {
let tcp = connect_async(addr).await;
let (h2, connection) = h2::client::handshake(tcp).await.unwrap();
tokio::spawn(async move {
connection.await.unwrap();
});
let mut h2 = h2.ready().await.unwrap();

let request = Request::get("/").body(()).unwrap();
let (response, _) = h2.send_request(request, true).unwrap();

let mut body = response.await.unwrap().into_body();

let bytes = body.data().await.unwrap().unwrap();
let _ = body.flow_control().release_capacity(bytes.len());

read_body_tx.send(()).unwrap();
unhealthy_rx.await.unwrap();

loop {
let bytes = match body.data().await.transpose() {
Ok(Some(bytes)) => bytes,
Ok(None) => panic!(),
Err(_) => break,
};
let _ = body.flow_control().release_capacity(bytes.len());
}
});

let mut read_body_rx = Some(read_body_rx);
let mut unhealthy_tx = Some(unhealthy_tx);
let svc = service_fn(move |_: Request<IncomingBody>| {
future::ok::<_, &'static str>(
Response::builder()
.status(200)
.body(UnhealthyBody {
rx: read_body_rx.take().unwrap(),
tx: unhealthy_tx.take(),
})
.unwrap(),
)
});

let (socket, _) = listener.accept().await.unwrap();
http2::Builder::new(TokioExecutor)
.serve_connection(socket, svc)
.await
.unwrap();

client.await.unwrap();
}

#[tokio::test]
async fn h2_connect() {
let (listener, addr) = setup_tcp_listener();
Expand Down

0 comments on commit 1b2ca47

Please sign in to comment.