diff --git a/tonic/src/service/interceptor.rs b/tonic/src/service/interceptor.rs index 85477e432..25b1326b2 100644 --- a/tonic/src/service/interceptor.rs +++ b/tonic/src/service/interceptor.rs @@ -3,6 +3,7 @@ //! See [`Interceptor`] for more details. use crate::{request::SanitizeHeaders, Status}; +use bytes::Bytes; use pin_project::pin_project; use std::{ fmt, @@ -140,6 +141,7 @@ where impl Service> for InterceptedService where + ResBody: Default + http_body::Body + Send + 'static, F: Interceptor, S: Service, Response = http::Response>, S::Error: Into, @@ -215,6 +217,7 @@ impl Future for ResponseFuture where F: Future, E>>, E: Into, + B: Default + http_body::Body + Send + 'static, { type Output = Result, crate::Error>; @@ -222,8 +225,9 @@ where match self.project().kind.project() { KindProj::Future(future) => future.poll(cx).map_err(Into::into), KindProj::Error(status) => { - let error = status.take().unwrap().into(); - Poll::Ready(Err(error)) + let response = status.take().unwrap().to_http().map(|_| B::default()); + + Poll::Ready(Ok(response)) } } } @@ -233,11 +237,38 @@ where mod tests { #[allow(unused_imports)] use super::*; + use http::header::HeaderMap; + use std::{ + pin::Pin, + task::{Context, Poll}, + }; use tower::ServiceExt; + #[derive(Debug, Default)] + struct TestBody; + + impl http_body::Body for TestBody { + type Data = Bytes; + type Error = Status; + + fn poll_data( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll>> { + Poll::Ready(None) + } + + fn poll_trailers( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { + Poll::Ready(Ok(None)) + } + } + #[tokio::test] - async fn doesnt_remove_headers() { - let svc = tower::service_fn(|request: http::Request| async move { + async fn doesnt_remove_headers_from_requests() { + let svc = tower::service_fn(|request: http::Request| async move { assert_eq!( request .headers() @@ -246,7 +277,7 @@ mod tests { "test-tonic" ); - Ok::<_, hyper::Error>(hyper::Response::new(hyper::Body::empty())) + Ok::<_, Status>(http::Response::new(TestBody)) }); let svc = InterceptedService::new(svc, |request: crate::Request<()>| { @@ -257,14 +288,36 @@ mod tests { .expect("missing in interceptor"), "test-tonic" ); + Ok(request) }); let request = http::Request::builder() .header("user-agent", "test-tonic") - .body(hyper::Body::empty()) + .body(TestBody) .unwrap(); svc.oneshot(request).await.unwrap(); } + + #[tokio::test] + async fn handles_intercepted_status_as_response() { + let message = "Blocked by the interceptor"; + let expected = Status::permission_denied(message).to_http(); + + let svc = tower::service_fn(|_: http::Request| async { + Ok::<_, Status>(http::Response::new(TestBody)) + }); + + let svc = InterceptedService::new(svc, |_: crate::Request<()>| { + Err(Status::permission_denied(message)) + }); + + let request = http::Request::builder().body(TestBody).unwrap(); + let response = svc.oneshot(request).await.unwrap(); + + assert_eq!(expected.status(), response.status()); + assert_eq!(expected.version(), response.version()); + assert_eq!(expected.headers(), response.headers()); + } }