/
connect_info.rs
124 lines (96 loc) · 3.62 KB
/
connect_info.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
use futures_util::FutureExt;
use integration_tests::pb::{test_client, test_server, Input, Output};
use std::time::Duration;
use tokio::sync::oneshot;
use tonic::{
transport::{server::TcpConnectInfo, Endpoint, Server},
Request, Response, Status,
};
#[tokio::test]
async fn getting_connect_info() {
struct Svc;
#[tonic::async_trait]
impl test_server::Test for Svc {
async fn unary_call(&self, req: Request<Input>) -> Result<Response<Output>, Status> {
assert!(req.remote_addr().is_some());
assert!(req.extensions().get::<TcpConnectInfo>().is_some());
Ok(Response::new(Output {}))
}
}
let svc = test_server::TestServer::new(Svc);
let (tx, rx) = oneshot::channel::<()>();
let jh = tokio::spawn(async move {
Server::builder()
.add_service(svc)
.serve_with_shutdown("127.0.0.1:1400".parse().unwrap(), rx.map(drop))
.await
.unwrap();
});
tokio::time::sleep(Duration::from_millis(100)).await;
let channel = Endpoint::from_static("http://127.0.0.1:1400")
.connect()
.await
.unwrap();
let mut client = test_client::TestClient::new(channel);
client.unary_call(Input {}).await.unwrap();
tx.send(()).unwrap();
jh.await.unwrap();
}
#[cfg(unix)]
pub mod unix {
use std::convert::TryFrom as _;
use futures_util::FutureExt;
use tokio::{
net::{UnixListener, UnixStream},
sync::oneshot,
};
use tokio_stream::wrappers::UnixListenerStream;
use tonic::{
transport::{server::UdsConnectInfo, Endpoint, Server, Uri},
Request, Response, Status,
};
use tower::service_fn;
use integration_tests::pb::{test_client, test_server, Input, Output};
struct Svc {}
#[tonic::async_trait]
impl test_server::Test for Svc {
async fn unary_call(&self, req: Request<Input>) -> Result<Response<Output>, Status> {
let conn_info = req.extensions().get::<UdsConnectInfo>().unwrap();
// Client-side unix sockets are unnamed.
assert!(req.remote_addr().is_none());
assert!(conn_info.peer_addr.as_ref().unwrap().is_unnamed());
// This should contain process credentials for the client socket.
assert!(conn_info.peer_cred.as_ref().is_some());
Ok(Response::new(Output {}))
}
}
#[tokio::test]
async fn getting_connect_info() {
let mut unix_socket_path = std::env::temp_dir();
unix_socket_path.push("uds-integration-test");
let uds = UnixListener::bind(&unix_socket_path).unwrap();
let uds_stream = UnixListenerStream::new(uds);
let service = test_server::TestServer::new(Svc {});
let (tx, rx) = oneshot::channel::<()>();
let jh = tokio::spawn(async move {
Server::builder()
.add_service(service)
.serve_with_incoming_shutdown(uds_stream, rx.map(drop))
.await
.unwrap();
});
// Take a copy before moving into the `service_fn` closure so that the closure
// can implement `FnMut`.
let path = unix_socket_path.clone();
let channel = Endpoint::try_from("http://[::]:50051")
.unwrap()
.connect_with_connector(service_fn(move |_: Uri| UnixStream::connect(path.clone())))
.await
.unwrap();
let mut client = test_client::TestClient::new(channel);
client.unary_call(Input {}).await.unwrap();
tx.send(()).unwrap();
jh.await.unwrap();
std::fs::remove_file(unix_socket_path).unwrap();
}
}