Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to latest tokio, tonic, and tokio-openssl #2

Merged
merged 1 commit into from
Oct 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ members = [
]

[dependencies]
tokio = "0.2"
tonic = "0.2"
async-stream = "0.2"
tokio-openssl = "0.4"
tokio = "1"
tonic = "0.5"
async-stream = "0.3"
tokio-openssl = "0.6"
openssl = "0.10"
futures = { version = "0.3", default-features = false }
15 changes: 8 additions & 7 deletions example/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@ name = "server"
path = "src/server.rs"

[dependencies]
tonic = "0.2"
tonic = "0.5"
tonic-openssl = { version = "0.1", path = ".." }
hyper = "0.13"
hyper-openssl = "0.8"
prost = "0.6"
tokio = { version = "0.2", features = ["full"] }
hyper = "0.14"
hyper-openssl = "0.9"
prost = "0.8"
tokio = { version = "1", features = ["full"] }
tokio-stream = { version = "0.1", features = ["net"] }
openssl = "0.10"
tower = "0.3"
tower = "0.4"
pretty_env_logger = "*"

[build-dependencies]
tonic-build = "0.2"
tonic-build = "0.5.0"
4 changes: 2 additions & 2 deletions example/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use hello_world::greeter_client::GreeterClient;
use hello_world::HelloRequest;
use hyper::{client::connect::HttpConnector, Client, Uri};
use hyper_openssl::HttpsConnector;
use openssl::{
ssl::{SslConnector, SslMethod},
x509::X509,
};
use tonic_openssl::ALPN_H2_WIRE;
use hyper_openssl::HttpsConnector;
use hyper::{Client, client::connect::HttpConnector, Uri};

pub mod hello_world {
tonic::include_proto!("helloworld");
Expand Down
14 changes: 10 additions & 4 deletions example/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ use hello_world::{HelloReply, HelloRequest};
use openssl::ssl::{select_next_proto, AlpnError, SslAcceptor, SslFiletype, SslMethod};
use std::net::SocketAddr;
use tokio::net::TcpListener;
use tonic_openssl::ALPN_H2_WIRE;
use tokio_stream::wrappers::TcpListenerStream;
use tonic::transport::server::TcpConnectInfo;
use tonic_openssl::{SslConnectInfo, ALPN_H2_WIRE};

pub mod hello_world {
tonic::include_proto!("helloworld");
Expand All @@ -29,8 +31,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

let addr = "[::1]:50051".parse::<SocketAddr>()?;

let mut listener = TcpListener::bind(addr).await?;
let incoming = tonic_openssl::incoming(listener.incoming(), acceptor);
let listener = TcpListener::bind(addr).await?;
let incoming = tonic_openssl::incoming(TcpListenerStream::new(listener), acceptor);

let greeter = MyGreeter::default();

Expand All @@ -53,7 +55,11 @@ impl Greeter for MyGreeter {
&self,
request: Request<HelloRequest>,
) -> Result<Response<HelloReply>, Status> {
println!("Got a request from {:?}", request.remote_addr());
let remote_addr = request
.extensions()
.get::<SslConnectInfo<TcpConnectInfo>>()
.and_then(|info| info.get_ref().remote_addr());
println!("Got a request from {:?}", remote_addr);

let reply = hello_world::HelloReply {
message: format!("Hello {}!", request.into_inner().name),
Expand Down
72 changes: 52 additions & 20 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@

use async_stream::try_stream;
use futures::{Stream, TryStream, TryStreamExt};
use openssl::ssl::SslAcceptor;
use openssl::ssl::{Ssl, SslAcceptor};
use std::{
fmt::Debug,
net::SocketAddr,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tonic::transport::{server::Connected, Certificate};

/// Wrapper error type.
Expand All @@ -46,7 +46,9 @@ where

try_stream! {
while let Some(stream) = incoming.try_next().await? {
let tls = tokio_openssl::accept(&acceptor, stream).await?;
let ssl = Ssl::new(acceptor.context())?;
let mut tls = tokio_openssl::SslStream::new(ssl, stream)?;
Pin::new(&mut tls).accept().await?;

let ssl = SslStream {
inner: tls
Expand All @@ -65,22 +67,24 @@ pub struct SslStream<S> {
}

impl<S: Connected> Connected for SslStream<S> {
fn remote_addr(&self) -> Option<SocketAddr> {
let tcp = self.inner.get_ref();
tcp.remote_addr()
}

fn peer_certs(&self) -> Option<Vec<Certificate>> {
let ssl = self.inner.ssl();
let certs = ssl.verified_chain()?;
type ConnectInfo = SslConnectInfo<S::ConnectInfo>;

let certs = certs
.iter()
.filter_map(|c| c.to_pem().ok())
.map(Certificate::from_pem)
.collect();
fn connect_info(&self) -> Self::ConnectInfo {
let inner = self.inner.get_ref().connect_info();

Some(certs)
let ssl = self.inner.ssl();
let certs = ssl
.verified_chain()
.map(|certs| {
certs
.iter()
.filter_map(|c| c.to_pem().ok())
.map(Certificate::from_pem)
.collect()
})
.map(Arc::new);

SslConnectInfo { inner, certs }
}
}

Expand All @@ -91,8 +95,8 @@ where
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.inner).poll_read(cx, buf)
}
}
Expand All @@ -117,3 +121,31 @@ where
Pin::new(&mut self.inner).poll_shutdown(cx)
}
}

/// Connection info for SSL streams.
///
/// This type will be accessible through [request extensions](tonic::Request::extensions).
///
/// See [`Connected`](tonic::transport::server::Connected) for more details.
#[derive(Debug, Clone)]
pub struct SslConnectInfo<T> {
inner: T,
certs: Option<Arc<Vec<Certificate>>>,
}

impl<T> SslConnectInfo<T> {
/// Get a reference to the underlying connection info.
pub fn get_ref(&self) -> &T {
&self.inner
}

/// Get a mutable reference to the underlying connection info.
pub fn get_mut(&mut self) -> &mut T {
&mut self.inner
}

/// Return the set of connected peer SSL certificates.
pub fn peer_certs(&self) -> Option<Arc<Vec<Certificate>>> {
self.certs.clone()
}
}