From 11a551c15f7fff1e33d99f60d558fa18b211d4b9 Mon Sep 17 00:00:00 2001 From: Markus Westerlind Date: Thu, 30 Jun 2022 14:36:59 +0200 Subject: [PATCH 01/15] refactor: Add a non-generic inner part of Streaming --- tonic/src/codec/decode.rs | 93 ++++++++++++++++++++++----------------- 1 file changed, 52 insertions(+), 41 deletions(-) diff --git a/tonic/src/codec/decode.rs b/tonic/src/codec/decode.rs index 03614b1ac..6d5bda606 100644 --- a/tonic/src/codec/decode.rs +++ b/tonic/src/codec/decode.rs @@ -21,6 +21,10 @@ const BUFFER_SIZE: usize = 8 * 1024; /// to fetch the message stream and trailing metadata pub struct Streaming { decoder: Box + Send + 'static>, + inner: StreamingInner, +} + +struct StreamingInner { body: BoxBody, state: State, direction: Direction, @@ -96,16 +100,18 @@ impl Streaming { { Self { decoder: Box::new(decoder), - body: body - .map_data(|mut buf| buf.copy_to_bytes(buf.remaining())) - .map_err(|err| Status::map_error(err.into())) - .boxed_unsync(), - state: State::ReadHeader, - direction, - buf: BytesMut::with_capacity(BUFFER_SIZE), - trailers: None, - decompress_buf: BytesMut::new(), - encoding, + inner: StreamingInner { + body: body + .map_data(|mut buf| buf.copy_to_bytes(buf.remaining())) + .map_err(|err| Status::map_error(err.into())) + .boxed_unsync(), + state: State::ReadHeader, + direction, + buf: BytesMut::with_capacity(BUFFER_SIZE), + trailers: None, + decompress_buf: BytesMut::new(), + encoding, + }, } } } @@ -165,7 +171,7 @@ impl Streaming { pub async fn trailers(&mut self) -> Result, Status> { // Shortcut to see if we already pulled the trailers in the stream step // we need to do that so that the stream can error on trailing grpc-status - if let Some(trailers) = self.trailers.take() { + if let Some(trailers) = self.inner.trailers.take() { return Ok(Some(trailers)); } @@ -174,13 +180,13 @@ impl Streaming { // Since we call poll_trailers internally on poll_next we need to // check if it got cached again. - if let Some(trailers) = self.trailers.take() { + if let Some(trailers) = self.inner.trailers.take() { return Ok(Some(trailers)); } // Trailers were not caught during poll_next and thus lets poll for // them manually. - let map = future::poll_fn(|cx| Pin::new(&mut self.body).poll_trailers(cx)) + let map = future::poll_fn(|cx| Pin::new(&mut self.inner.body).poll_trailers(cx)) .await .map_err(|e| Status::from_error(Box::new(e))); @@ -188,17 +194,17 @@ impl Streaming { } fn decode_chunk(&mut self) -> Result, Status> { - if let State::ReadHeader = self.state { - if self.buf.remaining() < HEADER_SIZE { + if let State::ReadHeader = self.inner.state { + if self.inner.buf.remaining() < HEADER_SIZE { return Ok(None); } - let compression_encoding = match self.buf.get_u8() { + let compression_encoding = match self.inner.buf.get_u8() { 0 => None, 1 => { { - if self.encoding.is_some() { - self.encoding + if self.inner.encoding.is_some() { + self.inner.encoding } else { // https://grpc.github.io/grpc/core/md_doc_compression.html // An ill-constructed message with its Compressed-Flag bit set but lacking a grpc-encoding @@ -210,7 +216,7 @@ impl Streaming { } f => { trace!("unexpected compression flag"); - let message = if let Direction::Response(status) = self.direction { + let message = if let Direction::Response(status) = self.inner.direction { format!( "protocol error: received message with invalid compression flag: {} (valid flags are 0 and 1) while receiving response with status: {}", f, status @@ -221,28 +227,32 @@ impl Streaming { return Err(Status::new(Code::Internal, message)); } }; - let len = self.buf.get_u32() as usize; - self.buf.reserve(len); + let len = self.inner.buf.get_u32() as usize; + self.inner.buf.reserve(len); - self.state = State::ReadBody { + self.inner.state = State::ReadBody { compression: compression_encoding, len, } } - if let State::ReadBody { len, compression } = self.state { + if let State::ReadBody { len, compression } = self.inner.state { // if we haven't read enough of the message then return and keep // reading - if self.buf.remaining() < len || self.buf.len() < len { + if self.inner.buf.remaining() < len || self.inner.buf.len() < len { return Ok(None); } let decoding_result = if let Some(encoding) = compression { - self.decompress_buf.clear(); - - if let Err(err) = decompress(encoding, &mut self.buf, &mut self.decompress_buf, len) - { - let message = if let Direction::Response(status) = self.direction { + self.inner.decompress_buf.clear(); + + if let Err(err) = decompress( + encoding, + &mut self.inner.buf, + &mut self.inner.decompress_buf, + len, + ) { + let message = if let Direction::Response(status) = self.inner.direction { format!( "Error decompressing: {}, while receiving response with status: {}", err, status @@ -252,18 +262,19 @@ impl Streaming { }; return Err(Status::new(Code::Internal, message)); } - let decompressed_len = self.decompress_buf.len(); + let decompressed_len = self.inner.decompress_buf.len(); self.decoder.decode(&mut DecodeBuf::new( - &mut self.decompress_buf, + &mut self.inner.decompress_buf, decompressed_len, )) } else { - self.decoder.decode(&mut DecodeBuf::new(&mut self.buf, len)) + self.decoder + .decode(&mut DecodeBuf::new(&mut self.inner.buf, len)) }; return match decoding_result { Ok(Some(msg)) => { - self.state = State::ReadHeader; + self.inner.state = State::ReadHeader; Ok(Some(msg)) } Ok(None) => Ok(None), @@ -280,7 +291,7 @@ impl Stream for Streaming { fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { - if let State::Error = &self.state { + if let State::Error = &self.inner.state { return Poll::Ready(None); } @@ -291,10 +302,10 @@ impl Stream for Streaming { return Poll::Ready(Some(Ok(item))); } - let chunk = match ready!(Pin::new(&mut self.body).poll_data(cx)) { + let chunk = match ready!(Pin::new(&mut self.inner.body).poll_data(cx)) { Some(Ok(d)) => Some(d), Some(Err(e)) => { - let _ = std::mem::replace(&mut self.state, State::Error); + let _ = std::mem::replace(&mut self.inner.state, State::Error); let err: crate::Error = e.into(); debug!("decoder inner stream error: {:?}", err); let status = Status::from_error(err); @@ -304,10 +315,10 @@ impl Stream for Streaming { }; if let Some(data) = chunk { - self.buf.put(data); + self.inner.buf.put(data); } else { // FIXME: improve buf usage. - if self.buf.has_remaining() { + if self.inner.buf.has_remaining() { trace!("unexpected EOF decoding stream"); return Poll::Ready(Some(Err(Status::new( Code::Internal, @@ -319,8 +330,8 @@ impl Stream for Streaming { } } - if let Direction::Response(status) = self.direction { - match ready!(Pin::new(&mut self.body).poll_trailers(cx)) { + if let Direction::Response(status) = self.inner.direction { + match ready!(Pin::new(&mut self.inner.body).poll_trailers(cx)) { Ok(trailer) => { if let Err(e) = crate::status::infer_grpc_status(trailer.as_ref(), status) { if let Some(e) = e { @@ -329,7 +340,7 @@ impl Stream for Streaming { return Poll::Ready(None); } } else { - self.trailers = trailer.map(MetadataMap::from_headers); + self.inner.trailers = trailer.map(MetadataMap::from_headers); } } Err(e) => { From 8debe0a7b8badeded468855204e5c0f6553a708d Mon Sep 17 00:00:00 2001 From: Markus Westerlind Date: Thu, 30 Jun 2022 14:44:49 +0200 Subject: [PATCH 02/15] refactor: Factor out the non-generic part of decode_chunk --- tonic/src/codec/decode.rs | 170 +++++++++++++++++++------------------- 1 file changed, 85 insertions(+), 85 deletions(-) diff --git a/tonic/src/codec/decode.rs b/tonic/src/codec/decode.rs index 6d5bda606..56d353341 100644 --- a/tonic/src/codec/decode.rs +++ b/tonic/src/codec/decode.rs @@ -116,6 +116,85 @@ impl Streaming { } } +impl StreamingInner { + fn decode_chunk(&mut self) -> Result>, Status> { + if let State::ReadHeader = self.state { + if self.buf.remaining() < HEADER_SIZE { + return Ok(None); + } + + let compression_encoding = match self.buf.get_u8() { + 0 => None, + 1 => { + { + if self.encoding.is_some() { + self.encoding + } else { + // https://grpc.github.io/grpc/core/md_doc_compression.html + // An ill-constructed message with its Compressed-Flag bit set but lacking a grpc-encoding + // entry different from identity in its metadata MUST fail with INTERNAL status, + // its associated description indicating the invalid Compressed-Flag condition. + return Err(Status::new(Code::Internal, "protocol error: received message with compressed-flag but no grpc-encoding was specified")); + } + } + } + f => { + trace!("unexpected compression flag"); + let message = if let Direction::Response(status) = self.direction { + format!( + "protocol error: received message with invalid compression flag: {} (valid flags are 0 and 1) while receiving response with status: {}", + f, status + ) + } else { + format!("protocol error: received message with invalid compression flag: {} (valid flags are 0 and 1), while sending request", f) + }; + return Err(Status::new(Code::Internal, message)); + } + }; + let len = self.buf.get_u32() as usize; + self.buf.reserve(len); + + self.state = State::ReadBody { + compression: compression_encoding, + len, + } + } + + if let State::ReadBody { len, compression } = self.state { + // if we haven't read enough of the message then return and keep + // reading + if self.buf.remaining() < len || self.buf.len() < len { + return Ok(None); + } + + let decode_buf = if let Some(encoding) = compression { + self.decompress_buf.clear(); + + if let Err(err) = decompress(encoding, &mut self.buf, &mut self.decompress_buf, len) + { + let message = if let Direction::Response(status) = self.direction { + format!( + "Error decompressing: {}, while receiving response with status: {}", + err, status + ) + } else { + format!("Error decompressing: {}, while sending request", err) + }; + return Err(Status::new(Code::Internal, message)); + } + let decompressed_len = self.decompress_buf.len(); + DecodeBuf::new(&mut self.decompress_buf, decompressed_len) + } else { + DecodeBuf::new(&mut self.buf, len) + }; + + return Ok(Some(decode_buf)); + } + + Ok(None) + } +} + impl Streaming { /// Fetch the next message from this stream. /// @@ -194,95 +273,16 @@ impl Streaming { } fn decode_chunk(&mut self) -> Result, Status> { - if let State::ReadHeader = self.inner.state { - if self.inner.buf.remaining() < HEADER_SIZE { - return Ok(None); - } - - let compression_encoding = match self.inner.buf.get_u8() { - 0 => None, - 1 => { - { - if self.inner.encoding.is_some() { - self.inner.encoding - } else { - // https://grpc.github.io/grpc/core/md_doc_compression.html - // An ill-constructed message with its Compressed-Flag bit set but lacking a grpc-encoding - // entry different from identity in its metadata MUST fail with INTERNAL status, - // its associated description indicating the invalid Compressed-Flag condition. - return Err(Status::new(Code::Internal, "protocol error: received message with compressed-flag but no grpc-encoding was specified")); - } - } - } - f => { - trace!("unexpected compression flag"); - let message = if let Direction::Response(status) = self.inner.direction { - format!( - "protocol error: received message with invalid compression flag: {} (valid flags are 0 and 1) while receiving response with status: {}", - f, status - ) - } else { - format!("protocol error: received message with invalid compression flag: {} (valid flags are 0 and 1), while sending request", f) - }; - return Err(Status::new(Code::Internal, message)); - } - }; - let len = self.inner.buf.get_u32() as usize; - self.inner.buf.reserve(len); - - self.inner.state = State::ReadBody { - compression: compression_encoding, - len, - } - } - - if let State::ReadBody { len, compression } = self.inner.state { - // if we haven't read enough of the message then return and keep - // reading - if self.inner.buf.remaining() < len || self.inner.buf.len() < len { - return Ok(None); - } - - let decoding_result = if let Some(encoding) = compression { - self.inner.decompress_buf.clear(); - - if let Err(err) = decompress( - encoding, - &mut self.inner.buf, - &mut self.inner.decompress_buf, - len, - ) { - let message = if let Direction::Response(status) = self.inner.direction { - format!( - "Error decompressing: {}, while receiving response with status: {}", - err, status - ) - } else { - format!("Error decompressing: {}, while sending request", err) - }; - return Err(Status::new(Code::Internal, message)); - } - let decompressed_len = self.inner.decompress_buf.len(); - self.decoder.decode(&mut DecodeBuf::new( - &mut self.inner.decompress_buf, - decompressed_len, - )) - } else { - self.decoder - .decode(&mut DecodeBuf::new(&mut self.inner.buf, len)) - }; - - return match decoding_result { - Ok(Some(msg)) => { + match self.inner.decode_chunk()? { + Some(mut decode_buf) => match self.decoder.decode(&mut decode_buf)? { + Some(msg) => { self.inner.state = State::ReadHeader; Ok(Some(msg)) } - Ok(None) => Ok(None), - Err(e) => Err(e), - }; + None => Ok(None), + }, + None => Ok(None), } - - Ok(None) } } From d684a143310ceb9b72793ebe93d1c721f1947b42 Mon Sep 17 00:00:00 2001 From: Markus Westerlind Date: Thu, 30 Jun 2022 14:59:12 +0200 Subject: [PATCH 03/15] refactor: Factor out a non-generic part of poll_next --- tonic/src/codec/decode.rs | 53 +++++++++++++++++++++------------------ 1 file changed, 29 insertions(+), 24 deletions(-) diff --git a/tonic/src/codec/decode.rs b/tonic/src/codec/decode.rs index 56d353341..ba8dfb313 100644 --- a/tonic/src/codec/decode.rs +++ b/tonic/src/codec/decode.rs @@ -193,6 +193,31 @@ impl StreamingInner { Ok(None) } + + fn poll_response(&mut self, cx: &mut Context<'_>) -> Poll> { + if let Direction::Response(status) = self.direction { + match ready!(Pin::new(&mut self.body).poll_trailers(cx)) { + Ok(trailer) => { + if let Err(e) = crate::status::infer_grpc_status(trailer.as_ref(), status) { + if let Some(e) = e { + return Poll::Ready(Err(e)); + } else { + return Poll::Ready(Ok(())); + } + } else { + self.trailers = trailer.map(MetadataMap::from_headers); + } + } + Err(e) => { + let err: crate::Error = e.into(); + debug!("decoder inner trailers error: {:?}", err); + let status = Status::from_error(err); + return Poll::Ready(Err(status)); + } + } + } + Poll::Ready(Ok(())) + } } impl Streaming { @@ -329,30 +354,10 @@ impl Stream for Streaming { } } } - - if let Direction::Response(status) = self.inner.direction { - match ready!(Pin::new(&mut self.inner.body).poll_trailers(cx)) { - Ok(trailer) => { - if let Err(e) = crate::status::infer_grpc_status(trailer.as_ref(), status) { - if let Some(e) = e { - return Some(Err(e)).into(); - } else { - return Poll::Ready(None); - } - } else { - self.inner.trailers = trailer.map(MetadataMap::from_headers); - } - } - Err(e) => { - let err: crate::Error = e.into(); - debug!("decoder inner trailers error: {:?}", err); - let status = Status::from_error(err); - return Some(Err(status)).into(); - } - } - } - - Poll::Ready(None) + Poll::Ready(match ready!(self.inner.poll_response(cx)) { + Ok(()) => None, + Err(err) => Some(Err(err)), + }) } } From 539b5f418c65a47cf6f4a5cf262e423cffa52e67 Mon Sep 17 00:00:00 2001 From: Markus Westerlind Date: Thu, 30 Jun 2022 15:04:12 +0200 Subject: [PATCH 04/15] refactor: Factor out a non-generic part of poll_next --- tonic/src/codec/decode.rs | 60 +++++++++++++++++++++++---------------- 1 file changed, 35 insertions(+), 25 deletions(-) diff --git a/tonic/src/codec/decode.rs b/tonic/src/codec/decode.rs index ba8dfb313..bb422652f 100644 --- a/tonic/src/codec/decode.rs +++ b/tonic/src/codec/decode.rs @@ -194,6 +194,37 @@ impl StreamingInner { Ok(None) } + // Returns Some(()) if data was found or None if the loop in `poll_next` should break + fn poll_data(&mut self, cx: &mut Context<'_>) -> Poll, Status>> { + let chunk = match ready!(Pin::new(&mut self.body).poll_data(cx)) { + Some(Ok(d)) => Some(d), + Some(Err(e)) => { + let _ = std::mem::replace(&mut self.state, State::Error); + let err: crate::Error = e.into(); + debug!("decoder inner stream error: {:?}", err); + let status = Status::from_error(err); + return Poll::Ready(Err(status)); + } + None => None, + }; + + Poll::Ready(if let Some(data) = chunk { + self.buf.put(data); + Ok(Some(())) + } else { + // FIXME: improve buf usage. + if self.buf.has_remaining() { + trace!("unexpected EOF decoding stream"); + Err(Status::new( + Code::Internal, + "Unexpected EOF decoding stream.".to_string(), + )) + } else { + Ok(None) + } + }) + } + fn poll_response(&mut self, cx: &mut Context<'_>) -> Poll> { if let Direction::Response(status) = self.direction { match ready!(Pin::new(&mut self.body).poll_trailers(cx)) { @@ -327,33 +358,12 @@ impl Stream for Streaming { return Poll::Ready(Some(Ok(item))); } - let chunk = match ready!(Pin::new(&mut self.inner.body).poll_data(cx)) { - Some(Ok(d)) => Some(d), - Some(Err(e)) => { - let _ = std::mem::replace(&mut self.inner.state, State::Error); - let err: crate::Error = e.into(); - debug!("decoder inner stream error: {:?}", err); - let status = Status::from_error(err); - return Poll::Ready(Some(Err(status))); - } - None => None, - }; - - if let Some(data) = chunk { - self.inner.buf.put(data); - } else { - // FIXME: improve buf usage. - if self.inner.buf.has_remaining() { - trace!("unexpected EOF decoding stream"); - return Poll::Ready(Some(Err(Status::new( - Code::Internal, - "Unexpected EOF decoding stream.".to_string(), - )))); - } else { - break; - } + match ready!(self.inner.poll_data(cx))? { + Some(()) => (), + None => break, } } + Poll::Ready(match ready!(self.inner.poll_response(cx)) { Ok(()) => None, Err(err) => Some(Err(err)), From 6434b6aeb5bfbb9af8b26891e8c8f6e04a536bd2 Mon Sep 17 00:00:00 2001 From: Markus Westerlind Date: Thu, 30 Jun 2022 15:18:50 +0200 Subject: [PATCH 05/15] refactor: Avoid using async_stream when not necessary --- tonic/src/codec/encode.rs | 110 ++++++++++++------------- tonic/src/transport/server/incoming.rs | 8 +- 2 files changed, 52 insertions(+), 66 deletions(-) diff --git a/tonic/src/codec/encode.rs b/tonic/src/codec/encode.rs index fbeb1e870..64e2b3c74 100644 --- a/tonic/src/codec/encode.rs +++ b/tonic/src/codec/encode.rs @@ -58,66 +58,58 @@ where T: Encoder, U: Stream>, { - async_stream::stream! { - let mut buf = BytesMut::with_capacity(BUFFER_SIZE); - - let compression_encoding = if compression_override == SingleMessageCompressionOverride::Disable { - None - } else { - compression_encoding - }; - - let mut uncompression_buf = if compression_encoding.is_some() { - BytesMut::with_capacity(BUFFER_SIZE) - } else { - BytesMut::new() - }; - - futures_util::pin_mut!(source); - - loop { - match source.next().await { - Some(Ok(item)) => { - buf.reserve(HEADER_SIZE); - unsafe { - buf.advance_mut(HEADER_SIZE); - } - - if let Some(encoding) = compression_encoding { - uncompression_buf.clear(); - - encoder.encode(item, &mut EncodeBuf::new(&mut uncompression_buf)) - .map_err(|err| Status::internal(format!("Error encoding: {}", err)))?; - - let uncompressed_len = uncompression_buf.len(); - - compress( - encoding, - &mut uncompression_buf, - &mut buf, - uncompressed_len, - ).map_err(|err| Status::internal(format!("Error compressing: {}", err)))?; - } else { - encoder.encode(item, &mut EncodeBuf::new(&mut buf)) - .map_err(|err| Status::internal(format!("Error encoding: {}", err)))?; - } - - // now that we know length, we can write the header - let len = buf.len() - HEADER_SIZE; - assert!(len <= std::u32::MAX as usize); - { - let mut buf = &mut buf[..HEADER_SIZE]; - buf.put_u8(compression_encoding.is_some() as u8); - buf.put_u32(len as u32); - } - - yield Ok(buf.split_to(len + HEADER_SIZE).freeze()); - }, - Some(Err(status)) => yield Err(status), - None => break, + let mut buf = BytesMut::with_capacity(BUFFER_SIZE); + + let compression_encoding = if compression_override == SingleMessageCompressionOverride::Disable + { + None + } else { + compression_encoding + }; + + let mut uncompression_buf = if compression_encoding.is_some() { + BytesMut::with_capacity(BUFFER_SIZE) + } else { + BytesMut::new() + }; + + source.and_then(move |item| { + let result = (|| { + buf.reserve(HEADER_SIZE); + unsafe { + buf.advance_mut(HEADER_SIZE); } - } - } + + if let Some(encoding) = compression_encoding { + uncompression_buf.clear(); + + encoder + .encode(item, &mut EncodeBuf::new(&mut uncompression_buf)) + .map_err(|err| Status::internal(format!("Error encoding: {}", err)))?; + + let uncompressed_len = uncompression_buf.len(); + + compress(encoding, &mut uncompression_buf, &mut buf, uncompressed_len) + .map_err(|err| Status::internal(format!("Error compressing: {}", err)))?; + } else { + encoder + .encode(item, &mut EncodeBuf::new(&mut buf)) + .map_err(|err| Status::internal(format!("Error encoding: {}", err)))?; + } + + // now that we know length, we can write the header + let len = buf.len() - HEADER_SIZE; + assert!(len <= std::u32::MAX as usize); + { + let mut buf = &mut buf[..HEADER_SIZE]; + buf.put_u8(compression_encoding.is_some() as u8); + buf.put_u32(len as u32); + } + + Ok(buf.split_to(len + HEADER_SIZE).freeze()) + })(); + async { result } + }) } #[derive(Debug)] diff --git a/tonic/src/transport/server/incoming.rs b/tonic/src/transport/server/incoming.rs index 686aef197..e76ebf964 100644 --- a/tonic/src/transport/server/incoming.rs +++ b/tonic/src/transport/server/incoming.rs @@ -23,13 +23,7 @@ where IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, IE: Into, { - async_stream::try_stream! { - futures_util::pin_mut!(incoming); - - while let Some(stream) = incoming.try_next().await? { - yield ServerIo::new_io(stream); - } - } + incoming.err_into().map_ok(ServerIo::new_io) } #[cfg(feature = "tls")] From cb40df814b257a52d97b2291075df1704ad1b871 Mon Sep 17 00:00:00 2001 From: Markus Westerlind Date: Thu, 30 Jun 2022 15:31:21 +0200 Subject: [PATCH 06/15] refactor: Factor out a non-generic part of encode --- tonic/src/codec/encode.rs | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/tonic/src/codec/encode.rs b/tonic/src/codec/encode.rs index 64e2b3c74..21cba8618 100644 --- a/tonic/src/codec/encode.rs +++ b/tonic/src/codec/encode.rs @@ -98,20 +98,24 @@ where } // now that we know length, we can write the header - let len = buf.len() - HEADER_SIZE; - assert!(len <= std::u32::MAX as usize); - { - let mut buf = &mut buf[..HEADER_SIZE]; - buf.put_u8(compression_encoding.is_some() as u8); - buf.put_u32(len as u32); - } - - Ok(buf.split_to(len + HEADER_SIZE).freeze()) + Ok(finish_encoding(compression_encoding, &mut buf)) })(); async { result } }) } +fn finish_encoding(compression_encoding: Option, buf: &mut BytesMut) -> Bytes { + let len = buf.len() - HEADER_SIZE; + assert!(len <= std::u32::MAX as usize); + { + let mut buf = &mut buf[..HEADER_SIZE]; + buf.put_u8(compression_encoding.is_some() as u8); + buf.put_u32(len as u32); + } + + buf.split_to(len + HEADER_SIZE).freeze() +} + #[derive(Debug)] enum Role { Client, From d7eeb1fe7e8fd4514a3b6ef2c6b150be39a00250 Mon Sep 17 00:00:00 2001 From: Markus Westerlind Date: Thu, 30 Jun 2022 16:01:03 +0200 Subject: [PATCH 07/15] refactor: Factor out a less generic part of streaming --- tonic/src/client/grpc.rs | 131 +++++++++++++++++++++++---------------- 1 file changed, 78 insertions(+), 53 deletions(-) diff --git a/tonic/src/client/grpc.rs b/tonic/src/client/grpc.rs index a1a3675ef..50dcf859b 100644 --- a/tonic/src/client/grpc.rs +++ b/tonic/src/client/grpc.rs @@ -30,6 +30,10 @@ use std::fmt; /// [gRPC protocol definition]: https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests pub struct Grpc { inner: T, + config: GrpcConfig, +} + +struct GrpcConfig { origin: Uri, /// Which compression encodings does the client accept? accept_compression_encodings: EnabledCompressionEncodings, @@ -42,9 +46,11 @@ impl Grpc { pub fn new(inner: T) -> Self { Self { inner, - origin: Uri::default(), - send_compression_encodings: None, - accept_compression_encodings: EnabledCompressionEncodings::default(), + config: GrpcConfig { + origin: Uri::default(), + send_compression_encodings: None, + accept_compression_encodings: EnabledCompressionEncodings::default(), + }, } } @@ -55,9 +61,11 @@ impl Grpc { pub fn with_origin(inner: T, origin: Uri) -> Self { Self { inner, - origin, - send_compression_encodings: None, - accept_compression_encodings: EnabledCompressionEncodings::default(), + config: GrpcConfig { + origin, + send_compression_encodings: None, + accept_compression_encodings: EnabledCompressionEncodings::default(), + }, } } @@ -88,7 +96,7 @@ impl Grpc { /// # }; /// ``` pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { - self.send_compression_encodings = Some(encoding); + self.config.send_compression_encodings = Some(encoding); self } @@ -119,7 +127,7 @@ impl Grpc { /// # }; /// ``` pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { - self.accept_compression_encodings.enable(encoding); + self.config.accept_compression_encodings.enable(encoding); self } @@ -226,6 +234,56 @@ impl Grpc { M1: Send + Sync + 'static, M2: Send + Sync + 'static, { + let request = request + .map(|s| encode_client(codec.encoder(), s, self.config.send_compression_encodings)) + .map(BoxBody::new); + + let request = self.config.prepare_request(request, path); + + let response = self + .inner + .call(request) + .await + .map_err(|err| Status::from_error(err.into()))?; + + let encoding = CompressionEncoding::from_encoding_header( + response.headers(), + self.config.accept_compression_encodings, + )?; + + let status_code = response.status(); + let trailers_only_status = Status::from_header_map(response.headers()); + + // We do not need to check for trailers if the `grpc-status` header is present + // with a valid code. + let expect_additional_trailers = if let Some(status) = trailers_only_status { + if status.code() != Code::Ok { + return Err(status); + } + + false + } else { + true + }; + + let response = response.map(|body| { + if expect_additional_trailers { + Streaming::new_response(codec.decoder(), body, status_code, encoding) + } else { + Streaming::new_empty(codec.decoder(), body) + } + }); + + Ok(Response::from_http(response)) + } +} + +impl GrpcConfig { + fn prepare_request( + &self, + request: Request>, + path: PathAndQuery, + ) -> http::Request> { let scheme = self.origin.scheme().cloned(); let authority = self.origin.authority().cloned(); @@ -236,10 +294,6 @@ impl Grpc { let uri = Uri::from_parts(parts).expect("path_and_query only is valid Uri"); - let request = request - .map(|s| encode_client(codec.encoder(), s, self.send_compression_encodings)) - .map(BoxBody::new); - let mut request = request.into_http( uri, http::Method::POST, @@ -274,41 +328,7 @@ impl Grpc { ); } - let response = self - .inner - .call(request) - .await - .map_err(|err| Status::from_error(err.into()))?; - - let encoding = CompressionEncoding::from_encoding_header( - response.headers(), - self.accept_compression_encodings, - )?; - - let status_code = response.status(); - let trailers_only_status = Status::from_header_map(response.headers()); - - // We do not need to check for trailers if the `grpc-status` header is present - // with a valid code. - let expect_additional_trailers = if let Some(status) = trailers_only_status { - if status.code() != Code::Ok { - return Err(status); - } - - false - } else { - true - }; - - let response = response.map(|body| { - if expect_additional_trailers { - Streaming::new_response(codec.decoder(), body, status_code, encoding) - } else { - Streaming::new_empty(codec.decoder(), body) - } - }); - - Ok(Response::from_http(response)) + request } } @@ -316,9 +336,11 @@ impl Clone for Grpc { fn clone(&self) -> Self { Self { inner: self.inner.clone(), - origin: self.origin.clone(), - send_compression_encodings: self.send_compression_encodings, - accept_compression_encodings: self.accept_compression_encodings, + config: GrpcConfig { + origin: self.config.origin.clone(), + send_compression_encodings: self.config.send_compression_encodings, + accept_compression_encodings: self.config.accept_compression_encodings, + }, } } } @@ -329,13 +351,16 @@ impl fmt::Debug for Grpc { f.field("inner", &self.inner); - f.field("origin", &self.origin); + f.field("origin", &self.config.origin); - f.field("compression_encoding", &self.send_compression_encodings); + f.field( + "compression_encoding", + &self.config.send_compression_encodings, + ); f.field( "accept_compression_encodings", - &self.accept_compression_encodings, + &self.config.accept_compression_encodings, ); f.finish() From adfa597ec5307f0e6308241d7a54656579ed4208 Mon Sep 17 00:00:00 2001 From: Markus Westerlind Date: Thu, 30 Jun 2022 16:05:48 +0200 Subject: [PATCH 08/15] refactor: Implement Grpc::new in terms of with_origin --- tonic/src/client/grpc.rs | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/tonic/src/client/grpc.rs b/tonic/src/client/grpc.rs index 50dcf859b..6a25ee109 100644 --- a/tonic/src/client/grpc.rs +++ b/tonic/src/client/grpc.rs @@ -44,14 +44,7 @@ struct GrpcConfig { impl Grpc { /// Creates a new gRPC client with the provided [`GrpcService`]. pub fn new(inner: T) -> Self { - Self { - inner, - config: GrpcConfig { - origin: Uri::default(), - send_compression_encodings: None, - accept_compression_encodings: EnabledCompressionEncodings::default(), - }, - } + Self::with_origin(inner, Uri::default()) } /// Creates a new gRPC client with the provided [`GrpcService`] and `Uri`. From 97fd87c8a8e1b6123cb8373d0c38c5901cc1e860 Mon Sep 17 00:00:00 2001 From: Markus Westerlind Date: Thu, 30 Jun 2022 16:14:11 +0200 Subject: [PATCH 09/15] refactor: Factor out poll_trailers to a non-generic function --- tonic/src/codec/encode.rs | 69 +++++++++++++++++++++++---------------- 1 file changed, 41 insertions(+), 28 deletions(-) diff --git a/tonic/src/codec/encode.rs b/tonic/src/codec/encode.rs index 21cba8618..06a80a5d3 100644 --- a/tonic/src/codec/encode.rs +++ b/tonic/src/codec/encode.rs @@ -127,6 +127,11 @@ enum Role { pub(crate) struct EncodeBody { #[pin] inner: S, + state: EncodeState, +} + +#[derive(Debug)] +struct EncodeState { error: Option, role: Role, is_end_stream: bool, @@ -139,18 +144,44 @@ where pub(crate) fn new_client(inner: S) -> Self { Self { inner, - error: None, - role: Role::Client, - is_end_stream: false, + state: EncodeState { + error: None, + role: Role::Client, + is_end_stream: false, + }, } } pub(crate) fn new_server(inner: S) -> Self { Self { inner, - error: None, - role: Role::Server, - is_end_stream: false, + state: EncodeState { + error: None, + role: Role::Server, + is_end_stream: false, + }, + } + } +} + +impl EncodeState { + fn trailers(&mut self) -> Result, Status> { + match self.role { + Role::Client => Ok(None), + Role::Server => { + if self.is_end_stream { + return Ok(None); + } + + let status = if let Some(status) = self.error.take() { + self.is_end_stream = true; + status + } else { + Status::new(Code::Ok, "") + }; + + Ok(Some(status.to_header_map()?)) + } } } } @@ -163,7 +194,7 @@ where type Error = Status; fn is_end_stream(&self) -> bool { - self.is_end_stream + self.state.is_end_stream } fn poll_data( @@ -173,10 +204,10 @@ where let mut self_proj = self.project(); match ready!(self_proj.inner.try_poll_next_unpin(cx)) { Some(Ok(d)) => Some(Ok(d)).into(), - Some(Err(status)) => match self_proj.role { + Some(Err(status)) => match self_proj.state.role { Role::Client => Some(Err(status)).into(), Role::Server => { - *self_proj.error = Some(status); + self_proj.state.error = Some(status); None.into() } }, @@ -188,24 +219,6 @@ where self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll, Status>> { - match self.role { - Role::Client => Poll::Ready(Ok(None)), - Role::Server => { - let self_proj = self.project(); - - if *self_proj.is_end_stream { - return Poll::Ready(Ok(None)); - } - - let status = if let Some(status) = self_proj.error.take() { - *self_proj.is_end_stream = true; - status - } else { - Status::new(Code::Ok, "") - }; - - Poll::Ready(Ok(Some(status.to_header_map()?))) - } - } + Poll::Ready(self.project().state.trailers()) } } From 9a1a22a325134f486f3ec82ebb26e9c0bcb40e26 Mon Sep 17 00:00:00 2001 From: Markus Westerlind Date: Wed, 10 Aug 2022 11:53:56 +0200 Subject: [PATCH 10/15] refactor: Remove an unnecessary closure --- tonic/src/codec/encode.rs | 45 +++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/tonic/src/codec/encode.rs b/tonic/src/codec/encode.rs index 06a80a5d3..eee878dc8 100644 --- a/tonic/src/codec/encode.rs +++ b/tonic/src/codec/encode.rs @@ -73,34 +73,33 @@ where BytesMut::new() }; - source.and_then(move |item| { - let result = (|| { - buf.reserve(HEADER_SIZE); - unsafe { - buf.advance_mut(HEADER_SIZE); - } + source.map(move |result| { + let item = result?; - if let Some(encoding) = compression_encoding { - uncompression_buf.clear(); + buf.reserve(HEADER_SIZE); + unsafe { + buf.advance_mut(HEADER_SIZE); + } - encoder - .encode(item, &mut EncodeBuf::new(&mut uncompression_buf)) - .map_err(|err| Status::internal(format!("Error encoding: {}", err)))?; + if let Some(encoding) = compression_encoding { + uncompression_buf.clear(); - let uncompressed_len = uncompression_buf.len(); + encoder + .encode(item, &mut EncodeBuf::new(&mut uncompression_buf)) + .map_err(|err| Status::internal(format!("Error encoding: {}", err)))?; - compress(encoding, &mut uncompression_buf, &mut buf, uncompressed_len) - .map_err(|err| Status::internal(format!("Error compressing: {}", err)))?; - } else { - encoder - .encode(item, &mut EncodeBuf::new(&mut buf)) - .map_err(|err| Status::internal(format!("Error encoding: {}", err)))?; - } + let uncompressed_len = uncompression_buf.len(); + + compress(encoding, &mut uncompression_buf, &mut buf, uncompressed_len) + .map_err(|err| Status::internal(format!("Error compressing: {}", err)))?; + } else { + encoder + .encode(item, &mut EncodeBuf::new(&mut buf)) + .map_err(|err| Status::internal(format!("Error encoding: {}", err)))?; + } - // now that we know length, we can write the header - Ok(finish_encoding(compression_encoding, &mut buf)) - })(); - async { result } + // now that we know length, we can write the header + Ok(finish_encoding(compression_encoding, &mut buf)) }) } From d3a1f52a9550261a1d1434b6739615cd33702d01 Mon Sep 17 00:00:00 2001 From: Markus Westerlind Date: Wed, 10 Aug 2022 12:27:43 +0200 Subject: [PATCH 11/15] refactor: Share the output decoding across functions (-2%) --- tonic/src/client/grpc.rs | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/tonic/src/client/grpc.rs b/tonic/src/client/grpc.rs index 6a25ee109..51f5af7a9 100644 --- a/tonic/src/client/grpc.rs +++ b/tonic/src/client/grpc.rs @@ -2,7 +2,7 @@ use crate::codec::compression::{CompressionEncoding, EnabledCompressionEncodings use crate::{ body::BoxBody, client::GrpcService, - codec::{encode_client, Codec, Streaming}, + codec::{encode_client, Codec, Decoder, Streaming}, request::SanitizeHeaders, Code, Request, Response, Status, }; @@ -239,6 +239,22 @@ impl Grpc { .await .map_err(|err| Status::from_error(err.into()))?; + let decoder = codec.decoder(); + + self.create_response(decoder, response) + } + + // Keeping this code in a separate function from Self::streaming lets functions that return the + // same output share the generated binary code + fn create_response( + &self, + decoder: impl Decoder + Send + 'static, + response: http::Response, + ) -> Result>, Status> + where + B: Body + Send + 'static, + ::Error: Into, + { let encoding = CompressionEncoding::from_encoding_header( response.headers(), self.config.accept_compression_encodings, @@ -261,9 +277,9 @@ impl Grpc { let response = response.map(|body| { if expect_additional_trailers { - Streaming::new_response(codec.decoder(), body, status_code, encoding) + Streaming::new_response(decoder, body, status_code, encoding) } else { - Streaming::new_empty(codec.decoder(), body) + Streaming::new_empty(decoder, body) } }); From 4ccd3e51979da2f8653d4ee2746c06dd242b8ae1 Mon Sep 17 00:00:00 2001 From: Markus Westerlind Date: Wed, 10 Aug 2022 12:34:36 +0200 Subject: [PATCH 12/15] refactor: Remove an unnecessary closure --- tonic/src/client/grpc.rs | 6 +----- tonic/src/status.rs | 7 ++++++- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tonic/src/client/grpc.rs b/tonic/src/client/grpc.rs index 51f5af7a9..0f4acec92 100644 --- a/tonic/src/client/grpc.rs +++ b/tonic/src/client/grpc.rs @@ -233,11 +233,7 @@ impl Grpc { let request = self.config.prepare_request(request, path); - let response = self - .inner - .call(request) - .await - .map_err(|err| Status::from_error(err.into()))?; + let response = self.inner.call(request).await.map_err(Status::from_error)?; let decoder = codec.decoder(); diff --git a/tonic/src/status.rs b/tonic/src/status.rs index 1d9d9d7c1..320bb89df 100644 --- a/tonic/src/status.rs +++ b/tonic/src/status.rs @@ -304,7 +304,12 @@ impl Status { } #[cfg_attr(not(feature = "transport"), allow(dead_code))] - pub(crate) fn from_error(err: Box) -> Status { + pub(crate) fn from_error(err: impl Into>) -> Status { + Self::from_error_impl(err.into()) + } + + #[cfg_attr(not(feature = "transport"), allow(dead_code))] + fn from_error_impl(err: Box) -> Status { Status::try_from_error(err).unwrap_or_else(|err| { let mut status = Status::new(Code::Unknown, err.to_string()); status.source = Some(err); From 14423329d89f95336b1e1c7aaa9fdd2101d28640 Mon Sep 17 00:00:00 2001 From: Markus Westerlind Date: Wed, 10 Aug 2022 12:37:06 +0200 Subject: [PATCH 13/15] refactor: Remove an unnecessary type parameter --- tonic/src/client/grpc.rs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tonic/src/client/grpc.rs b/tonic/src/client/grpc.rs index 0f4acec92..6c9959e8b 100644 --- a/tonic/src/client/grpc.rs +++ b/tonic/src/client/grpc.rs @@ -242,14 +242,15 @@ impl Grpc { // Keeping this code in a separate function from Self::streaming lets functions that return the // same output share the generated binary code - fn create_response( + fn create_response( &self, decoder: impl Decoder + Send + 'static, - response: http::Response, + response: http::Response, ) -> Result>, Status> where - B: Body + Send + 'static, - ::Error: Into, + T: GrpcService, + T::ResponseBody: Body + Send + 'static, + ::Error: Into, { let encoding = CompressionEncoding::from_encoding_header( response.headers(), From 628f9b4dc27ea5f230ca3eee6f9d18a4e8bae6bd Mon Sep 17 00:00:00 2001 From: Markus Westerlind Date: Wed, 10 Aug 2022 12:43:49 +0200 Subject: [PATCH 14/15] refactor: Extract a function only generic on Encoder in encode --- tonic/src/codec/encode.rs | 59 ++++++++++++++++++++++++++------------- 1 file changed, 39 insertions(+), 20 deletions(-) diff --git a/tonic/src/codec/encode.rs b/tonic/src/codec/encode.rs index eee878dc8..60a8548ce 100644 --- a/tonic/src/codec/encode.rs +++ b/tonic/src/codec/encode.rs @@ -76,31 +76,50 @@ where source.map(move |result| { let item = result?; - buf.reserve(HEADER_SIZE); - unsafe { - buf.advance_mut(HEADER_SIZE); - } + encode_item( + &mut encoder, + &mut buf, + &mut uncompression_buf, + compression_encoding, + item, + ) + }) +} - if let Some(encoding) = compression_encoding { - uncompression_buf.clear(); +fn encode_item( + encoder: &mut T, + buf: &mut BytesMut, + uncompression_buf: &mut BytesMut, + compression_encoding: Option, + item: T::Item, +) -> Result +where + T: Encoder, +{ + buf.reserve(HEADER_SIZE); + unsafe { + buf.advance_mut(HEADER_SIZE); + } - encoder - .encode(item, &mut EncodeBuf::new(&mut uncompression_buf)) - .map_err(|err| Status::internal(format!("Error encoding: {}", err)))?; + if let Some(encoding) = compression_encoding { + uncompression_buf.clear(); - let uncompressed_len = uncompression_buf.len(); + encoder + .encode(item, &mut EncodeBuf::new(uncompression_buf)) + .map_err(|err| Status::internal(format!("Error encoding: {}", err)))?; - compress(encoding, &mut uncompression_buf, &mut buf, uncompressed_len) - .map_err(|err| Status::internal(format!("Error compressing: {}", err)))?; - } else { - encoder - .encode(item, &mut EncodeBuf::new(&mut buf)) - .map_err(|err| Status::internal(format!("Error encoding: {}", err)))?; - } + let uncompressed_len = uncompression_buf.len(); - // now that we know length, we can write the header - Ok(finish_encoding(compression_encoding, &mut buf)) - }) + compress(encoding, uncompression_buf, buf, uncompressed_len) + .map_err(|err| Status::internal(format!("Error compressing: {}", err)))?; + } else { + encoder + .encode(item, &mut EncodeBuf::new(buf)) + .map_err(|err| Status::internal(format!("Error encoding: {}", err)))?; + } + + // now that we know length, we can write the header + Ok(finish_encoding(compression_encoding, buf)) } fn finish_encoding(compression_encoding: Option, buf: &mut BytesMut) -> Bytes { From ff20380db1ca08f0449667cb6558da18f14279ca Mon Sep 17 00:00:00 2001 From: Markus Westerlind Date: Fri, 19 Aug 2022 18:01:45 +0200 Subject: [PATCH 15/15] test: Correct the test for from_error Seems to be some weirdness around type inference and `Into` causing the downcasts to fail --- tonic/src/client/grpc.rs | 6 +++++- tonic/src/status.rs | 8 +++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/tonic/src/client/grpc.rs b/tonic/src/client/grpc.rs index 6c9959e8b..85134ceae 100644 --- a/tonic/src/client/grpc.rs +++ b/tonic/src/client/grpc.rs @@ -233,7 +233,11 @@ impl Grpc { let request = self.config.prepare_request(request, path); - let response = self.inner.call(request).await.map_err(Status::from_error)?; + let response = self + .inner + .call(request) + .await + .map_err(Status::from_error_generic)?; let decoder = codec.decoder(); diff --git a/tonic/src/status.rs b/tonic/src/status.rs index 320bb89df..9db79301f 100644 --- a/tonic/src/status.rs +++ b/tonic/src/status.rs @@ -304,12 +304,14 @@ impl Status { } #[cfg_attr(not(feature = "transport"), allow(dead_code))] - pub(crate) fn from_error(err: impl Into>) -> Status { - Self::from_error_impl(err.into()) + pub(crate) fn from_error_generic( + err: impl Into>, + ) -> Status { + Self::from_error(err.into()) } #[cfg_attr(not(feature = "transport"), allow(dead_code))] - fn from_error_impl(err: Box) -> Status { + pub(crate) fn from_error(err: Box) -> Status { Status::try_from_error(err).unwrap_or_else(|err| { let mut status = Status::new(Code::Unknown, err.to_string()); status.source = Some(err);