Skip to content

Commit

Permalink
Merge pull request #2 from pbzweihander/update
Browse files Browse the repository at this point in the history
  • Loading branch information
LucioFranco committed Oct 14, 2021
2 parents b680f73 + 6885f2f commit 35adb23
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 37 deletions.
8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ members = [
]

[dependencies]
tokio = "0.2"
tonic = "0.2"
async-stream = "0.2"
tokio-openssl = "0.4"
tokio = "1"
tonic = "0.5"
async-stream = "0.3"
tokio-openssl = "0.6"
openssl = "0.10"
futures = { version = "0.3", default-features = false }
15 changes: 8 additions & 7 deletions example/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@ name = "server"
path = "src/server.rs"

[dependencies]
tonic = "0.2"
tonic = "0.5"
tonic-openssl = { version = "0.1", path = ".." }
hyper = "0.13"
hyper-openssl = "0.8"
prost = "0.6"
tokio = { version = "0.2", features = ["full"] }
hyper = "0.14"
hyper-openssl = "0.9"
prost = "0.8"
tokio = { version = "1", features = ["full"] }
tokio-stream = { version = "0.1", features = ["net"] }
openssl = "0.10"
tower = "0.3"
tower = "0.4"
pretty_env_logger = "*"

[build-dependencies]
tonic-build = "0.2"
tonic-build = "0.5.0"
4 changes: 2 additions & 2 deletions example/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use hello_world::greeter_client::GreeterClient;
use hello_world::HelloRequest;
use hyper::{client::connect::HttpConnector, Client, Uri};
use hyper_openssl::HttpsConnector;
use openssl::{
ssl::{SslConnector, SslMethod},
x509::X509,
};
use tonic_openssl::ALPN_H2_WIRE;
use hyper_openssl::HttpsConnector;
use hyper::{Client, client::connect::HttpConnector, Uri};

pub mod hello_world {
tonic::include_proto!("helloworld");
Expand Down
14 changes: 10 additions & 4 deletions example/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ use hello_world::{HelloReply, HelloRequest};
use openssl::ssl::{select_next_proto, AlpnError, SslAcceptor, SslFiletype, SslMethod};
use std::net::SocketAddr;
use tokio::net::TcpListener;
use tonic_openssl::ALPN_H2_WIRE;
use tokio_stream::wrappers::TcpListenerStream;
use tonic::transport::server::TcpConnectInfo;
use tonic_openssl::{SslConnectInfo, ALPN_H2_WIRE};

pub mod hello_world {
tonic::include_proto!("helloworld");
Expand All @@ -29,8 +31,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

let addr = "[::1]:50051".parse::<SocketAddr>()?;

let mut listener = TcpListener::bind(addr).await?;
let incoming = tonic_openssl::incoming(listener.incoming(), acceptor);
let listener = TcpListener::bind(addr).await?;
let incoming = tonic_openssl::incoming(TcpListenerStream::new(listener), acceptor);

let greeter = MyGreeter::default();

Expand All @@ -53,7 +55,11 @@ impl Greeter for MyGreeter {
&self,
request: Request<HelloRequest>,
) -> Result<Response<HelloReply>, Status> {
println!("Got a request from {:?}", request.remote_addr());
let remote_addr = request
.extensions()
.get::<SslConnectInfo<TcpConnectInfo>>()
.and_then(|info| info.get_ref().remote_addr());
println!("Got a request from {:?}", remote_addr);

let reply = hello_world::HelloReply {
message: format!("Hello {}!", request.into_inner().name),
Expand Down
72 changes: 52 additions & 20 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@

use async_stream::try_stream;
use futures::{Stream, TryStream, TryStreamExt};
use openssl::ssl::SslAcceptor;
use openssl::ssl::{Ssl, SslAcceptor};
use std::{
fmt::Debug,
net::SocketAddr,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tonic::transport::{server::Connected, Certificate};

/// Wrapper error type.
Expand All @@ -46,7 +46,9 @@ where

try_stream! {
while let Some(stream) = incoming.try_next().await? {
let tls = tokio_openssl::accept(&acceptor, stream).await?;
let ssl = Ssl::new(acceptor.context())?;
let mut tls = tokio_openssl::SslStream::new(ssl, stream)?;
Pin::new(&mut tls).accept().await?;

let ssl = SslStream {
inner: tls
Expand All @@ -65,22 +67,24 @@ pub struct SslStream<S> {
}

impl<S: Connected> Connected for SslStream<S> {
fn remote_addr(&self) -> Option<SocketAddr> {
let tcp = self.inner.get_ref();
tcp.remote_addr()
}

fn peer_certs(&self) -> Option<Vec<Certificate>> {
let ssl = self.inner.ssl();
let certs = ssl.verified_chain()?;
type ConnectInfo = SslConnectInfo<S::ConnectInfo>;

let certs = certs
.iter()
.filter_map(|c| c.to_pem().ok())
.map(Certificate::from_pem)
.collect();
fn connect_info(&self) -> Self::ConnectInfo {
let inner = self.inner.get_ref().connect_info();

Some(certs)
let ssl = self.inner.ssl();
let certs = ssl
.verified_chain()
.map(|certs| {
certs
.iter()
.filter_map(|c| c.to_pem().ok())
.map(Certificate::from_pem)
.collect()
})
.map(Arc::new);

SslConnectInfo { inner, certs }
}
}

Expand All @@ -91,8 +95,8 @@ where
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.inner).poll_read(cx, buf)
}
}
Expand All @@ -117,3 +121,31 @@ where
Pin::new(&mut self.inner).poll_shutdown(cx)
}
}

/// Connection info for SSL streams.
///
/// This type will be accessible through [request extensions](tonic::Request::extensions).
///
/// See [`Connected`](tonic::transport::server::Connected) for more details.
#[derive(Debug, Clone)]
pub struct SslConnectInfo<T> {
inner: T,
certs: Option<Arc<Vec<Certificate>>>,
}

impl<T> SslConnectInfo<T> {
/// Get a reference to the underlying connection info.
pub fn get_ref(&self) -> &T {
&self.inner
}

/// Get a mutable reference to the underlying connection info.
pub fn get_mut(&mut self) -> &mut T {
&mut self.inner
}

/// Return the set of connected peer SSL certificates.
pub fn peer_certs(&self) -> Option<Arc<Vec<Certificate>>> {
self.certs.clone()
}
}

0 comments on commit 35adb23

Please sign in to comment.