diff --git a/tonic-build/src/client.rs b/tonic-build/src/client.rs index cfe62d9ef..daa05fa4e 100644 --- a/tonic-build/src/client.rs +++ b/tonic-build/src/client.rs @@ -57,8 +57,8 @@ pub fn generate( impl #service_ident where T: tonic::client::GrpcService, - T::ResponseBody: Body + Send + 'static, T::Error: Into, + T::ResponseBody: Default + Body + Send + 'static, ::Error: Into + Send, { pub fn new(inner: T) -> Self { diff --git a/tonic/src/codegen.rs b/tonic/src/codegen.rs index 42fe53f20..78a2c88a8 100644 --- a/tonic/src/codegen.rs +++ b/tonic/src/codegen.rs @@ -13,6 +13,7 @@ pub type StdError = Box; #[cfg(feature = "compression")] pub use crate::codec::{CompressionEncoding, EnabledCompressionEncodings}; pub use crate::service::interceptor::InterceptedService; +pub use bytes::Bytes; pub use http_body::Body; pub type BoxFuture = self::Pin> + Send + 'static>>; diff --git a/tonic/src/service/interceptor.rs b/tonic/src/service/interceptor.rs index 85477e432..441eb9163 100644 --- a/tonic/src/service/interceptor.rs +++ b/tonic/src/service/interceptor.rs @@ -3,6 +3,7 @@ //! See [`Interceptor`] for more details. use crate::{request::SanitizeHeaders, Status}; +use bytes::Bytes; use pin_project::pin_project; use std::{ fmt, @@ -140,9 +141,11 @@ where impl Service> for InterceptedService where + ResBody: Default + http_body::Body + Send + 'static, F: Interceptor, S: Service, Response = http::Response>, S::Error: Into, + ResBody::Error: Into, { type Response = http::Response; type Error = crate::Error; @@ -215,6 +218,8 @@ impl Future for ResponseFuture where F: Future, E>>, E: Into, + B: Default + http_body::Body + Send + 'static, + B::Error: Into, { type Output = Result, crate::Error>; @@ -222,8 +227,9 @@ where match self.project().kind.project() { KindProj::Future(future) => future.poll(cx).map_err(Into::into), KindProj::Error(status) => { - let error = status.take().unwrap().into(); - Poll::Ready(Err(error)) + let response = status.take().unwrap().to_http().map(|_| B::default()); + + Poll::Ready(Ok(response)) } } } @@ -233,11 +239,38 @@ where mod tests { #[allow(unused_imports)] use super::*; + use http::header::HeaderMap; + use std::{ + pin::Pin, + task::{Context, Poll}, + }; use tower::ServiceExt; + #[derive(Debug, Default)] + struct TestBody; + + impl http_body::Body for TestBody { + type Data = Bytes; + type Error = Status; + + fn poll_data( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll>> { + Poll::Ready(None) + } + + fn poll_trailers( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { + Poll::Ready(Ok(None)) + } + } + #[tokio::test] - async fn doesnt_remove_headers() { - let svc = tower::service_fn(|request: http::Request| async move { + async fn doesnt_remove_headers_from_requests() { + let svc = tower::service_fn(|request: http::Request| async move { assert_eq!( request .headers() @@ -246,7 +279,7 @@ mod tests { "test-tonic" ); - Ok::<_, hyper::Error>(hyper::Response::new(hyper::Body::empty())) + Ok::<_, Status>(http::Response::new(TestBody)) }); let svc = InterceptedService::new(svc, |request: crate::Request<()>| { @@ -257,14 +290,36 @@ mod tests { .expect("missing in interceptor"), "test-tonic" ); + Ok(request) }); let request = http::Request::builder() .header("user-agent", "test-tonic") - .body(hyper::Body::empty()) + .body(TestBody) .unwrap(); svc.oneshot(request).await.unwrap(); } + + #[tokio::test] + async fn handles_intercepted_status_as_response() { + let message = "Blocked by the interceptor"; + let expected = Status::permission_denied(message).to_http(); + + let svc = tower::service_fn(|_: http::Request| async { + Ok::<_, Status>(http::Response::new(TestBody)) + }); + + let svc = InterceptedService::new(svc, |_: crate::Request<()>| { + Err(Status::permission_denied(message)) + }); + + let request = http::Request::builder().body(TestBody).unwrap(); + let response = svc.oneshot(request).await.unwrap(); + + assert_eq!(expected.status(), response.status()); + assert_eq!(expected.version(), response.version()); + assert_eq!(expected.headers(), response.headers()); + } }