From 1db6bafafd09ac89a38844a3ba9193a218ec509b Mon Sep 17 00:00:00 2001 From: Quentin Perez Date: Sat, 18 Jun 2022 12:57:01 +0200 Subject: [PATCH] add integration test --- tests/integration_tests/tests/origin.rs | 100 ++++++++++++++++++++++ tonic/src/transport/service/connection.rs | 6 +- 2 files changed, 101 insertions(+), 5 deletions(-) create mode 100644 tests/integration_tests/tests/origin.rs diff --git a/tests/integration_tests/tests/origin.rs b/tests/integration_tests/tests/origin.rs new file mode 100644 index 000000000..17bbc9cdd --- /dev/null +++ b/tests/integration_tests/tests/origin.rs @@ -0,0 +1,100 @@ +use futures::future::BoxFuture; +use futures_util::FutureExt; +use integration_tests::pb::test_client; +use integration_tests::pb::{test_server, Input, Output}; +use std::task::Context; +use std::task::Poll; +use std::time::Duration; +use tokio::sync::oneshot; +use tonic::codegen::http::Request; +use tonic::{ + transport::{Endpoint, Server}, + Response, Status, +}; +use tower::Layer; +use tower::Service; + +#[tokio::test] +async fn writes_origin_header() { + struct Svc; + + #[tonic::async_trait] + impl test_server::Test for Svc { + async fn unary_call( + &self, + _req: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(Output {})) + } + } + + let svc = test_server::TestServer::new(Svc); + + let (tx, rx) = oneshot::channel::<()>(); + + let jh = tokio::spawn(async move { + Server::builder() + .layer(OriginLayer {}) + .add_service(svc) + .serve_with_shutdown("127.0.0.1:1442".parse().unwrap(), rx.map(drop)) + .await + .unwrap(); + }); + + tokio::time::sleep(Duration::from_millis(100)).await; + + let channel = Endpoint::from_static("http://127.0.0.1:1442") + .origin("https://docs.rs".parse().expect("valid uri")) + .connect() + .await + .unwrap(); + + let mut client = test_client::TestClient::new(channel); + + match client.unary_call(Input {}).await { + Ok(_) => {} + Err(status) => panic!("{}", status.message()), + } + + tx.send(()).unwrap(); + + jh.await.unwrap(); +} + +#[derive(Clone)] +struct OriginLayer {} + +impl Layer for OriginLayer { + type Service = OriginService; + + fn layer(&self, inner: S) -> Self::Service { + OriginService { inner } + } +} + +#[derive(Clone)] +struct OriginService { + inner: S, +} + +impl Service> for OriginService +where + T: Service>, + T::Future: Send + 'static, + T::Error: Into>, +{ + type Response = T::Response; + type Error = Box; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx).map_err(Into::into) + } + + fn call(&mut self, req: Request) -> Self::Future { + assert_eq!(req.uri().host(), Some("docs.rs")); + let fut = self.inner.call(req); + + Box::pin(async move { fut.await.map_err(Into::into) }) + } +} diff --git a/tonic/src/transport/service/connection.rs b/tonic/src/transport/service/connection.rs index fe461240c..a12f2d6a9 100644 --- a/tonic/src/transport/service/connection.rs +++ b/tonic/src/transport/service/connection.rs @@ -56,11 +56,7 @@ impl Connection { let stack = ServiceBuilder::new() .layer_fn(|s| { - let origin = endpoint - .origin - .as_ref() - .unwrap_or_else(|| &endpoint.uri) - .clone(); + let origin = endpoint.origin.as_ref().unwrap_or(&endpoint.uri).clone(); AddOrigin::new(s, origin) })