Skip to content

Commit

Permalink
fix(tonic): Handle interceptor errors as responses (#840)
Browse files Browse the repository at this point in the history
  • Loading branch information
NAlexPear committed Nov 19, 2021
1 parent c62f382 commit b8ae4da
Showing 1 changed file with 59 additions and 6 deletions.
65 changes: 59 additions & 6 deletions tonic/src/service/interceptor.rs
Expand Up @@ -3,6 +3,7 @@
//! See [`Interceptor`] for more details.

use crate::{request::SanitizeHeaders, Status};
use bytes::Bytes;
use pin_project::pin_project;
use std::{
fmt,
Expand Down Expand Up @@ -140,6 +141,7 @@ where

impl<S, F, ReqBody, ResBody> Service<http::Request<ReqBody>> for InterceptedService<S, F>
where
ResBody: Default + http_body::Body<Data = Bytes, Error = Status> + Send + 'static,
F: Interceptor,
S: Service<http::Request<ReqBody>, Response = http::Response<ResBody>>,
S::Error: Into<crate::Error>,
Expand Down Expand Up @@ -215,15 +217,17 @@ impl<F, E, B> Future for ResponseFuture<F>
where
F: Future<Output = Result<http::Response<B>, E>>,
E: Into<crate::Error>,
B: Default + http_body::Body<Data = Bytes, Error = Status> + Send + 'static,
{
type Output = Result<http::Response<B>, crate::Error>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project().kind.project() {
KindProj::Future(future) => future.poll(cx).map_err(Into::into),
KindProj::Error(status) => {
let error = status.take().unwrap().into();
Poll::Ready(Err(error))
let response = status.take().unwrap().to_http().map(|_| B::default());

Poll::Ready(Ok(response))
}
}
}
Expand All @@ -233,11 +237,38 @@ where
mod tests {
#[allow(unused_imports)]
use super::*;
use http::header::HeaderMap;
use std::{
pin::Pin,
task::{Context, Poll},
};
use tower::ServiceExt;

#[derive(Debug, Default)]
struct TestBody;

impl http_body::Body for TestBody {
type Data = Bytes;
type Error = Status;

fn poll_data(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
Poll::Ready(None)
}

fn poll_trailers(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
Poll::Ready(Ok(None))
}
}

#[tokio::test]
async fn doesnt_remove_headers() {
let svc = tower::service_fn(|request: http::Request<hyper::Body>| async move {
async fn doesnt_remove_headers_from_requests() {
let svc = tower::service_fn(|request: http::Request<TestBody>| async move {
assert_eq!(
request
.headers()
Expand All @@ -246,7 +277,7 @@ mod tests {
"test-tonic"
);

Ok::<_, hyper::Error>(hyper::Response::new(hyper::Body::empty()))
Ok::<_, Status>(http::Response::new(TestBody))
});

let svc = InterceptedService::new(svc, |request: crate::Request<()>| {
Expand All @@ -257,14 +288,36 @@ mod tests {
.expect("missing in interceptor"),
"test-tonic"
);

Ok(request)
});

let request = http::Request::builder()
.header("user-agent", "test-tonic")
.body(hyper::Body::empty())
.body(TestBody)
.unwrap();

svc.oneshot(request).await.unwrap();
}

#[tokio::test]
async fn handles_intercepted_status_as_response() {
let message = "Blocked by the interceptor";
let expected = Status::permission_denied(message).to_http();

let svc = tower::service_fn(|_: http::Request<TestBody>| async {
Ok::<_, Status>(http::Response::new(TestBody))
});

let svc = InterceptedService::new(svc, |_: crate::Request<()>| {
Err(Status::permission_denied(message))
});

let request = http::Request::builder().body(TestBody).unwrap();
let response = svc.oneshot(request).await.unwrap();

assert_eq!(expected.status(), response.status());
assert_eq!(expected.version(), response.version());
assert_eq!(expected.headers(), response.headers());
}
}

0 comments on commit b8ae4da

Please sign in to comment.