Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tonic: Introduce a new method on Endpoint to override the origin #1013

Merged
merged 2 commits into from Jun 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
100 changes: 100 additions & 0 deletions tests/integration_tests/tests/origin.rs
@@ -0,0 +1,100 @@
use futures::future::BoxFuture;
use futures_util::FutureExt;
use integration_tests::pb::test_client;
use integration_tests::pb::{test_server, Input, Output};
use std::task::Context;
use std::task::Poll;
use std::time::Duration;
use tokio::sync::oneshot;
use tonic::codegen::http::Request;
use tonic::{
transport::{Endpoint, Server},
Response, Status,
};
use tower::Layer;
use tower::Service;

#[tokio::test]
async fn writes_origin_header() {
struct Svc;

#[tonic::async_trait]
impl test_server::Test for Svc {
async fn unary_call(
&self,
_req: tonic::Request<Input>,
) -> Result<Response<Output>, Status> {
Ok(Response::new(Output {}))
}
}

let svc = test_server::TestServer::new(Svc);

let (tx, rx) = oneshot::channel::<()>();

let jh = tokio::spawn(async move {
Server::builder()
.layer(OriginLayer {})
.add_service(svc)
.serve_with_shutdown("127.0.0.1:1442".parse().unwrap(), rx.map(drop))
.await
.unwrap();
});

tokio::time::sleep(Duration::from_millis(100)).await;

let channel = Endpoint::from_static("http://127.0.0.1:1442")
.origin("https://docs.rs".parse().expect("valid uri"))
.connect()
.await
.unwrap();

let mut client = test_client::TestClient::new(channel);

match client.unary_call(Input {}).await {
Ok(_) => {}
Err(status) => panic!("{}", status.message()),
}

tx.send(()).unwrap();

jh.await.unwrap();
}

#[derive(Clone)]
struct OriginLayer {}

impl<S> Layer<S> for OriginLayer {
type Service = OriginService<S>;

fn layer(&self, inner: S) -> Self::Service {
OriginService { inner }
}
}

#[derive(Clone)]
struct OriginService<S> {
inner: S,
}

impl<T> Service<Request<tonic::transport::Body>> for OriginService<T>
where
T: Service<Request<tonic::transport::Body>>,
T::Future: Send + 'static,
T::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
{
type Response = T::Response;
type Error = Box<dyn std::error::Error + Send + Sync>;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(Into::into)
}

fn call(&mut self, req: Request<tonic::transport::Body>) -> Self::Future {
assert_eq!(req.uri().host(), Some("docs.rs"));
let fut = self.inner.call(req);

Box::pin(async move { fut.await.map_err(Into::into) })
}
}
21 changes: 21 additions & 0 deletions tonic/src/transport/channel/endpoint.rs
Expand Up @@ -24,6 +24,7 @@ use tower::make::MakeConnection;
#[derive(Clone)]
pub struct Endpoint {
pub(crate) uri: Uri,
pub(crate) origin: Option<Uri>,
pub(crate) user_agent: Option<HeaderValue>,
pub(crate) timeout: Option<Duration>,
pub(crate) concurrency_limit: Option<usize>,
Expand Down Expand Up @@ -106,6 +107,25 @@ impl Endpoint {
.map_err(|_| Error::new_invalid_user_agent())
}

/// Set a custom origin.
///
/// Override the `origin`, mainly useful when you are reaching a Server/LoadBalancer
/// which serves multiple services at the same time.
/// It will play the role of SNI (Server Name Indication).
///
/// ```
/// # use tonic::transport::Endpoint;
/// # let mut builder = Endpoint::from_static("https://proxy.com");
QuentinPerez marked this conversation as resolved.
Show resolved Hide resolved
/// builder.origin("https://example.com".parse().expect("http://example.com must be a valid URI"));
/// // origin: "https://example.com"
/// ```
pub fn origin(self, origin: Uri) -> Self {
Endpoint {
origin: Some(origin),
..self
}
}

/// Apply a timeout to each request.
///
/// ```
Expand Down Expand Up @@ -395,6 +415,7 @@ impl From<Uri> for Endpoint {
fn from(uri: Uri) -> Self {
Self {
uri,
origin: None,
user_agent: None,
concurrency_limit: None,
rate_limit: None,
Expand Down
39 changes: 25 additions & 14 deletions tonic/src/transport/service/add_origin.rs
@@ -1,17 +1,28 @@
use futures_core::future::BoxFuture;
use http::uri::Authority;
use http::uri::Scheme;
use http::{Request, Uri};
use std::task::{Context, Poll};
use tower_service::Service;

#[derive(Debug)]
pub(crate) struct AddOrigin<T> {
inner: T,
origin: Uri,
scheme: Option<Scheme>,
authority: Option<Authority>,
}

impl<T> AddOrigin<T> {
pub(crate) fn new(inner: T, origin: Uri) -> Self {
Self { inner, origin }
let http::uri::Parts {
scheme, authority, ..
} = origin.into_parts();

Self {
inner,
scheme,
authority,
}
}
}

Expand All @@ -30,24 +41,24 @@ where
}

fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
// Split the request into the head and the body.
let (mut head, body) = req.into_parts();

// Split the request URI into parts.
let mut uri: http::uri::Parts = head.uri.into();
let set_uri = self.origin.clone().into_parts();

if set_uri.scheme.is_none() || set_uri.authority.is_none() {
if self.scheme.is_none() || self.authority.is_none() {
let err = crate::transport::Error::new_invalid_uri();
return Box::pin(async move { Err::<Self::Response, _>(err.into()) });
}

// Update the URI parts, setting hte scheme and authority
uri.scheme = Some(set_uri.scheme.expect("expected scheme"));
uri.authority = Some(set_uri.authority.expect("expected authority"));
// Split the request into the head and the body.
let (mut head, body) = req.into_parts();

// Update the the request URI
head.uri = http::Uri::from_parts(uri).expect("valid uri");
head.uri = {
// Split the request URI into parts.
let mut uri: http::uri::Parts = head.uri.into();
// Update the URI parts, setting hte scheme and authority
uri.scheme = self.scheme.clone();
uri.authority = self.authority.clone();

http::Uri::from_parts(uri).expect("valid uri")
};

let request = Request::from_parts(head, body);

Expand Down
6 changes: 5 additions & 1 deletion tonic/src/transport/service/connection.rs
Expand Up @@ -55,7 +55,11 @@ impl Connection {
}

let stack = ServiceBuilder::new()
.layer_fn(|s| AddOrigin::new(s, endpoint.uri.clone()))
.layer_fn(|s| {
let origin = endpoint.origin.as_ref().unwrap_or(&endpoint.uri).clone();

AddOrigin::new(s, origin)
})
.layer_fn(|s| UserAgent::new(s, endpoint.user_agent.clone()))
.layer_fn(|s| GrpcTimeout::new(s, endpoint.timeout))
.option_layer(endpoint.concurrency_limit.map(ConcurrencyLimitLayer::new))
Expand Down