diff --git a/tonic/src/client/grpc.rs b/tonic/src/client/grpc.rs index a1a3675ef..85134ceae 100644 --- a/tonic/src/client/grpc.rs +++ b/tonic/src/client/grpc.rs @@ -2,7 +2,7 @@ use crate::codec::compression::{CompressionEncoding, EnabledCompressionEncodings use crate::{ body::BoxBody, client::GrpcService, - codec::{encode_client, Codec, Streaming}, + codec::{encode_client, Codec, Decoder, Streaming}, request::SanitizeHeaders, Code, Request, Response, Status, }; @@ -30,6 +30,10 @@ use std::fmt; /// [gRPC protocol definition]: https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests pub struct Grpc { inner: T, + config: GrpcConfig, +} + +struct GrpcConfig { origin: Uri, /// Which compression encodings does the client accept? accept_compression_encodings: EnabledCompressionEncodings, @@ -40,12 +44,7 @@ pub struct Grpc { impl Grpc { /// Creates a new gRPC client with the provided [`GrpcService`]. pub fn new(inner: T) -> Self { - Self { - inner, - origin: Uri::default(), - send_compression_encodings: None, - accept_compression_encodings: EnabledCompressionEncodings::default(), - } + Self::with_origin(inner, Uri::default()) } /// Creates a new gRPC client with the provided [`GrpcService`] and `Uri`. @@ -55,9 +54,11 @@ impl Grpc { pub fn with_origin(inner: T, origin: Uri) -> Self { Self { inner, - origin, - send_compression_encodings: None, - accept_compression_encodings: EnabledCompressionEncodings::default(), + config: GrpcConfig { + origin, + send_compression_encodings: None, + accept_compression_encodings: EnabledCompressionEncodings::default(), + }, } } @@ -88,7 +89,7 @@ impl Grpc { /// # }; /// ``` pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { - self.send_compression_encodings = Some(encoding); + self.config.send_compression_encodings = Some(encoding); self } @@ -119,7 +120,7 @@ impl Grpc { /// # }; /// ``` pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { - self.accept_compression_encodings.enable(encoding); + self.config.accept_compression_encodings.enable(encoding); self } @@ -226,6 +227,73 @@ impl Grpc { M1: Send + Sync + 'static, M2: Send + Sync + 'static, { + let request = request + .map(|s| encode_client(codec.encoder(), s, self.config.send_compression_encodings)) + .map(BoxBody::new); + + let request = self.config.prepare_request(request, path); + + let response = self + .inner + .call(request) + .await + .map_err(Status::from_error_generic)?; + + let decoder = codec.decoder(); + + self.create_response(decoder, response) + } + + // Keeping this code in a separate function from Self::streaming lets functions that return the + // same output share the generated binary code + fn create_response( + &self, + decoder: impl Decoder + Send + 'static, + response: http::Response, + ) -> Result>, Status> + where + T: GrpcService, + T::ResponseBody: Body + Send + 'static, + ::Error: Into, + { + let encoding = CompressionEncoding::from_encoding_header( + response.headers(), + self.config.accept_compression_encodings, + )?; + + let status_code = response.status(); + let trailers_only_status = Status::from_header_map(response.headers()); + + // We do not need to check for trailers if the `grpc-status` header is present + // with a valid code. + let expect_additional_trailers = if let Some(status) = trailers_only_status { + if status.code() != Code::Ok { + return Err(status); + } + + false + } else { + true + }; + + let response = response.map(|body| { + if expect_additional_trailers { + Streaming::new_response(decoder, body, status_code, encoding) + } else { + Streaming::new_empty(decoder, body) + } + }); + + Ok(Response::from_http(response)) + } +} + +impl GrpcConfig { + fn prepare_request( + &self, + request: Request>, + path: PathAndQuery, + ) -> http::Request> { let scheme = self.origin.scheme().cloned(); let authority = self.origin.authority().cloned(); @@ -236,10 +304,6 @@ impl Grpc { let uri = Uri::from_parts(parts).expect("path_and_query only is valid Uri"); - let request = request - .map(|s| encode_client(codec.encoder(), s, self.send_compression_encodings)) - .map(BoxBody::new); - let mut request = request.into_http( uri, http::Method::POST, @@ -274,41 +338,7 @@ impl Grpc { ); } - let response = self - .inner - .call(request) - .await - .map_err(|err| Status::from_error(err.into()))?; - - let encoding = CompressionEncoding::from_encoding_header( - response.headers(), - self.accept_compression_encodings, - )?; - - let status_code = response.status(); - let trailers_only_status = Status::from_header_map(response.headers()); - - // We do not need to check for trailers if the `grpc-status` header is present - // with a valid code. - let expect_additional_trailers = if let Some(status) = trailers_only_status { - if status.code() != Code::Ok { - return Err(status); - } - - false - } else { - true - }; - - let response = response.map(|body| { - if expect_additional_trailers { - Streaming::new_response(codec.decoder(), body, status_code, encoding) - } else { - Streaming::new_empty(codec.decoder(), body) - } - }); - - Ok(Response::from_http(response)) + request } } @@ -316,9 +346,11 @@ impl Clone for Grpc { fn clone(&self) -> Self { Self { inner: self.inner.clone(), - origin: self.origin.clone(), - send_compression_encodings: self.send_compression_encodings, - accept_compression_encodings: self.accept_compression_encodings, + config: GrpcConfig { + origin: self.config.origin.clone(), + send_compression_encodings: self.config.send_compression_encodings, + accept_compression_encodings: self.config.accept_compression_encodings, + }, } } } @@ -329,13 +361,16 @@ impl fmt::Debug for Grpc { f.field("inner", &self.inner); - f.field("origin", &self.origin); + f.field("origin", &self.config.origin); - f.field("compression_encoding", &self.send_compression_encodings); + f.field( + "compression_encoding", + &self.config.send_compression_encodings, + ); f.field( "accept_compression_encodings", - &self.accept_compression_encodings, + &self.config.accept_compression_encodings, ); f.finish() diff --git a/tonic/src/codec/decode.rs b/tonic/src/codec/decode.rs index 03614b1ac..bb422652f 100644 --- a/tonic/src/codec/decode.rs +++ b/tonic/src/codec/decode.rs @@ -21,6 +21,10 @@ const BUFFER_SIZE: usize = 8 * 1024; /// to fetch the message stream and trailing metadata pub struct Streaming { decoder: Box + Send + 'static>, + inner: StreamingInner, +} + +struct StreamingInner { body: BoxBody, state: State, direction: Direction, @@ -96,20 +100,157 @@ impl Streaming { { Self { decoder: Box::new(decoder), - body: body - .map_data(|mut buf| buf.copy_to_bytes(buf.remaining())) - .map_err(|err| Status::map_error(err.into())) - .boxed_unsync(), - state: State::ReadHeader, - direction, - buf: BytesMut::with_capacity(BUFFER_SIZE), - trailers: None, - decompress_buf: BytesMut::new(), - encoding, + inner: StreamingInner { + body: body + .map_data(|mut buf| buf.copy_to_bytes(buf.remaining())) + .map_err(|err| Status::map_error(err.into())) + .boxed_unsync(), + state: State::ReadHeader, + direction, + buf: BytesMut::with_capacity(BUFFER_SIZE), + trailers: None, + decompress_buf: BytesMut::new(), + encoding, + }, } } } +impl StreamingInner { + fn decode_chunk(&mut self) -> Result>, Status> { + if let State::ReadHeader = self.state { + if self.buf.remaining() < HEADER_SIZE { + return Ok(None); + } + + let compression_encoding = match self.buf.get_u8() { + 0 => None, + 1 => { + { + if self.encoding.is_some() { + self.encoding + } else { + // https://grpc.github.io/grpc/core/md_doc_compression.html + // An ill-constructed message with its Compressed-Flag bit set but lacking a grpc-encoding + // entry different from identity in its metadata MUST fail with INTERNAL status, + // its associated description indicating the invalid Compressed-Flag condition. + return Err(Status::new(Code::Internal, "protocol error: received message with compressed-flag but no grpc-encoding was specified")); + } + } + } + f => { + trace!("unexpected compression flag"); + let message = if let Direction::Response(status) = self.direction { + format!( + "protocol error: received message with invalid compression flag: {} (valid flags are 0 and 1) while receiving response with status: {}", + f, status + ) + } else { + format!("protocol error: received message with invalid compression flag: {} (valid flags are 0 and 1), while sending request", f) + }; + return Err(Status::new(Code::Internal, message)); + } + }; + let len = self.buf.get_u32() as usize; + self.buf.reserve(len); + + self.state = State::ReadBody { + compression: compression_encoding, + len, + } + } + + if let State::ReadBody { len, compression } = self.state { + // if we haven't read enough of the message then return and keep + // reading + if self.buf.remaining() < len || self.buf.len() < len { + return Ok(None); + } + + let decode_buf = if let Some(encoding) = compression { + self.decompress_buf.clear(); + + if let Err(err) = decompress(encoding, &mut self.buf, &mut self.decompress_buf, len) + { + let message = if let Direction::Response(status) = self.direction { + format!( + "Error decompressing: {}, while receiving response with status: {}", + err, status + ) + } else { + format!("Error decompressing: {}, while sending request", err) + }; + return Err(Status::new(Code::Internal, message)); + } + let decompressed_len = self.decompress_buf.len(); + DecodeBuf::new(&mut self.decompress_buf, decompressed_len) + } else { + DecodeBuf::new(&mut self.buf, len) + }; + + return Ok(Some(decode_buf)); + } + + Ok(None) + } + + // Returns Some(()) if data was found or None if the loop in `poll_next` should break + fn poll_data(&mut self, cx: &mut Context<'_>) -> Poll, Status>> { + let chunk = match ready!(Pin::new(&mut self.body).poll_data(cx)) { + Some(Ok(d)) => Some(d), + Some(Err(e)) => { + let _ = std::mem::replace(&mut self.state, State::Error); + let err: crate::Error = e.into(); + debug!("decoder inner stream error: {:?}", err); + let status = Status::from_error(err); + return Poll::Ready(Err(status)); + } + None => None, + }; + + Poll::Ready(if let Some(data) = chunk { + self.buf.put(data); + Ok(Some(())) + } else { + // FIXME: improve buf usage. + if self.buf.has_remaining() { + trace!("unexpected EOF decoding stream"); + Err(Status::new( + Code::Internal, + "Unexpected EOF decoding stream.".to_string(), + )) + } else { + Ok(None) + } + }) + } + + fn poll_response(&mut self, cx: &mut Context<'_>) -> Poll> { + if let Direction::Response(status) = self.direction { + match ready!(Pin::new(&mut self.body).poll_trailers(cx)) { + Ok(trailer) => { + if let Err(e) = crate::status::infer_grpc_status(trailer.as_ref(), status) { + if let Some(e) = e { + return Poll::Ready(Err(e)); + } else { + return Poll::Ready(Ok(())); + } + } else { + self.trailers = trailer.map(MetadataMap::from_headers); + } + } + Err(e) => { + let err: crate::Error = e.into(); + debug!("decoder inner trailers error: {:?}", err); + let status = Status::from_error(err); + return Poll::Ready(Err(status)); + } + } + } + Poll::Ready(Ok(())) + } +} + impl Streaming { /// Fetch the next message from this stream. /// @@ -165,7 +306,7 @@ impl Streaming { pub async fn trailers(&mut self) -> Result, Status> { // Shortcut to see if we already pulled the trailers in the stream step // we need to do that so that the stream can error on trailing grpc-status - if let Some(trailers) = self.trailers.take() { + if let Some(trailers) = self.inner.trailers.take() { return Ok(Some(trailers)); } @@ -174,13 +315,13 @@ impl Streaming { // Since we call poll_trailers internally on poll_next we need to // check if it got cached again. - if let Some(trailers) = self.trailers.take() { + if let Some(trailers) = self.inner.trailers.take() { return Ok(Some(trailers)); } // Trailers were not caught during poll_next and thus lets poll for // them manually. - let map = future::poll_fn(|cx| Pin::new(&mut self.body).poll_trailers(cx)) + let map = future::poll_fn(|cx| Pin::new(&mut self.inner.body).poll_trailers(cx)) .await .map_err(|e| Status::from_error(Box::new(e))); @@ -188,90 +329,16 @@ impl Streaming { } fn decode_chunk(&mut self) -> Result, Status> { - if let State::ReadHeader = self.state { - if self.buf.remaining() < HEADER_SIZE { - return Ok(None); - } - - let compression_encoding = match self.buf.get_u8() { - 0 => None, - 1 => { - { - if self.encoding.is_some() { - self.encoding - } else { - // https://grpc.github.io/grpc/core/md_doc_compression.html - // An ill-constructed message with its Compressed-Flag bit set but lacking a grpc-encoding - // entry different from identity in its metadata MUST fail with INTERNAL status, - // its associated description indicating the invalid Compressed-Flag condition. - return Err(Status::new(Code::Internal, "protocol error: received message with compressed-flag but no grpc-encoding was specified")); - } - } - } - f => { - trace!("unexpected compression flag"); - let message = if let Direction::Response(status) = self.direction { - format!( - "protocol error: received message with invalid compression flag: {} (valid flags are 0 and 1) while receiving response with status: {}", - f, status - ) - } else { - format!("protocol error: received message with invalid compression flag: {} (valid flags are 0 and 1), while sending request", f) - }; - return Err(Status::new(Code::Internal, message)); - } - }; - let len = self.buf.get_u32() as usize; - self.buf.reserve(len); - - self.state = State::ReadBody { - compression: compression_encoding, - len, - } - } - - if let State::ReadBody { len, compression } = self.state { - // if we haven't read enough of the message then return and keep - // reading - if self.buf.remaining() < len || self.buf.len() < len { - return Ok(None); - } - - let decoding_result = if let Some(encoding) = compression { - self.decompress_buf.clear(); - - if let Err(err) = decompress(encoding, &mut self.buf, &mut self.decompress_buf, len) - { - let message = if let Direction::Response(status) = self.direction { - format!( - "Error decompressing: {}, while receiving response with status: {}", - err, status - ) - } else { - format!("Error decompressing: {}, while sending request", err) - }; - return Err(Status::new(Code::Internal, message)); - } - let decompressed_len = self.decompress_buf.len(); - self.decoder.decode(&mut DecodeBuf::new( - &mut self.decompress_buf, - decompressed_len, - )) - } else { - self.decoder.decode(&mut DecodeBuf::new(&mut self.buf, len)) - }; - - return match decoding_result { - Ok(Some(msg)) => { - self.state = State::ReadHeader; + match self.inner.decode_chunk()? { + Some(mut decode_buf) => match self.decoder.decode(&mut decode_buf)? { + Some(msg) => { + self.inner.state = State::ReadHeader; Ok(Some(msg)) } - Ok(None) => Ok(None), - Err(e) => Err(e), - }; + None => Ok(None), + }, + None => Ok(None), } - - Ok(None) } } @@ -280,7 +347,7 @@ impl Stream for Streaming { fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { - if let State::Error = &self.state { + if let State::Error = &self.inner.state { return Poll::Ready(None); } @@ -291,57 +358,16 @@ impl Stream for Streaming { return Poll::Ready(Some(Ok(item))); } - let chunk = match ready!(Pin::new(&mut self.body).poll_data(cx)) { - Some(Ok(d)) => Some(d), - Some(Err(e)) => { - let _ = std::mem::replace(&mut self.state, State::Error); - let err: crate::Error = e.into(); - debug!("decoder inner stream error: {:?}", err); - let status = Status::from_error(err); - return Poll::Ready(Some(Err(status))); - } - None => None, - }; - - if let Some(data) = chunk { - self.buf.put(data); - } else { - // FIXME: improve buf usage. - if self.buf.has_remaining() { - trace!("unexpected EOF decoding stream"); - return Poll::Ready(Some(Err(Status::new( - Code::Internal, - "Unexpected EOF decoding stream.".to_string(), - )))); - } else { - break; - } - } - } - - if let Direction::Response(status) = self.direction { - match ready!(Pin::new(&mut self.body).poll_trailers(cx)) { - Ok(trailer) => { - if let Err(e) = crate::status::infer_grpc_status(trailer.as_ref(), status) { - if let Some(e) = e { - return Some(Err(e)).into(); - } else { - return Poll::Ready(None); - } - } else { - self.trailers = trailer.map(MetadataMap::from_headers); - } - } - Err(e) => { - let err: crate::Error = e.into(); - debug!("decoder inner trailers error: {:?}", err); - let status = Status::from_error(err); - return Some(Err(status)).into(); - } + match ready!(self.inner.poll_data(cx))? { + Some(()) => (), + None => break, } } - Poll::Ready(None) + Poll::Ready(match ready!(self.inner.poll_response(cx)) { + Ok(()) => None, + Err(err) => Some(Err(err)), + }) } } diff --git a/tonic/src/codec/encode.rs b/tonic/src/codec/encode.rs index fbeb1e870..60a8548ce 100644 --- a/tonic/src/codec/encode.rs +++ b/tonic/src/codec/encode.rs @@ -58,66 +58,80 @@ where T: Encoder, U: Stream>, { - async_stream::stream! { - let mut buf = BytesMut::with_capacity(BUFFER_SIZE); - - let compression_encoding = if compression_override == SingleMessageCompressionOverride::Disable { - None - } else { - compression_encoding - }; - - let mut uncompression_buf = if compression_encoding.is_some() { - BytesMut::with_capacity(BUFFER_SIZE) - } else { - BytesMut::new() - }; - - futures_util::pin_mut!(source); - - loop { - match source.next().await { - Some(Ok(item)) => { - buf.reserve(HEADER_SIZE); - unsafe { - buf.advance_mut(HEADER_SIZE); - } - - if let Some(encoding) = compression_encoding { - uncompression_buf.clear(); - - encoder.encode(item, &mut EncodeBuf::new(&mut uncompression_buf)) - .map_err(|err| Status::internal(format!("Error encoding: {}", err)))?; - - let uncompressed_len = uncompression_buf.len(); - - compress( - encoding, - &mut uncompression_buf, - &mut buf, - uncompressed_len, - ).map_err(|err| Status::internal(format!("Error compressing: {}", err)))?; - } else { - encoder.encode(item, &mut EncodeBuf::new(&mut buf)) - .map_err(|err| Status::internal(format!("Error encoding: {}", err)))?; - } - - // now that we know length, we can write the header - let len = buf.len() - HEADER_SIZE; - assert!(len <= std::u32::MAX as usize); - { - let mut buf = &mut buf[..HEADER_SIZE]; - buf.put_u8(compression_encoding.is_some() as u8); - buf.put_u32(len as u32); - } - - yield Ok(buf.split_to(len + HEADER_SIZE).freeze()); - }, - Some(Err(status)) => yield Err(status), - None => break, - } - } + let mut buf = BytesMut::with_capacity(BUFFER_SIZE); + + let compression_encoding = if compression_override == SingleMessageCompressionOverride::Disable + { + None + } else { + compression_encoding + }; + + let mut uncompression_buf = if compression_encoding.is_some() { + BytesMut::with_capacity(BUFFER_SIZE) + } else { + BytesMut::new() + }; + + source.map(move |result| { + let item = result?; + + encode_item( + &mut encoder, + &mut buf, + &mut uncompression_buf, + compression_encoding, + item, + ) + }) +} + +fn encode_item( + encoder: &mut T, + buf: &mut BytesMut, + uncompression_buf: &mut BytesMut, + compression_encoding: Option, + item: T::Item, +) -> Result +where + T: Encoder, +{ + buf.reserve(HEADER_SIZE); + unsafe { + buf.advance_mut(HEADER_SIZE); + } + + if let Some(encoding) = compression_encoding { + uncompression_buf.clear(); + + encoder + .encode(item, &mut EncodeBuf::new(uncompression_buf)) + .map_err(|err| Status::internal(format!("Error encoding: {}", err)))?; + + let uncompressed_len = uncompression_buf.len(); + + compress(encoding, uncompression_buf, buf, uncompressed_len) + .map_err(|err| Status::internal(format!("Error compressing: {}", err)))?; + } else { + encoder + .encode(item, &mut EncodeBuf::new(buf)) + .map_err(|err| Status::internal(format!("Error encoding: {}", err)))?; } + + // now that we know length, we can write the header + Ok(finish_encoding(compression_encoding, buf)) +} + +fn finish_encoding(compression_encoding: Option, buf: &mut BytesMut) -> Bytes { + let len = buf.len() - HEADER_SIZE; + assert!(len <= std::u32::MAX as usize); + { + let mut buf = &mut buf[..HEADER_SIZE]; + buf.put_u8(compression_encoding.is_some() as u8); + buf.put_u32(len as u32); + } + + buf.split_to(len + HEADER_SIZE).freeze() } #[derive(Debug)] @@ -131,6 +145,11 @@ enum Role { pub(crate) struct EncodeBody { #[pin] inner: S, + state: EncodeState, +} + +#[derive(Debug)] +struct EncodeState { error: Option, role: Role, is_end_stream: bool, @@ -143,18 +162,44 @@ where pub(crate) fn new_client(inner: S) -> Self { Self { inner, - error: None, - role: Role::Client, - is_end_stream: false, + state: EncodeState { + error: None, + role: Role::Client, + is_end_stream: false, + }, } } pub(crate) fn new_server(inner: S) -> Self { Self { inner, - error: None, - role: Role::Server, - is_end_stream: false, + state: EncodeState { + error: None, + role: Role::Server, + is_end_stream: false, + }, + } + } +} + +impl EncodeState { + fn trailers(&mut self) -> Result, Status> { + match self.role { + Role::Client => Ok(None), + Role::Server => { + if self.is_end_stream { + return Ok(None); + } + + let status = if let Some(status) = self.error.take() { + self.is_end_stream = true; + status + } else { + Status::new(Code::Ok, "") + }; + + Ok(Some(status.to_header_map()?)) + } } } } @@ -167,7 +212,7 @@ where type Error = Status; fn is_end_stream(&self) -> bool { - self.is_end_stream + self.state.is_end_stream } fn poll_data( @@ -177,10 +222,10 @@ where let mut self_proj = self.project(); match ready!(self_proj.inner.try_poll_next_unpin(cx)) { Some(Ok(d)) => Some(Ok(d)).into(), - Some(Err(status)) => match self_proj.role { + Some(Err(status)) => match self_proj.state.role { Role::Client => Some(Err(status)).into(), Role::Server => { - *self_proj.error = Some(status); + self_proj.state.error = Some(status); None.into() } }, @@ -192,24 +237,6 @@ where self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll, Status>> { - match self.role { - Role::Client => Poll::Ready(Ok(None)), - Role::Server => { - let self_proj = self.project(); - - if *self_proj.is_end_stream { - return Poll::Ready(Ok(None)); - } - - let status = if let Some(status) = self_proj.error.take() { - *self_proj.is_end_stream = true; - status - } else { - Status::new(Code::Ok, "") - }; - - Poll::Ready(Ok(Some(status.to_header_map()?))) - } - } + Poll::Ready(self.project().state.trailers()) } } diff --git a/tonic/src/status.rs b/tonic/src/status.rs index 1d9d9d7c1..9db79301f 100644 --- a/tonic/src/status.rs +++ b/tonic/src/status.rs @@ -303,6 +303,13 @@ impl Status { Status::new(Code::Unauthenticated, message) } + #[cfg_attr(not(feature = "transport"), allow(dead_code))] + pub(crate) fn from_error_generic( + err: impl Into>, + ) -> Status { + Self::from_error(err.into()) + } + #[cfg_attr(not(feature = "transport"), allow(dead_code))] pub(crate) fn from_error(err: Box) -> Status { Status::try_from_error(err).unwrap_or_else(|err| { diff --git a/tonic/src/transport/server/incoming.rs b/tonic/src/transport/server/incoming.rs index 686aef197..e76ebf964 100644 --- a/tonic/src/transport/server/incoming.rs +++ b/tonic/src/transport/server/incoming.rs @@ -23,13 +23,7 @@ where IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, IE: Into, { - async_stream::try_stream! { - futures_util::pin_mut!(incoming); - - while let Some(stream) = incoming.try_next().await? { - yield ServerIo::new_io(stream); - } - } + incoming.err_into().map_ok(ServerIo::new_io) } #[cfg(feature = "tls")]