Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(transport): add unix socket support in server #861

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
85 changes: 8 additions & 77 deletions examples/src/uds/server.rs
@@ -1,9 +1,12 @@
#![cfg_attr(not(unix), allow(unused_imports))]

use futures::TryFutureExt;
use std::path::Path;
#[cfg(unix)]
use tokio::net::UnixListener;
#[cfg(unix)]
use tokio_stream::wrappers::UnixListenerStream;
#[cfg(unix)]
use tonic::transport::server::UdsConnectInfo;
use tonic::{transport::Server, Request, Response, Status};

pub mod hello_world {
Expand All @@ -26,7 +29,7 @@ impl Greeter for MyGreeter {
) -> Result<Response<HelloReply>, Status> {
#[cfg(unix)]
{
let conn_info = request.extensions().get::<unix::UdsConnectInfo>().unwrap();
let conn_info = request.extensions().get::<UdsConnectInfo>().unwrap();
println!("Got a request {:?} with info {:?}", request, conn_info);
}

Expand All @@ -46,89 +49,17 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

let greeter = MyGreeter::default();

let incoming = {
let uds = UnixListener::bind(path)?;

async_stream::stream! {
loop {
let item = uds.accept().map_ok(|(st, _)| unix::UnixStream(st)).await;

yield item;
}
}
};
let uds = UnixListener::bind(path)?;
let uds_stream = UnixListenerStream::new(uds);

Server::builder()
.add_service(GreeterServer::new(greeter))
.serve_with_incoming(incoming)
.serve_with_incoming(uds_stream)
.await?;

Ok(())
}

#[cfg(unix)]
mod unix {
use std::{
pin::Pin,
sync::Arc,
task::{Context, Poll},
};

use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tonic::transport::server::Connected;

#[derive(Debug)]
pub struct UnixStream(pub tokio::net::UnixStream);

impl Connected for UnixStream {
type ConnectInfo = UdsConnectInfo;

fn connect_info(&self) -> Self::ConnectInfo {
UdsConnectInfo {
peer_addr: self.0.peer_addr().ok().map(Arc::new),
peer_cred: self.0.peer_cred().ok(),
}
}
}

#[derive(Clone, Debug)]
pub struct UdsConnectInfo {
pub peer_addr: Option<Arc<tokio::net::unix::SocketAddr>>,
pub peer_cred: Option<tokio::net::unix::UCred>,
}

impl AsyncRead for UnixStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.0).poll_read(cx, buf)
}
}

impl AsyncWrite for UnixStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
Pin::new(&mut self.0).poll_write(cx, buf)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.0).poll_flush(cx)
}

fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.0).poll_shutdown(cx)
}
}
}

#[cfg(not(unix))]
fn main() {
panic!("The `uds` example only works on unix systems!");
Expand Down
74 changes: 74 additions & 0 deletions tests/integration_tests/tests/connect_info.rs
Expand Up @@ -48,3 +48,77 @@ async fn getting_connect_info() {

jh.await.unwrap();
}

#[cfg(unix)]
pub mod unix {
use std::convert::TryFrom as _;

use futures_util::FutureExt;
use tokio::{
net::{UnixListener, UnixStream},
sync::oneshot,
};
use tokio_stream::wrappers::UnixListenerStream;
use tonic::{
transport::{server::UdsConnectInfo, Endpoint, Server, Uri},
Request, Response, Status,
};
use tower::service_fn;

use integration_tests::pb::{test_client, test_server, Input, Output};

struct Svc {}

#[tonic::async_trait]
impl test_server::Test for Svc {
async fn unary_call(&self, req: Request<Input>) -> Result<Response<Output>, Status> {
let conn_info = req.extensions().get::<UdsConnectInfo>().unwrap();

// Client-side unix sockets are unnamed.
assert!(req.remote_addr().is_none());
assert!(conn_info.peer_addr.as_ref().unwrap().is_unnamed());
// This should contain process credentials for the client socket.
assert!(conn_info.peer_cred.as_ref().is_some());

Ok(Response::new(Output {}))
}
}

#[tokio::test]
async fn getting_connect_info() {
let mut unix_socket_path = std::env::temp_dir();
unix_socket_path.push("uds-integration-test");

let uds = UnixListener::bind(&unix_socket_path).unwrap();
let uds_stream = UnixListenerStream::new(uds);

let service = test_server::TestServer::new(Svc {});
let (tx, rx) = oneshot::channel::<()>();

let jh = tokio::spawn(async move {
Server::builder()
.add_service(service)
.serve_with_incoming_shutdown(uds_stream, rx.map(drop))
.await
.unwrap();
});

// Take a copy before moving into the `service_fn` closure so that the closure
// can implement `FnMut`.
let path = unix_socket_path.clone();
let channel = Endpoint::try_from("http://[::]:50051")
.unwrap()
.connect_with_connector(service_fn(move |_: Uri| UnixStream::connect(path.clone())))
.await
.unwrap();

let mut client = test_client::TestClient::new(channel);

client.unary_call(Input {}).await.unwrap();

tx.send(()).unwrap();
jh.await.unwrap();

std::fs::remove_file(unix_socket_path).unwrap();
}
}
4 changes: 2 additions & 2 deletions tonic/src/request.rs
Expand Up @@ -202,8 +202,8 @@ impl<T> Request<T> {
/// Get the remote address of this connection.
///
/// This will return `None` if the `IO` type used
/// does not implement `Connected`. This currently,
/// only works on the server side.
/// does not implement `Connected` or when using a unix domain socket.
/// This currently only works on the server side.
pub fn remote_addr(&self) -> Option<SocketAddr> {
#[cfg(feature = "transport")]
{
Expand Down
5 changes: 5 additions & 0 deletions tonic/src/transport/server/mod.rs
Expand Up @@ -6,6 +6,8 @@ mod recover_error;
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
mod tls;
#[cfg(unix)]
mod unix;

pub use conn::{Connected, TcpConnectInfo};
#[cfg(feature = "tls")]
Expand All @@ -17,6 +19,9 @@ pub use conn::TlsConnectInfo;
#[cfg(feature = "tls")]
use super::service::TlsAcceptor;

#[cfg(unix)]
pub use unix::UdsConnectInfo;

use incoming::TcpIncoming;

#[cfg(feature = "tls")]
Expand Down
31 changes: 31 additions & 0 deletions tonic/src/transport/server/unix.rs
@@ -0,0 +1,31 @@
use super::Connected;
use std::sync::Arc;

/// Connection info for Unix domain socket streams.
///
/// This type will be accessible through [request extensions][ext] if you're using
/// a unix stream.
///
/// See [Connected] for more details.
///
/// [ext]: crate::Request::extensions
/// [Connected]: crate::transport::server::Connected
#[cfg_attr(docsrs, doc(cfg(unix)))]
#[derive(Clone, Debug)]
pub struct UdsConnectInfo {
/// Peer address. This will be "unnamed" for client unix sockets.
pub peer_addr: Option<Arc<tokio::net::unix::SocketAddr>>,
/// Process credentials for the unix socket.
pub peer_cred: Option<tokio::net::unix::UCred>,
}

impl Connected for tokio::net::UnixStream {
type ConnectInfo = UdsConnectInfo;

fn connect_info(&self) -> Self::ConnectInfo {
UdsConnectInfo {
peer_addr: self.peer_addr().ok().map(Arc::new),
peer_cred: self.peer_cred().ok(),
}
}
}