Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More ergonomic support for using a unix socket with a tonic server #1

Merged
merged 3 commits into from
Dec 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
93 changes: 15 additions & 78 deletions examples/src/uds/server.rs
@@ -1,10 +1,14 @@
#![cfg_attr(not(unix), allow(unused_imports))]

use futures::TryFutureExt;
use std::path::Path;
#[cfg(unix)]
use tokio::net::UnixListener;
use tonic::{transport::Server, Request, Response, Status};
#[cfg(unix)]
use tokio_stream::wrappers::UnixListenerStream;
use tonic::{
transport::{server::UdsConnectInfo, Server},
Request, Response, Status,
};

pub mod hello_world {
tonic::include_proto!("helloworld");
Expand All @@ -26,8 +30,13 @@ impl Greeter for MyGreeter {
) -> Result<Response<HelloReply>, Status> {
#[cfg(unix)]
{
let conn_info = request.extensions().get::<unix::UdsConnectInfo>().unwrap();
let conn_info = request.extensions().get::<UdsConnectInfo>().unwrap();
println!("Got a request {:?} with info {:?}", request, conn_info);

// Client-side unix sockets are unnamed.
assert!(conn_info.peer_addr.as_ref().unwrap().is_unnamed());
// This should contain process credentials for the client socket.
assert!(conn_info.peer_cred.as_ref().is_some());
}

let reply = hello_world::HelloReply {
Expand All @@ -46,89 +55,17 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

let greeter = MyGreeter::default();

let incoming = {
let uds = UnixListener::bind(path)?;

async_stream::stream! {
loop {
let item = uds.accept().map_ok(|(st, _)| unix::UnixStream(st)).await;

yield item;
}
}
};
let uds = UnixListener::bind(path)?;
let uds_stream = UnixListenerStream::new(uds);

Server::builder()
.add_service(GreeterServer::new(greeter))
.serve_with_incoming(incoming)
.serve_with_incoming(uds_stream)
.await?;

Ok(())
}

#[cfg(unix)]
mod unix {
use std::{
pin::Pin,
sync::Arc,
task::{Context, Poll},
};

use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tonic::transport::server::Connected;

#[derive(Debug)]
pub struct UnixStream(pub tokio::net::UnixStream);

impl Connected for UnixStream {
type ConnectInfo = UdsConnectInfo;

fn connect_info(&self) -> Self::ConnectInfo {
UdsConnectInfo {
peer_addr: self.0.peer_addr().ok().map(Arc::new),
peer_cred: self.0.peer_cred().ok(),
}
}
}

#[derive(Clone, Debug)]
pub struct UdsConnectInfo {
pub peer_addr: Option<Arc<tokio::net::unix::SocketAddr>>,
pub peer_cred: Option<tokio::net::unix::UCred>,
}

impl AsyncRead for UnixStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.0).poll_read(cx, buf)
}
}

impl AsyncWrite for UnixStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
Pin::new(&mut self.0).poll_write(cx, buf)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.0).poll_flush(cx)
}

fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.0).poll_shutdown(cx)
}
}
}

