/
origin.rs
100 lines (81 loc) · 2.55 KB
/
origin.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
use futures::future::BoxFuture;
use futures_util::FutureExt;
use integration_tests::pb::test_client;
use integration_tests::pb::{test_server, Input, Output};
use std::task::Context;
use std::task::Poll;
use std::time::Duration;
use tokio::sync::oneshot;
use tonic::codegen::http::Request;
use tonic::{
transport::{Endpoint, Server},
Response, Status,
};
use tower::Layer;
use tower::Service;
#[tokio::test]
async fn writes_origin_header() {
struct Svc;
#[tonic::async_trait]
impl test_server::Test for Svc {
async fn unary_call(
&self,
_req: tonic::Request<Input>,
) -> Result<Response<Output>, Status> {
Ok(Response::new(Output {}))
}
}
let svc = test_server::TestServer::new(Svc);
let (tx, rx) = oneshot::channel::<()>();
let jh = tokio::spawn(async move {
Server::builder()
.layer(OriginLayer {})
.add_service(svc)
.serve_with_shutdown("127.0.0.1:1442".parse().unwrap(), rx.map(drop))
.await
.unwrap();
});
tokio::time::sleep(Duration::from_millis(100)).await;
let channel = Endpoint::from_static("http://127.0.0.1:1442")
.origin("https://docs.rs".parse().expect("valid uri"))
.connect()
.await
.unwrap();
let mut client = test_client::TestClient::new(channel);
match client.unary_call(Input {}).await {
Ok(_) => {}
Err(status) => panic!("{}", status.message()),
}
tx.send(()).unwrap();
jh.await.unwrap();
}
#[derive(Clone)]
struct OriginLayer {}
impl<S> Layer<S> for OriginLayer {
type Service = OriginService<S>;
fn layer(&self, inner: S) -> Self::Service {
OriginService { inner }
}
}
#[derive(Clone)]
struct OriginService<S> {
inner: S,
}
impl<T> Service<Request<tonic::transport::Body>> for OriginService<T>
where
T: Service<Request<tonic::transport::Body>>,
T::Future: Send + 'static,
T::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
{
type Response = T::Response;
type Error = Box<dyn std::error::Error + Send + Sync>;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(Into::into)
}
fn call(&mut self, req: Request<tonic::transport::Body>) -> Self::Future {
assert_eq!(req.uri().host(), Some("docs.rs"));
let fut = self.inner.call(req);
Box::pin(async move { fut.await.map_err(Into::into) })
}
}