Skip to content

Commit

Permalink
fix(tonic): Handle interceptor errors as responses (#840)
Browse files Browse the repository at this point in the history
  • Loading branch information
NAlexPear committed Nov 23, 2021
1 parent c62f382 commit 2a61ea0
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 8 deletions.
4 changes: 2 additions & 2 deletions tonic-build/src/client.rs
Expand Up @@ -15,7 +15,7 @@ pub fn generate<T: Service>(
attributes: &Attributes,
) -> TokenStream {
let service_ident = quote::format_ident!("{}Client", service.name());
let client_mod = quote::format_ident!("{}_client", naive_snake_case(&service.name()));
let client_mod = quote::format_ident!("{}_client", naive_snake_case(service.name()));
let methods = generate_methods(service, emit_package, proto_path, compile_well_known_types);

let connect = generate_connect(&service_ident);
Expand Down Expand Up @@ -57,8 +57,8 @@ pub fn generate<T: Service>(
impl<T> #service_ident<T>
where
T: tonic::client::GrpcService<tonic::body::BoxBody>,
T::ResponseBody: Body + Send + 'static,
T::Error: Into<StdError>,
T::ResponseBody: Default + Body<Data = Bytes> + Send + 'static,
<T::ResponseBody as Body>::Error: Into<StdError> + Send,
{
pub fn new(inner: T) -> Self {
Expand Down
1 change: 1 addition & 0 deletions tonic/src/codegen.rs
Expand Up @@ -13,6 +13,7 @@ pub type StdError = Box<dyn std::error::Error + Send + Sync + 'static>;
#[cfg(feature = "compression")]
pub use crate::codec::{CompressionEncoding, EnabledCompressionEncodings};
pub use crate::service::interceptor::InterceptedService;
pub use bytes::Bytes;
pub use http_body::Body;

pub type BoxFuture<T, E> = self::Pin<Box<dyn self::Future<Output = Result<T, E>> + Send + 'static>>;
Expand Down
67 changes: 61 additions & 6 deletions tonic/src/service/interceptor.rs
Expand Up @@ -3,6 +3,7 @@
//! See [`Interceptor`] for more details.

use crate::{request::SanitizeHeaders, Status};
use bytes::Bytes;
use pin_project::pin_project;
use std::{
fmt,
Expand Down Expand Up @@ -140,9 +141,11 @@ where

impl<S, F, ReqBody, ResBody> Service<http::Request<ReqBody>> for InterceptedService<S, F>
where
ResBody: Default + http_body::Body<Data = Bytes> + Send + 'static,
F: Interceptor,
S: Service<http::Request<ReqBody>, Response = http::Response<ResBody>>,
S::Error: Into<crate::Error>,
ResBody::Error: Into<crate::Error>,
{
type Response = http::Response<ResBody>;
type Error = crate::Error;
Expand Down Expand Up @@ -215,15 +218,18 @@ impl<F, E, B> Future for ResponseFuture<F>
where
F: Future<Output = Result<http::Response<B>, E>>,
E: Into<crate::Error>,
B: Default + http_body::Body<Data = Bytes> + Send + 'static,
B::Error: Into<crate::Error>,
{
type Output = Result<http::Response<B>, crate::Error>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project().kind.project() {
KindProj::Future(future) => future.poll(cx).map_err(Into::into),
KindProj::Error(status) => {
let error = status.take().unwrap().into();
Poll::Ready(Err(error))
let response = status.take().unwrap().to_http().map(|_| B::default());

Poll::Ready(Ok(response))
}
}
}
Expand All @@ -233,11 +239,38 @@ where
mod tests {
#[allow(unused_imports)]
use super::*;
use http::header::HeaderMap;
use std::{
pin::Pin,
task::{Context, Poll},
};
use tower::ServiceExt;

#[derive(Debug, Default)]
struct TestBody;

impl http_body::Body for TestBody {
type Data = Bytes;
type Error = Status;

fn poll_data(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
Poll::Ready(None)
}

fn poll_trailers(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
Poll::Ready(Ok(None))
}
}

#[tokio::test]
async fn doesnt_remove_headers() {
let svc = tower::service_fn(|request: http::Request<hyper::Body>| async move {
async fn doesnt_remove_headers_from_requests() {
let svc = tower::service_fn(|request: http::Request<TestBody>| async move {
assert_eq!(
request
.headers()
Expand All @@ -246,7 +279,7 @@ mod tests {
"test-tonic"
);

Ok::<_, hyper::Error>(hyper::Response::new(hyper::Body::empty()))
Ok::<_, Status>(http::Response::new(TestBody))
});

let svc = InterceptedService::new(svc, |request: crate::Request<()>| {
Expand All @@ -257,14 +290,36 @@ mod tests {
.expect("missing in interceptor"),
"test-tonic"
);

Ok(request)
});

let request = http::Request::builder()
.header("user-agent", "test-tonic")
.body(hyper::Body::empty())
.body(TestBody)
.unwrap();

svc.oneshot(request).await.unwrap();
}

#[tokio::test]
async fn handles_intercepted_status_as_response() {
let message = "Blocked by the interceptor";
let expected = Status::permission_denied(message).to_http();

let svc = tower::service_fn(|_: http::Request<TestBody>| async {
Ok::<_, Status>(http::Response::new(TestBody))
});

let svc = InterceptedService::new(svc, |_: crate::Request<()>| {
Err(Status::permission_denied(message))
});

let request = http::Request::builder().body(TestBody).unwrap();
let response = svc.oneshot(request).await.unwrap();

assert_eq!(expected.status(), response.status());
assert_eq!(expected.version(), response.version());
assert_eq!(expected.headers(), response.headers());
}
}

0 comments on commit 2a61ea0

Please sign in to comment.