Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(server): Fix race condition in client dispatcher (#2419) #3041

Merged
merged 1 commit into from Nov 7, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
224 changes: 143 additions & 81 deletions src/proto/h2/client.rs
Expand Up @@ -7,19 +7,22 @@ use futures_channel::{mpsc, oneshot};
use futures_util::future::{self, Either, FutureExt as _, TryFutureExt as _};
use futures_util::stream::StreamExt as _;
use h2::client::{Builder, SendRequest};
use h2::SendStream;
use http::{Method, StatusCode};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{debug, trace, warn};

use super::{ping, H2Upgraded, PipeToSendStream, SendBuf};
use crate::body::HttpBody;
use crate::client::dispatch::Callback;
use crate::common::{exec::Exec, task, Future, Never, Pin, Poll};
use crate::ext::Protocol;
use crate::headers;
use crate::proto::h2::UpgradedSendStream;
use crate::proto::Dispatched;
use crate::upgrade::Upgraded;
use crate::{Body, Request, Response};
use h2::client::ResponseFuture;

type ClientRx<B> = crate::client::dispatch::Receiver<Request<B>, Response<Body>>;

Expand Down Expand Up @@ -170,6 +173,7 @@ where
executor: exec,
h2_tx,
req_rx,
fut_ctx: None,
})
}

Expand All @@ -193,6 +197,20 @@ where
}
}

struct FutCtx<B>
where
B: HttpBody,
{
is_connect: bool,
eos: bool,
fut: ResponseFuture,
body_tx: SendStream<SendBuf<B::Data>>,
body: B,
cb: Callback<Request<B>, Response<Body>>,
}

impl<B: HttpBody> Unpin for FutCtx<B> {}

pub(crate) struct ClientTask<B>
where
B: HttpBody,
Expand All @@ -203,6 +221,7 @@ where
executor: Exec,
h2_tx: SendRequest<SendBuf<B::Data>>,
req_rx: ClientRx<B>,
fut_ctx: Option<FutCtx<B>>,
}

impl<B> ClientTask<B>
Expand All @@ -214,6 +233,99 @@ where
}
}

impl<B> ClientTask<B>
where
B: HttpBody + Send + 'static,
B::Data: Send,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
{
fn poll_pipe(&mut self, f: FutCtx<B>, cx: &mut task::Context<'_>) {
let ping = self.ping.clone();
let send_stream = if !f.is_connect {
if !f.eos {
let mut pipe = Box::pin(PipeToSendStream::new(f.body, f.body_tx)).map(|res| {
if let Err(e) = res {
debug!("client request body error: {}", e);
}
});

// eagerly see if the body pipe is ready and
// can thus skip allocating in the executor
match Pin::new(&mut pipe).poll(cx) {
Poll::Ready(_) => (),
Poll::Pending => {
let conn_drop_ref = self.conn_drop_ref.clone();
// keep the ping recorder's knowledge of an
// "open stream" alive while this body is
// still sending...
let ping = ping.clone();
let pipe = pipe.map(move |x| {
drop(conn_drop_ref);
drop(ping);
x
});
// Clear send task
self.executor.execute(pipe);
}
}
}

None
} else {
Some(f.body_tx)
};

let fut = f.fut.map(move |result| match result {
Ok(res) => {
// record that we got the response headers
ping.record_non_data();

let content_length = headers::content_length_parse_all(res.headers());
if let (Some(mut send_stream), StatusCode::OK) = (send_stream, res.status()) {
if content_length.map_or(false, |len| len != 0) {
warn!("h2 connect response with non-zero body not supported");

send_stream.send_reset(h2::Reason::INTERNAL_ERROR);
return Err((
crate::Error::new_h2(h2::Reason::INTERNAL_ERROR.into()),
None,
));
}
let (parts, recv_stream) = res.into_parts();
let mut res = Response::from_parts(parts, Body::empty());

let (pending, on_upgrade) = crate::upgrade::pending();
let io = H2Upgraded {
ping,
send_stream: unsafe { UpgradedSendStream::new(send_stream) },
recv_stream,
buf: Bytes::new(),
};
let upgraded = Upgraded::new(io, Bytes::new());

pending.fulfill(upgraded);
res.extensions_mut().insert(on_upgrade);

Ok(res)
} else {
let res = res.map(|stream| {
let ping = ping.for_stream(&stream);
crate::Body::h2(stream, content_length.into(), ping)
});
Ok(res)
}
}
Err(err) => {
ping.ensure_not_timed_out().map_err(|e| (e, None))?;

debug!("client response error: {}", err);
Err((crate::Error::new_h2(err), None))
}
});
self.executor.execute(f.cb.send_when(fut));
}
}

