Skip to content

Commit

Permalink
Merge pull request #1 from StarryInternet/agreen/add-unix-socket-supp…
Browse files Browse the repository at this point in the history
…ort-in-server

More ergonomic support for using a unix socket with a tonic server
  • Loading branch information
agreen17 committed Dec 8, 2021
2 parents c62f382 + 22d1f1e commit 8f3dabf
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 82 deletions.
93 changes: 15 additions & 78 deletions examples/src/uds/server.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
#![cfg_attr(not(unix), allow(unused_imports))]

use futures::TryFutureExt;
use std::path::Path;
#[cfg(unix)]
use tokio::net::UnixListener;
use tonic::{transport::Server, Request, Response, Status};
#[cfg(unix)]
use tokio_stream::wrappers::UnixListenerStream;
use tonic::{
transport::{server::UdsConnectInfo, Server},
Request, Response, Status,
};

pub mod hello_world {
tonic::include_proto!("helloworld");
Expand All @@ -26,8 +30,13 @@ impl Greeter for MyGreeter {
) -> Result<Response<HelloReply>, Status> {
#[cfg(unix)]
{
let conn_info = request.extensions().get::<unix::UdsConnectInfo>().unwrap();
let conn_info = request.extensions().get::<UdsConnectInfo>().unwrap();
println!("Got a request {:?} with info {:?}", request, conn_info);

// Client-side unix sockets are unnamed.
assert!(conn_info.peer_addr.as_ref().unwrap().is_unnamed());
// This should contain process credentials for the client socket.
assert!(conn_info.peer_cred.as_ref().is_some());
}

let reply = hello_world::HelloReply {
Expand All @@ -46,89 +55,17 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

let greeter = MyGreeter::default();

let incoming = {
let uds = UnixListener::bind(path)?;

async_stream::stream! {
loop {
let item = uds.accept().map_ok(|(st, _)| unix::UnixStream(st)).await;

yield item;
}
}
};
let uds = UnixListener::bind(path)?;
let uds_stream = UnixListenerStream::new(uds);

Server::builder()
.add_service(GreeterServer::new(greeter))
.serve_with_incoming(incoming)
.serve_with_incoming(uds_stream)
.await?;

Ok(())
}

#[cfg(unix)]
mod unix {
use std::{
pin::Pin,
sync::Arc,
task::{Context, Poll},
};

use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tonic::transport::server::Connected;

#[derive(Debug)]
pub struct UnixStream(pub tokio::net::UnixStream);

impl Connected for UnixStream {
type ConnectInfo = UdsConnectInfo;

fn connect_info(&self) -> Self::ConnectInfo {
UdsConnectInfo {
peer_addr: self.0.peer_addr().ok().map(Arc::new),
peer_cred: self.0.peer_cred().ok(),
}
}
}

#[derive(Clone, Debug)]
pub struct UdsConnectInfo {
pub peer_addr: Option<Arc<tokio::net::unix::SocketAddr>>,
pub peer_cred: Option<tokio::net::unix::UCred>,
}

impl AsyncRead for UnixStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.0).poll_read(cx, buf)
}
}

impl AsyncWrite for UnixStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
Pin::new(&mut self.0).poll_write(cx, buf)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.0).poll_flush(cx)
}

fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.0).poll_shutdown(cx)
}
}
}

