diff --git a/tests/integration_tests/tests/origin.rs b/tests/integration_tests/tests/origin.rs new file mode 100644 index 000000000..17bbc9cdd --- /dev/null +++ b/tests/integration_tests/tests/origin.rs @@ -0,0 +1,100 @@ +use futures::future::BoxFuture; +use futures_util::FutureExt; +use integration_tests::pb::test_client; +use integration_tests::pb::{test_server, Input, Output}; +use std::task::Context; +use std::task::Poll; +use std::time::Duration; +use tokio::sync::oneshot; +use tonic::codegen::http::Request; +use tonic::{ + transport::{Endpoint, Server}, + Response, Status, +}; +use tower::Layer; +use tower::Service; + +#[tokio::test] +async fn writes_origin_header() { + struct Svc; + + #[tonic::async_trait] + impl test_server::Test for Svc { + async fn unary_call( + &self, + _req: tonic::Request, + ) -> Result, Status> { + Ok(Response::new(Output {})) + } + } + + let svc = test_server::TestServer::new(Svc); + + let (tx, rx) = oneshot::channel::<()>(); + + let jh = tokio::spawn(async move { + Server::builder() + .layer(OriginLayer {}) + .add_service(svc) + .serve_with_shutdown("127.0.0.1:1442".parse().unwrap(), rx.map(drop)) + .await + .unwrap(); + }); + + tokio::time::sleep(Duration::from_millis(100)).await; + + let channel = Endpoint::from_static("http://127.0.0.1:1442") + .origin("https://docs.rs".parse().expect("valid uri")) + .connect() + .await + .unwrap(); + + let mut client = test_client::TestClient::new(channel); + + match client.unary_call(Input {}).await { + Ok(_) => {} + Err(status) => panic!("{}", status.message()), + } + + tx.send(()).unwrap(); + + jh.await.unwrap(); +} + +#[derive(Clone)] +struct OriginLayer {} + +impl Layer for OriginLayer { + type Service = OriginService; + + fn layer(&self, inner: S) -> Self::Service { + OriginService { inner } + } +} + +#[derive(Clone)] +struct OriginService { + inner: S, +} + +impl Service> for OriginService +where + T: Service>, + T::Future: Send + 'static, + T::Error: Into>, +{ + type Response = T::Response; + type Error = Box; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx).map_err(Into::into) + } + + fn call(&mut self, req: Request) -> Self::Future { + assert_eq!(req.uri().host(), Some("docs.rs")); + let fut = self.inner.call(req); + + Box::pin(async move { fut.await.map_err(Into::into) }) + } +} diff --git a/tonic/src/transport/channel/endpoint.rs b/tonic/src/transport/channel/endpoint.rs index 5d08e22f3..32ce858b6 100644 --- a/tonic/src/transport/channel/endpoint.rs +++ b/tonic/src/transport/channel/endpoint.rs @@ -24,6 +24,7 @@ use tower::make::MakeConnection; #[derive(Clone)] pub struct Endpoint { pub(crate) uri: Uri, + pub(crate) origin: Option, pub(crate) user_agent: Option, pub(crate) timeout: Option, pub(crate) concurrency_limit: Option, @@ -106,6 +107,25 @@ impl Endpoint { .map_err(|_| Error::new_invalid_user_agent()) } + /// Set a custom origin. + /// + /// Override the `origin`, mainly useful when you are reaching a Server/LoadBalancer + /// which serves multiple services at the same time. + /// It will play the role of SNI (Server Name Indication). + /// + /// ``` + /// # use tonic::transport::Endpoint; + /// # let mut builder = Endpoint::from_static("https://proxy.com"); + /// builder.origin("https://example.com".parse().expect("http://example.com must be a valid URI")); + /// // origin: "https://example.com" + /// ``` + pub fn origin(self, origin: Uri) -> Self { + Endpoint { + origin: Some(origin), + ..self + } + } + /// Apply a timeout to each request. /// /// ``` @@ -395,6 +415,7 @@ impl From for Endpoint { fn from(uri: Uri) -> Self { Self { uri, + origin: None, user_agent: None, concurrency_limit: None, rate_limit: None, diff --git a/tonic/src/transport/service/add_origin.rs b/tonic/src/transport/service/add_origin.rs index 50d5b6d96..b706ee995 100644 --- a/tonic/src/transport/service/add_origin.rs +++ b/tonic/src/transport/service/add_origin.rs @@ -1,4 +1,6 @@ use futures_core::future::BoxFuture; +use http::uri::Authority; +use http::uri::Scheme; use http::{Request, Uri}; use std::task::{Context, Poll}; use tower_service::Service; @@ -6,12 +8,21 @@ use tower_service::Service; #[derive(Debug)] pub(crate) struct AddOrigin { inner: T, - origin: Uri, + scheme: Option, + authority: Option, } impl AddOrigin { pub(crate) fn new(inner: T, origin: Uri) -> Self { - Self { inner, origin } + let http::uri::Parts { + scheme, authority, .. + } = origin.into_parts(); + + Self { + inner, + scheme, + authority, + } } } @@ -30,24 +41,24 @@ where } fn call(&mut self, req: Request) -> Self::Future { - // Split the request into the head and the body. - let (mut head, body) = req.into_parts(); - - // Split the request URI into parts. - let mut uri: http::uri::Parts = head.uri.into(); - let set_uri = self.origin.clone().into_parts(); - - if set_uri.scheme.is_none() || set_uri.authority.is_none() { + if self.scheme.is_none() || self.authority.is_none() { let err = crate::transport::Error::new_invalid_uri(); return Box::pin(async move { Err::(err.into()) }); } - // Update the URI parts, setting hte scheme and authority - uri.scheme = Some(set_uri.scheme.expect("expected scheme")); - uri.authority = Some(set_uri.authority.expect("expected authority")); + // Split the request into the head and the body. + let (mut head, body) = req.into_parts(); // Update the the request URI - head.uri = http::Uri::from_parts(uri).expect("valid uri"); + head.uri = { + // Split the request URI into parts. + let mut uri: http::uri::Parts = head.uri.into(); + // Update the URI parts, setting hte scheme and authority + uri.scheme = self.scheme.clone(); + uri.authority = self.authority.clone(); + + http::Uri::from_parts(uri).expect("valid uri") + }; let request = Request::from_parts(head, body); diff --git a/tonic/src/transport/service/connection.rs b/tonic/src/transport/service/connection.rs index 3aee2681c..a12f2d6a9 100644 --- a/tonic/src/transport/service/connection.rs +++ b/tonic/src/transport/service/connection.rs @@ -55,7 +55,11 @@ impl Connection { } let stack = ServiceBuilder::new() - .layer_fn(|s| AddOrigin::new(s, endpoint.uri.clone())) + .layer_fn(|s| { + let origin = endpoint.origin.as_ref().unwrap_or(&endpoint.uri).clone(); + + AddOrigin::new(s, origin) + }) .layer_fn(|s| UserAgent::new(s, endpoint.user_agent.clone())) .layer_fn(|s| GrpcTimeout::new(s, endpoint.timeout)) .option_layer(endpoint.concurrency_limit.map(ConcurrencyLimitLayer::new))