Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Handle interceptor errors as responses (#840) #842

Merged
merged 1 commit into from Feb 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions tonic-build/src/client.rs
Expand Up @@ -15,7 +15,7 @@ pub fn generate<T: Service>(
attributes: &Attributes,
) -> TokenStream {
let service_ident = quote::format_ident!("{}Client", service.name());
let client_mod = quote::format_ident!("{}_client", naive_snake_case(&service.name()));
let client_mod = quote::format_ident!("{}_client", naive_snake_case(service.name()));
let methods = generate_methods(service, emit_package, proto_path, compile_well_known_types);

let connect = generate_connect(&service_ident);
Expand Down Expand Up @@ -57,8 +57,8 @@ pub fn generate<T: Service>(
impl<T> #service_ident<T>
where
T: tonic::client::GrpcService<tonic::body::BoxBody>,
T::ResponseBody: Body + Send + 'static,
T::Error: Into<StdError>,
T::ResponseBody: Default + Body<Data = Bytes> + Send + 'static,
<T::ResponseBody as Body>::Error: Into<StdError> + Send,
{
pub fn new(inner: T) -> Self {
Expand Down
1 change: 1 addition & 0 deletions tonic/src/codegen.rs
Expand Up @@ -13,6 +13,7 @@ pub type StdError = Box<dyn std::error::Error + Send + Sync + 'static>;
#[cfg(feature = "compression")]
pub use crate::codec::{CompressionEncoding, EnabledCompressionEncodings};
pub use crate::service::interceptor::InterceptedService;
pub use bytes::Bytes;
pub use http_body::Body;

pub type BoxFuture<T, E> = self::Pin<Box<dyn self::Future<Output = Result<T, E>> + Send + 'static>>;
Expand Down
67 changes: 61 additions & 6 deletions tonic/src/service/interceptor.rs
Expand Up @@ -3,6 +3,7 @@
//! See [`Interceptor`] for more details.

use crate::{request::SanitizeHeaders, Status};
use bytes::Bytes;
use pin_project::pin_project;
use std::{
fmt,
Expand Down Expand Up @@ -140,9 +141,11 @@ where

impl<S, F, ReqBody, ResBody> Service<http::Request<ReqBody>> for InterceptedService<S, F>
where
ResBody: Default + http_body::Body<Data = Bytes> + Send + 'static,
F: Interceptor,
S: Service<http::Request<ReqBody>, Response = http::Response<ResBody>>,
S::Error: Into<crate::Error>,
ResBody::Error: Into<crate::Error>,
{
type Response = http::Response<ResBody>;
type Error = crate::Error;
Expand Down Expand Up @@ -215,15 +218,18 @@ impl<F, E, B> Future for ResponseFuture<F>
where
F: Future<Output = Result<http::Response<B>, E>>,
E: Into<crate::Error>,
B: Default + http_body::Body<Data = Bytes> + Send + 'static,
B::Error: Into<crate::Error>,
{
type Output = Result<http::Response<B>, crate::Error>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project().kind.project() {
KindProj::Future(future) => future.poll(cx).map_err(Into::into),
KindProj::Error(status) => {
let error = status.take().unwrap().into();
Poll::Ready(Err(error))
let response = status.take().unwrap().to_http().map(|_| B::default());

Poll::Ready(Ok(response))
}
}
}
Expand All @@ -233,11 +239,38 @@ where
mod tests {
#[allow(unused_imports)]
use super::*;
use http::header::HeaderMap;
use std::{
pin::Pin,
task::{Context, Poll},
};
use tower::ServiceExt;

#[derive(Debug, Default)]
struct TestBody;

impl http_body::Body for TestBody {
type Data = Bytes;
type Error = Status;

fn poll_data(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
Poll::Ready(None)
}

fn poll_trailers(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
Poll::Ready(Ok(None))
}
}

#[tokio::test]
async fn doesnt_remove_headers() {
let svc = tower::service_fn(|request: http::Request<hyper::Body>| async move {
async fn doesnt_remove_headers_from_requests() {
let svc = tower::service_fn(|request: http::Request<TestBody>| async move {
assert_eq!(
request
.headers()
Expand All @@ -246,7 +279,7 @@ mod tests {
"test-tonic"
);

Ok::<_, hyper::Error>(hyper::Response::new(hyper::Body::empty()))
Ok::<_, Status>(http::Response::new(TestBody))
});

let svc = InterceptedService::new(svc, |request: crate::Request<()>| {
Expand All @@ -257,14 +290,36 @@ mod tests {
.expect("missing in interceptor"),
"test-tonic"
);

Ok(request)
});

let request = http::Request::builder()
.header("user-agent", "test-tonic")
.body(hyper::Body::empty())
.body(TestBody)
.unwrap();

svc.oneshot(request).await.unwrap();
}

#[tokio::test]
async fn handles_intercepted_status_as_response() {
let message = "Blocked by the interceptor";
let expected = Status::permission_denied(message).to_http();

let svc = tower::service_fn(|_: http::Request<TestBody>| async {
Ok::<_, Status>(http::Response::new(TestBody))
});

let svc = InterceptedService::new(svc, |_: crate::Request<()>| {
Err(Status::permission_denied(message))
});

let request = http::Request::builder().body(TestBody).unwrap();
let response = svc.oneshot(request).await.unwrap();

assert_eq!(expected.status(), response.status());
assert_eq!(expected.version(), response.version());
assert_eq!(expected.headers(), response.headers());
}
}