#[cfg(not(unix))]
fn main() {
panic!("The `uds` example only works on unix systems!");
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/Cargo.toml
Expand Up @@ -12,7 +12,7 @@ version = "0.1.0"
bytes = "1.0"
futures-util = "0.3"
prost = "0.9"
tokio = {version = "1.0", features = ["macros", "rt-multi-thread", "net"]}
tokio = {version = "1.0", features = ["fs", "macros", "rt-multi-thread", "net"]}
tonic = {path = "../../tonic"}

[dev-dependencies]
Expand Down
73 changes: 73 additions & 0 deletions tests/integration_tests/tests/connect_info.rs
Expand Up @@ -48,3 +48,76 @@ async fn getting_connect_info() {

jh.await.unwrap();
}

#[cfg(unix)]
pub mod unix {
use std::path::Path;

use futures_util::FutureExt;
use tokio::{
net::{UnixListener, UnixStream},
sync::oneshot,
};
use tokio_stream::wrappers::UnixListenerStream;
use tonic::{
transport::{server::UdsConnectInfo, Endpoint, Server, Uri},
Request, Response, Status,
};
use tower::service_fn;

use integration_tests::pb::{test_client, test_server, Input, Output};

struct Svc {}

#[tonic::async_trait]
impl test_server::Test for Svc {
async fn unary_call(&self, req: Request<Input>) -> Result<Response<Output>, Status> {
let conn_info = req.extensions().get::<UdsConnectInfo>().unwrap();

// Client-side unix sockets are unnamed.
assert!(req.remote_addr().is_none());
assert!(conn_info.peer_addr.as_ref().unwrap().is_unnamed());
// This should contain process credentials for the client socket.
assert!(conn_info.peer_cred.as_ref().is_some());

Ok(Response::new(Output {}))
}
}

#[tokio::test]
async fn getting_connect_info() {
let unix_socket_path = "/tmp/uds-integration-test";
agreen17 marked this conversation as resolved.
Show resolved Hide resolved
let uds = UnixListener::bind(unix_socket_path).unwrap();
let uds_stream = UnixListenerStream::new(uds);

let service = test_server::TestServer::new(Svc {});
let (tx, rx) = oneshot::channel::<()>();

let jh = tokio::spawn(async move {
Server::builder()
.add_service(service)
.serve_with_incoming_shutdown(uds_stream, rx.map(drop))
.await
.unwrap();
});

let channel = Endpoint::try_from("http://[::]:50051")
.unwrap()
.connect_with_connector(service_fn(move |_: Uri| {
UnixStream::connect(unix_socket_path)
}))
.await
.unwrap();
Comment on lines +104 to +110

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How easy would it be to make Endpoint::try_from("unix:///tmp/uds-integration-test) work without the fake url/custom connector? Probably fine to punt.

Copy link
Author

@agreen17 agreen17 Dec 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like this would get a bit hairy as implementing something in Endpoint would bubble back up though tonic::transport::Channel and tonic::transport::service::Connection and we'd have to find some way to handle an optional URI inside an Endpoint through those (and update tests, etc). Seems like more trouble than it's worth considering this only saves consumers 1 LOC (although I do realize it is awkward)


let mut client = test_client::TestClient::new(channel);

client.unary_call(Input {}).await.unwrap();

tx.send(()).unwrap();
jh.await.unwrap();

// tokio's `UnixListener` does not cleanup the socket automatically - we need to manually
// remove the file at the end of the test.
tokio::fs::remove_file(unix_socket_path).await.unwrap();
}
}
4 changes: 2 additions & 2 deletions tonic/src/request.rs
Expand Up @@ -202,8 +202,8 @@ impl<T> Request<T> {
/// Get the remote address of this connection.
///
/// This will return `None` if the `IO` type used
/// does not implement `Connected`. This currently,
/// only works on the server side.
/// does not implement `Connected` or when using a unix domain socket.
/// This currently only works on the server side.
pub fn remote_addr(&self) -> Option<SocketAddr> {
#[cfg(feature = "transport")]
{
Expand Down
32 changes: 31 additions & 1 deletion tonic/src/transport/server/conn.rs
Expand Up @@ -4,7 +4,7 @@ use tokio::net::TcpStream;

#[cfg(feature = "tls")]
use crate::transport::Certificate;
#[cfg(feature = "tls")]
#[cfg(any(unix, feature = "tls"))]
use std::sync::Arc;
#[cfg(feature = "tls")]
use tokio_rustls::{rustls::Session, server::TlsStream};
Expand Down Expand Up @@ -98,6 +98,36 @@ impl Connected for TcpStream {
}
}

/// Connection info for Unix domain socket streams.
///
/// This type will be accessible through [request extensions][ext] if you're using
/// a unix stream.
///
/// See [`Connected`] for more details.
///
/// [ext]: crate::Request::extensions
#[cfg(unix)]
#[cfg_attr(docsrs, doc(cfg(unix)))]
#[derive(Clone, Debug)]
pub struct UdsConnectInfo {
/// Peer address. This will be "unnamed" for client unix sockets.
pub peer_addr: Option<Arc<tokio::net::unix::SocketAddr>>,
/// Process credentials for the unix socket.
pub peer_cred: Option<tokio::net::unix::UCred>,
}

#[cfg(unix)]
impl Connected for tokio::net::UnixStream {
type ConnectInfo = UdsConnectInfo;

fn connect_info(&self) -> Self::ConnectInfo {
UdsConnectInfo {
peer_addr: self.peer_addr().ok().map(Arc::new),
peer_cred: self.peer_cred().ok(),
}
}
}

impl Connected for tokio::io::DuplexStream {
type ConnectInfo = ();

Expand Down
3 changes: 3 additions & 0 deletions tonic/src/transport/server/mod.rs
Expand Up @@ -17,6 +17,9 @@ pub use conn::TlsConnectInfo;
#[cfg(feature = "tls")]
use super::service::TlsAcceptor;

#[cfg(unix)]
pub use conn::UdsConnectInfo;

use incoming::TcpIncoming;

#[cfg(feature = "tls")]
Expand Down