#[cfg(not(unix))]
fn main() {
panic!("The `uds` example only works on unix systems!");
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ version = "0.1.0"
bytes = "1.0"
futures-util = "0.3"
prost = "0.9"
tokio = {version = "1.0", features = ["macros", "rt-multi-thread", "net"]}
tokio = {version = "1.0", features = ["fs", "macros", "rt-multi-thread", "net"]}
tonic = {path = "../../tonic"}

[dev-dependencies]
Expand Down
73 changes: 73 additions & 0 deletions tests/integration_tests/tests/connect_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,76 @@ async fn getting_connect_info() {

jh.await.unwrap();
}

#[cfg(unix)]
pub mod unix {
use std::path::Path;

use futures_util::FutureExt;
use tokio::{
net::{UnixListener, UnixStream},
sync::oneshot,
};
use tokio_stream::wrappers::UnixListenerStream;
use tonic::{
transport::{server::UdsConnectInfo, Endpoint, Server, Uri},
Request, Response, Status,
};
use tower::service_fn;

use integration_tests::pb::{test_client, test_server, Input, Output};

struct Svc {}

#[tonic::async_trait]
impl test_server::Test for Svc {
async fn unary_call(&self, req: Request<Input>) -> Result<Response<Output>, Status> {
let conn_info = req.extensions().get::<UdsConnectInfo>().unwrap();

// Client-side unix sockets are unnamed.
assert!(req.remote_addr().is_none());
assert!(conn_info.peer_addr.as_ref().unwrap().is_unnamed());
// This should contain process credentials for the client socket.
assert!(conn_info.peer_cred.as_ref().is_some());

Ok(Response::new(Output {}))
}
}

#[tokio::test]
async fn getting_connect_info() {
let unix_socket_path = "/tmp/uds-integration-test";
let uds = UnixListener::bind(unix_socket_path).unwrap();
let uds_stream = UnixListenerStream::new(uds);

let service = test_server::TestServer::new(Svc {});
let (tx, rx) = oneshot::channel::<()>();

let jh = tokio::spawn(async move {
Server::builder()
.add_service(service)
.serve_with_incoming_shutdown(uds_stream, rx.map(drop))
.await
.unwrap();
});

let channel = Endpoint::try_from("http://[::]:50051")
.unwrap()
.connect_with_connector(service_fn(move |_: Uri| {
UnixStream::connect(unix_socket_path)
}))
.await
.unwrap();

let mut client = test_client::TestClient::new(channel);

client.unary_call(Input {}).await.unwrap();

tx.send(()).unwrap();
jh.await.unwrap();

// tokio's `UnixListener` does not cleanup the socket automatically - we need to manually
// remove the file at the end of the test.
tokio::fs::remove_file(unix_socket_path).await.unwrap();
}
}
4 changes: 2 additions & 2 deletions tonic/src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ impl<T> Request<T> {
/// Get the remote address of this connection.
///
/// This will return `None` if the `IO` type used
/// does not implement `Connected`. This currently,
/// only works on the server side.
/// does not implement `Connected` or when using a unix domain socket.
/// This currently only works on the server side.
pub fn remote_addr(&self) -> Option<SocketAddr> {
#[cfg(feature = "transport")]
{
Expand Down
32 changes: 31 additions & 1 deletion tonic/src/transport/server/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use tokio::net::TcpStream;

#[cfg(feature = "tls")]
use crate::transport::Certificate;
#[cfg(feature = "tls")]
#[cfg(any(unix, feature = "tls"))]
use std::sync::Arc;
#[cfg(feature = "tls")]
use tokio_rustls::{rustls::Session, server::TlsStream};
Expand Down Expand Up @@ -98,6 +98,36 @@ impl Connected for TcpStream {
}
}

/// Connection info for Unix domain socket streams.
///
/// This type will be accessible through [request extensions][ext] if you're using
/// a unix stream.
///
/// See [`Connected`] for more details.
///
/// [ext]: crate::Request::extensions
#[cfg(unix)]
#[cfg_attr(docsrs, doc(cfg(unix)))]
#[derive(Clone, Debug)]
pub struct UdsConnectInfo {
/// Peer address. This will be "unnamed" for client unix sockets.
pub peer_addr: Option<Arc<tokio::net::unix::SocketAddr>>,
/// Process credentials for the unix socket.
pub peer_cred: Option<tokio::net::unix::UCred>,
}

#[cfg(unix)]
impl Connected for tokio::net::UnixStream {
type ConnectInfo = UdsConnectInfo;

fn connect_info(&self) -> Self::ConnectInfo {
UdsConnectInfo {
peer_addr: self.peer_addr().ok().map(Arc::new),
peer_cred: self.peer_cred().ok(),
}
}
}

impl Connected for tokio::io::DuplexStream {
type ConnectInfo = ();

Expand Down
3 changes: 3 additions & 0 deletions tonic/src/transport/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ pub use conn::TlsConnectInfo;
#[cfg(feature = "tls")]
use super::service::TlsAcceptor;

#[cfg(unix)]
pub use conn::UdsConnectInfo;

use incoming::TcpIncoming;

#[cfg(feature = "tls")]
Expand Down

0 comments on commit 8f3dabf

Please sign in to comment.