diff --git a/tests/integration_tests/Cargo.toml b/tests/integration_tests/Cargo.toml index 289054b57..543f8e006 100644 --- a/tests/integration_tests/Cargo.toml +++ b/tests/integration_tests/Cargo.toml @@ -23,6 +23,7 @@ http-body = "0.4" hyper = "0.14" tokio-stream = {version = "0.1.5", features = ["net"]} tower = {version = "0.4", features = []} +tower-http = { version = "0.2", features = ["set-header", "trace"] } tower-service = "0.3" tracing-subscriber = {version = "0.3", features = ["env-filter"]} diff --git a/tests/integration_tests/tests/client_layer.rs b/tests/integration_tests/tests/client_layer.rs new file mode 100644 index 000000000..c33f388d7 --- /dev/null +++ b/tests/integration_tests/tests/client_layer.rs @@ -0,0 +1,58 @@ +use std::time::Duration; + +use futures::{channel::oneshot, FutureExt}; +use http::{header::HeaderName, HeaderValue}; +use integration_tests::pb::{test_client::TestClient, test_server, Input, Output}; +use tonic::{ + transport::{Endpoint, Server}, + Request, Response, Status, +}; +use tower::ServiceBuilder; +use tower_http::{set_header::SetRequestHeaderLayer, trace::TraceLayer}; + +#[tokio::test] +async fn connect_supports_standard_tower_layers() { + struct Svc; + + #[tonic::async_trait] + impl test_server::Test for Svc { + async fn unary_call(&self, req: Request) -> Result, Status> { + match req.metadata().get("x-test") { + Some(_) => Ok(Response::new(Output {})), + None => Err(Status::internal("user-agent header is missing")), + } + } + } + + let (tx, rx) = oneshot::channel(); + let svc = test_server::TestServer::new(Svc); + + // Start the server now, second call should succeed + let jh = tokio::spawn(async move { + Server::builder() + .add_service(svc) + .serve_with_shutdown("127.0.0.1:1340".parse().unwrap(), rx.map(drop)) + .await + .unwrap(); + }); + + let channel = Endpoint::from_static("http://127.0.0.1:1340").connect_lazy(); + + // prior to https://github.com/hyperium/tonic/pull/974 + // this would not compile. (specifically the `TraceLayer`) + let mut client = TestClient::new( + ServiceBuilder::new() + .layer(SetRequestHeaderLayer::overriding( + HeaderName::from_static("x-test"), + HeaderValue::from_static("test-header"), + )) + .layer(TraceLayer::new_for_grpc()) + .service(channel), + ); + + tokio::time::sleep(Duration::from_millis(100)).await; + client.unary_call(Request::new(Input {})).await.unwrap(); + + tx.send(()).unwrap(); + jh.await.unwrap(); +} diff --git a/tonic-build/src/client.rs b/tonic-build/src/client.rs index daa05fa4e..913d3e7bb 100644 --- a/tonic-build/src/client.rs +++ b/tonic-build/src/client.rs @@ -58,7 +58,7 @@ pub fn generate( where T: tonic::client::GrpcService, T::Error: Into, - T::ResponseBody: Default + Body + Send + 'static, + T::ResponseBody: Body + Send + 'static, ::Error: Into + Send, { pub fn new(inner: T) -> Self { @@ -69,6 +69,7 @@ pub fn generate( pub fn with_interceptor(inner: T, interceptor: F) -> #service_ident> where F: tonic::service::Interceptor, + T::ResponseBody: Default, T: tonic::codegen::Service< http::Request, Response = http::Response<>::ResponseBody>