Skip to content

Commit

Permalink
Change Server::tls(path, path) to return a builder instead
Browse files Browse the repository at this point in the history
closes #223
  • Loading branch information
jxs authored and seanmonstar committed Dec 18, 2019
1 parent b6cbcca commit 1972f01
Show file tree
Hide file tree
Showing 3 changed files with 203 additions and 45 deletions.
4 changes: 3 additions & 1 deletion examples/tls.rs
Expand Up @@ -11,7 +11,9 @@ async fn main() {
let routes = warp::any().map(|| "Hello, World!");

warp::serve(routes)
.tls("examples/tls/cert.pem", "examples/tls/key.rsa")
.tls()
.cert_path("examples/tls/cert.pem")
.key_path("examples/tls/key.rs")
.run(([127, 0, 0, 1], 3030)).await;
}

Expand Down
56 changes: 47 additions & 9 deletions src/server.rs
@@ -1,8 +1,10 @@
use std::error::Error as StdError;
use std::net::SocketAddr;
#[cfg(feature = "tls")]
use std::path::Path;
use crate::tls::TlsConfigBuilder;
use std::sync::Arc;
#[cfg(feature = "tls")]
use std::path::Path;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::future::Future;
Expand Down Expand Up @@ -44,7 +46,8 @@ pub struct Server<S> {
#[cfg(feature = "tls")]
pub struct TlsServer<S> {
server: Server<S>,
tls: ::rustls::ServerConfig,
tls: TlsConfigBuilder,
//tls: ::rustls::ServerConfig,
}

// Getting all various generic bounds to make this a re-usable method is
Expand Down Expand Up @@ -84,7 +87,8 @@ macro_rules! bind_inner {
(tls: $this:ident, $addr:expr) => {{
let service = into_service!($this.server.service);
let (addr, incoming) = addr_incoming!($addr);
let srv = HyperServer::builder(crate::tls::TlsAcceptor::new($this.tls, incoming))
let tls = $this.tls.build().expect("TODO");
let srv = HyperServer::builder(crate::tls::TlsAcceptor::new(tls, incoming))
.http1_pipeline_flush($this.server.pipeline)
.serve(service);
Ok::<_, hyper::error::Error>((addr, srv))
Expand Down Expand Up @@ -334,14 +338,15 @@ where
self
}

/// Configure a server to use TLS with the supplied certificate and key files.
/// Configure a server to use TLS.
///
/// *This function requires the `"tls"` feature.*
#[cfg(feature = "tls")]
pub fn tls(self, cert: impl AsRef<Path>, key: impl AsRef<Path>) -> TlsServer<S> {
let tls = crate::tls::configure(cert.as_ref(), key.as_ref());

TlsServer { server: self, tls }
pub fn tls(self) -> TlsServer<S> {
TlsServer {
server: self,
tls: TlsConfigBuilder::new(),
}
}
}

Expand All @@ -354,6 +359,39 @@ where
<<S::Service as WarpService>::Reply as TryFuture>::Ok: Reply + Send,
<<S::Service as WarpService>::Reply as TryFuture>::Error: IsReject + Send,
{
// TLS config methods

/// Specify the file path to read the private key.
pub fn key_path(self, path: impl AsRef<Path>) -> Self {
self.with_tls(|tls| tls.key_path(path))
}

/// Specify the file path to read the certificate.
pub fn cert_path(self, path: impl AsRef<Path>) -> Self {
self.with_tls(|tls| tls.cert_path(path))
}

/// Specify the in-memory contents of the private key.
pub fn key(self, key: impl AsRef<[u8]>) -> Self {
self.with_tls(|tls| tls.key(key.as_ref()))
}

/// Specify the in-memory contents of the certificate.
pub fn cert(self, cert: impl AsRef<[u8]>) -> Self {
self.with_tls(|tls| tls.cert(cert.as_ref()))
}

fn with_tls<F>(self, func: F) -> Self
where
F: FnOnce(TlsConfigBuilder) -> TlsConfigBuilder,
{
let TlsServer { server, tls } = self;
let tls = func(tls);
TlsServer { server, tls }
}

// Server run methods

/// Run this `TlsServer` forever on the current thread.
///
/// *This function requires the `"tls"` feature.*
Expand All @@ -366,7 +404,7 @@ where
}

/// Bind to a socket address, returning a `Future` that can be
/// executed on any runtime.
/// executed on a runtime.
///
/// *This function requires the `"tls"` feature.*
pub async fn bind(
Expand Down
188 changes: 153 additions & 35 deletions src/tls.rs
@@ -1,57 +1,149 @@
use std::fs::File;
use std::io::{self, BufReader, Read, Write};
use std::io::{self, BufReader, Cursor, Read, Write};
use std::net::SocketAddr;
use std::path::Path;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::pin::Pin;
use std::ptr::null_mut;
use std::task::{Poll, Context};

use futures::ready;
use rustls::{self, ServerConfig, ServerSession, Session, Stream};
use rustls::{self, ServerConfig, ServerSession, Session, Stream, TLSError};
use hyper::server::accept::Accept;
use hyper::server::conn::{AddrIncoming, AddrStream};
use tokio::io::{AsyncRead, AsyncWrite};

use crate::transport::Transport;

pub(crate) fn configure(cert: &Path, key: &Path) -> ServerConfig {
let cert = {
let file = File::open(cert).unwrap_or_else(|e| panic!("tls cert file error: {}", e));
let mut rdr = BufReader::new(file);
rustls::internal::pemfile::certs(&mut rdr)
.unwrap_or_else(|()| panic!("tls cert parse error"))
};

let key = {
let mut pkcs8 = {
let file = File::open(&key).unwrap_or_else(|e| panic!("tls key file error: {}", e));
let mut rdr = BufReader::new(file);
rustls::internal::pemfile::pkcs8_private_keys(&mut rdr)
.unwrap_or_else(|()| panic!("tls key pkcs8 error"))
};
/// Represents errors that can occur building the TlsConfig
#[derive(Debug)]
pub(crate) enum TlsConfigError {
/// An Error parsing the Certificate
CertParseError,
/// An Error parsing a Pkcs8 key
Pkcs8ParseError,
/// An Error parsing a Rsa key
RsaParseError,
/// An error from an empty key
EmptyKey,
/// An error from an invalid key
InvalidKey(TLSError)
}

impl std::fmt::Display for TlsConfigError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TlsConfigError::CertParseError => write!(f, "certificate parse error"),
TlsConfigError::Pkcs8ParseError => write!(f, "pkcs8 parse error"),
TlsConfigError::RsaParseError => write!(f, "rsa parse error"),
TlsConfigError::EmptyKey => write!(f, "key contains no private key"),
TlsConfigError::InvalidKey(err) => write!(f, "key contains an invalid key, {}", err),
}
}
}

/// Builder to set the configuration for the Tls server.
pub(crate) struct TlsConfigBuilder {
cert: Box<dyn Read + Send + Sync>,
key: Box<dyn Read + Send + Sync>,
}

impl std::fmt::Debug for TlsConfigBuilder {
fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
f.debug_struct("TlsConfigBuilder")
.finish()
}
}

impl TlsConfigBuilder {
/// Create a new TlsConfigBuilder
pub(crate) fn new() -> TlsConfigBuilder {
TlsConfigBuilder {
key: Box::new(io::empty()),
cert: Box::new(io::empty()),
}
}

if !pkcs8.is_empty() {
pkcs8.remove(0)
} else {
let file = File::open(key).unwrap_or_else(|e| panic!("tls key file error: {}", e));
let mut rdr = BufReader::new(file);
let mut rsa = rustls::internal::pemfile::rsa_private_keys(&mut rdr)
.unwrap_or_else(|()| panic!("tls key rsa error"));
/// sets the Tls key via File Path, returns `TlsConfigError::IoError` if the file cannot be open
pub(crate) fn key_path(mut self, path: impl AsRef<Path>) -> Self {
self.key = Box::new(LazyFile {
path: path.as_ref().into(),
file: None,
});
self
}

if !rsa.is_empty() {
rsa.remove(0)
/// sets the Tls key via bytes slice
pub(crate) fn key(mut self, key: &[u8]) -> Self {
self.key = Box::new(Cursor::new(Vec::from(key)));
self
}


/// Specify the file path for the TLS certificate to use.
pub(crate) fn cert_path(mut self, path: impl AsRef<Path>) -> Self {
self.cert = Box::new(LazyFile {
path: path.as_ref().into(),
file: None,
});
self
}

/// sets the Tls certificate via bytes slice
pub(crate) fn cert(mut self, cert: &[u8]) -> Self {
self.cert = Box::new(Cursor::new(Vec::from(cert)));
self
}

pub(crate) fn build(self) -> Result<ServerConfig, TlsConfigError> {
let mut cert_rdr = BufReader::new(self.cert);
let cert = rustls::internal::pemfile::certs(&mut cert_rdr)
.map_err(|()| TlsConfigError::CertParseError)?;

let key = {
// convert it to Vec<u8> to allow reading it again if key is RSA

let mut key_rdr = BufReader::new(self.key);
let mut pkcs8_buf = BufReader::new(key_rdr.buffer());

let mut pkcs8 = rustls::internal::pemfile::pkcs8_private_keys(&mut pkcs8_buf)
.map_err(|()| TlsConfigError::Pkcs8ParseError)?;

if !pkcs8.is_empty() {
pkcs8.remove(0)
} else {
panic!("tls key path contains no private key");
let mut rsa = rustls::internal::pemfile::rsa_private_keys(&mut key_rdr)
.map_err(|()| TlsConfigError::RsaParseError)?;

if !rsa.is_empty() {
rsa.remove(0)
} else {
return Err(TlsConfigError::EmptyKey);
}
}
};

let mut config = ServerConfig::new(rustls::NoClientAuth::new());
config.set_single_cert(cert, key)
.map_err(|err| TlsConfigError::InvalidKey(err))?;
config.set_protocols(&["h2".into(), "http/1.1".into()]);
Ok(config)
}
}

struct LazyFile {
path: PathBuf,
file: Option<File>,
}

impl Read for LazyFile {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if self.file.is_none() {
self.file = Some(File::open(&self.path)?);
}
};

let mut tls = ServerConfig::new(rustls::NoClientAuth::new());
tls.set_single_cert(cert, key)
.unwrap_or_else(|e| panic!("tls failed: {}", e));
tls.set_protocols(&["h2".into(), "http/1.1".into()]);
tls
self.file.as_mut().unwrap().read(buf)
}
}

/// a wrapper arround T to allow for rustls Stream read/write translations to async read and write
Expand Down Expand Up @@ -244,4 +336,30 @@ impl Accept for TlsAcceptor {
None => Poll::Ready(None)
}
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn file_cert_key() {
TlsConfigBuilder::new()
.key_path("examples/tls/key.rsa")
.cert_path("examples/tls/cert.pem")
.build()
.unwrap();
}

#[test]
fn bytes_cert_key() {
let key = include_str!("../examples/tls/key.rsa");
let cert = include_str!("../examples/tls/cert.pem");

TlsConfigBuilder::new()
.key(key.as_bytes())
.cert(cert.as_bytes())
.build()
.unwrap();
}
}

0 comments on commit 1972f01

Please sign in to comment.