Skip to content

Commit

Permalink
tonic: Introduce a new method on Endpoint to override the origin (#1013)
Browse files Browse the repository at this point in the history
  • Loading branch information
Quentin Perez committed Jun 20, 2022
1 parent 8287988 commit 4388d82
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 15 deletions.
100 changes: 100 additions & 0 deletions tests/integration_tests/tests/origin.rs
@@ -0,0 +1,100 @@
use futures::future::BoxFuture;
use futures_util::FutureExt;
use integration_tests::pb::test_client;
use integration_tests::pb::{test_server, Input, Output};
use std::task::Context;
use std::task::Poll;
use std::time::Duration;
use tokio::sync::oneshot;
use tonic::codegen::http::Request;
use tonic::{
transport::{Endpoint, Server},
Response, Status,
};
use tower::Layer;
use tower::Service;

#[tokio::test]
async fn writes_origin_header() {
struct Svc;

#[tonic::async_trait]
impl test_server::Test for Svc {
async fn unary_call(
&self,
_req: tonic::Request<Input>,
) -> Result<Response<Output>, Status> {
Ok(Response::new(Output {}))
}
}

let svc = test_server::TestServer::new(Svc);

let (tx, rx) = oneshot::channel::<()>();

let jh = tokio::spawn(async move {
Server::builder()
.layer(OriginLayer {})
.add_service(svc)
.serve_with_shutdown("127.0.0.1:1442".parse().unwrap(), rx.map(drop))
.await
.unwrap();
});

tokio::time::sleep(Duration::from_millis(100)).await;

let channel = Endpoint::from_static("http://127.0.0.1:1442")
.origin("https://docs.rs".parse().expect("valid uri"))
.connect()
.await
.unwrap();

let mut client = test_client::TestClient::new(channel);

match client.unary_call(Input {}).await {
Ok(_) => {}
Err(status) => panic!("{}", status.message()),
}

tx.send(()).unwrap();

jh.await.unwrap();
}

#[derive(Clone)]
struct OriginLayer {}

impl<S> Layer<S> for OriginLayer {
type Service = OriginService<S>;

fn layer(&self, inner: S) -> Self::Service {
OriginService { inner }
}
}

#[derive(Clone)]
struct OriginService<S> {
inner: S,
}

impl<T> Service<Request<tonic::transport::Body>> for OriginService<T>
where
T: Service<Request<tonic::transport::Body>>,
T::Future: Send + 'static,
T::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
{
type Response = T::Response;
type Error = Box<dyn std::error::Error + Send + Sync>;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(Into::into)
}

fn call(&mut self, req: Request<tonic::transport::Body>) -> Self::Future {
assert_eq!(req.uri().host(), Some("docs.rs"));
let fut = self.inner.call(req);

Box::pin(async move { fut.await.map_err(Into::into) })
}
}
21 changes: 21 additions & 0 deletions tonic/src/transport/channel/endpoint.rs
Expand Up @@ -24,6 +24,7 @@ use tower::make::MakeConnection;
#[derive(Clone)]
pub struct Endpoint {
pub(crate) uri: Uri,
pub(crate) origin: Option<Uri>,
pub(crate) user_agent: Option<HeaderValue>,
pub(crate) timeout: Option<Duration>,
pub(crate) concurrency_limit: Option<usize>,
Expand Down Expand Up @@ -106,6 +107,25 @@ impl Endpoint {
.map_err(|_| Error::new_invalid_user_agent())
}

/// Set a custom origin.
///
/// Override the `origin`, mainly useful when you are reaching a Server/LoadBalancer
/// which serves multiple services at the same time.
/// It will play the role of SNI (Server Name Indication).
///
/// ```
/// # use tonic::transport::Endpoint;
/// # let mut builder = Endpoint::from_static("https://proxy.com");
/// builder.origin("https://example.com".parse().expect("http://example.com must be a valid URI"));
/// // origin: "https://example.com"
/// ```
pub fn origin(self, origin: Uri) -> Self {
Endpoint {
origin: Some(origin),
..self
}
}

/// Apply a timeout to each request.
///
/// ```
Expand Down Expand Up @@ -395,6 +415,7 @@ impl From<Uri> for Endpoint {
fn from(uri: Uri) -> Self {
Self {
uri,
origin: None,
user_agent: None,
concurrency_limit: None,
rate_limit: None,
Expand Down
39 changes: 25 additions & 14 deletions tonic/src/transport/service/add_origin.rs
@@ -1,17 +1,28 @@
use futures_core::future::BoxFuture;
use http::uri::Authority;
use http::uri::Scheme;
use http::{Request, Uri};
use std::task::{Context, Poll};
use tower_service::Service;

#[derive(Debug)]
pub(crate) struct AddOrigin<T> {
inner: T,
origin: Uri,
scheme: Option<Scheme>,
authority: Option<Authority>,
}

impl<T> AddOrigin<T> {
pub(crate) fn new(inner: T, origin: Uri) -> Self {
Self { inner, origin }
let http::uri::Parts {
scheme, authority, ..
} = origin.into_parts();

Self {
inner,
scheme,
authority,
}
}
}

Expand All @@ -30,24 +41,24 @@ where
}

fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
// Split the request into the head and the body.
let (mut head, body) = req.into_parts();

// Split the request URI into parts.
let mut uri: http::uri::Parts = head.uri.into();
let set_uri = self.origin.clone().into_parts();

if set_uri.scheme.is_none() || set_uri.authority.is_none() {
if self.scheme.is_none() || self.authority.is_none() {
let err = crate::transport::Error::new_invalid_uri();
return Box::pin(async move { Err::<Self::Response, _>(err.into()) });
}

// Update the URI parts, setting hte scheme and authority
uri.scheme = Some(set_uri.scheme.expect("expected scheme"));
uri.authority = Some(set_uri.authority.expect("expected authority"));
// Split the request into the head and the body.
let (mut head, body) = req.into_parts();

// Update the the request URI
head.uri = http::Uri::from_parts(uri).expect("valid uri");
head.uri = {
// Split the request URI into parts.
let mut uri: http::uri::Parts = head.uri.into();
// Update the URI parts, setting hte scheme and authority
uri.scheme = self.scheme.clone();
uri.authority = self.authority.clone();

http::Uri::from_parts(uri).expect("valid uri")
};

let request = Request::from_parts(head, body);

Expand Down
6 changes: 5 additions & 1 deletion tonic/src/transport/service/connection.rs
Expand Up @@ -55,7 +55,11 @@ impl Connection {
}

let stack = ServiceBuilder::new()
.layer_fn(|s| AddOrigin::new(s, endpoint.uri.clone()))
.layer_fn(|s| {
let origin = endpoint.origin.as_ref().unwrap_or(&endpoint.uri).clone();

AddOrigin::new(s, origin)
})
.layer_fn(|s| UserAgent::new(s, endpoint.user_agent.clone()))
.layer_fn(|s| GrpcTimeout::new(s, endpoint.timeout))
.option_layer(endpoint.concurrency_limit.map(ConcurrencyLimitLayer::new))
Expand Down

0 comments on commit 4388d82

Please sign in to comment.