impl<B> Future for ClientTask<B>
where
B: HttpBody + Send + 'static,
Expand All @@ -237,6 +349,16 @@ where
}
};

match self.fut_ctx.take() {
// If we were waiting on pending open
// continue where we left off.
Some(f) => {
self.poll_pipe(f, cx);
continue;
}
None => (),
}

match self.req_rx.poll_recv(cx) {
Poll::Ready(Some((req, cb))) => {
// check that future hasn't been canceled already
Expand All @@ -255,7 +377,6 @@ where

let is_connect = req.method() == Method::CONNECT;
let eos = body.is_end_stream();
let ping = self.ping.clone();

if is_connect {
if headers::content_length_parse_all(req.headers())
Expand Down Expand Up @@ -283,90 +404,31 @@ where
}
};

let send_stream = if !is_connect {
if !eos {
let mut pipe =
Box::pin(PipeToSendStream::new(body, body_tx)).map(|res| {
if let Err(e) = res {
debug!("client request body error: {}", e);
}
});

// eagerly see if the body pipe is ready and
// can thus skip allocating in the executor
match Pin::new(&mut pipe).poll(cx) {
Poll::Ready(_) => (),
Poll::Pending => {
let conn_drop_ref = self.conn_drop_ref.clone();
// keep the ping recorder's knowledge of an
// "open stream" alive while this body is
// still sending...
let ping = ping.clone();
let pipe = pipe.map(move |x| {
drop(conn_drop_ref);
drop(ping);
x
});
self.executor.execute(pipe);
}
}
}

None
} else {
Some(body_tx)
let f = FutCtx {
is_connect,
eos,
fut,
body_tx,
body,
cb,
};

let fut = fut.map(move |result| match result {
Ok(res) => {
// record that we got the response headers
ping.record_non_data();

let content_length = headers::content_length_parse_all(res.headers());
if let (Some(mut send_stream), StatusCode::OK) =
(send_stream, res.status())
{
if content_length.map_or(false, |len| len != 0) {
warn!("h2 connect response with non-zero body not supported");

send_stream.send_reset(h2::Reason::INTERNAL_ERROR);
return Err((
crate::Error::new_h2(h2::Reason::INTERNAL_ERROR.into()),
None,
));
}
let (parts, recv_stream) = res.into_parts();
let mut res = Response::from_parts(parts, Body::empty());

let (pending, on_upgrade) = crate::upgrade::pending();
let io = H2Upgraded {
ping,
send_stream: unsafe { UpgradedSendStream::new(send_stream) },
recv_stream,
buf: Bytes::new(),
};
let upgraded = Upgraded::new(io, Bytes::new());

pending.fulfill(upgraded);
res.extensions_mut().insert(on_upgrade);

Ok(res)
} else {
let res = res.map(|stream| {
let ping = ping.for_stream(&stream);
crate::Body::h2(stream, content_length.into(), ping)
});
Ok(res)
}
// Check poll_ready() again.
// If the call to send_request() resulted in the new stream being pending open
// we have to wait for the open to complete before accepting new requests.
match self.h2_tx.poll_ready(cx) {
Poll::Pending => {
// Save Context
self.fut_ctx = Some(f);
return Poll::Pending;
}
Err(err) => {
ping.ensure_not_timed_out().map_err(|e| (e, None))?;

debug!("client response error: {}", err);
Err((crate::Error::new_h2(err), None))
Poll::Ready(Ok(())) => (),
Poll::Ready(Err(err)) => {
f.cb.send(Err((crate::Error::new_h2(err), None)));
continue;
}
});
self.executor.execute(cb.send_when(fut));
}
self.poll_pipe(f, cx);
continue;
}

Expand Down