diff --git a/examples/tls.rs b/examples/tls.rs index ffe784e8e..3f117013b 100644 --- a/examples/tls.rs +++ b/examples/tls.rs @@ -11,7 +11,9 @@ async fn main() { let routes = warp::any().map(|| "Hello, World!"); warp::serve(routes) - .tls("examples/tls/cert.pem", "examples/tls/key.rsa") + .tls() + .cert_path("examples/tls/cert.pem") + .key_path("examples/tls/key.rsa") .run(([127, 0, 0, 1], 3030)).await; } diff --git a/src/error.rs b/src/error.rs index aec7beffc..6139f4b67 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,14 +1,17 @@ use std::error::Error as StdError; use std::convert::Infallible; use std::fmt; -use std::io; -use hyper::Error as HyperError; -#[cfg(feature = "websocket")] -use tungstenite::Error as WsError; +type BoxError = Box; /// Errors that can happen inside warp. -pub struct Error(Box); +pub struct Error(BoxError); + +impl Error { + pub(crate) fn new>(err: E) -> Error { + Error(err.into()) + } +} impl fmt::Debug for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { @@ -19,51 +22,11 @@ impl fmt::Debug for Error { impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self.0.as_ref() { - Kind::Hyper(ref e) => fmt::Display::fmt(e, f), - Kind::Multipart(ref e) => fmt::Display::fmt(e, f), - #[cfg(feature = "websocket")] - Kind::Ws(ref e) => fmt::Display::fmt(e, f), - } - } -} - -impl StdError for Error { - #[allow(deprecated)] - fn cause(&self) -> Option<&dyn StdError> { - match self.0.as_ref() { - Kind::Hyper(ref e) => e.cause(), - Kind::Multipart(ref e) => e.cause(), - #[cfg(feature = "websocket")] - Kind::Ws(ref e) => e.cause(), - } - } -} - -pub(crate) enum Kind { - Hyper(HyperError), - Multipart(io::Error), - #[cfg(feature = "websocket")] - Ws(WsError), -} - -impl fmt::Debug for Kind { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Kind::Hyper(ref e) => fmt::Debug::fmt(e, f), - Kind::Multipart(ref e) => fmt::Debug::fmt(e, f), - #[cfg(feature = "websocket")] - Kind::Ws(ref e) => fmt::Debug::fmt(e, f), - } + fmt::Display::fmt(&self.0, f) } } -#[doc(hidden)] -impl From for Error { - fn from(kind: Kind) -> Error { - Error(Box::new(kind)) - } -} +impl StdError for Error {} impl From for Error { fn from(infallible: Infallible) -> Error { @@ -75,6 +38,6 @@ impl From for Error { fn error_size_of() { assert_eq!( ::std::mem::size_of::(), - ::std::mem::size_of::() + ::std::mem::size_of::() * 2 ); } diff --git a/src/filters/body.rs b/src/filters/body.rs index 924f7bafb..b3f2c8128 100644 --- a/src/filters/body.rs +++ b/src/filters/body.rs @@ -271,7 +271,7 @@ impl Stream for BodyStream { None => Poll::Ready(None), Some(item) => { let stream_buf = item - .map_err(|e| crate::Error::from(crate::error::Kind::Hyper(e))) + .map_err(crate::Error::new) .map(|chunk| StreamBuf { chunk }); Poll::Ready(Some(stream_buf)) diff --git a/src/filters/multipart.rs b/src/filters/multipart.rs index b09548e7d..a11a63263 100644 --- a/src/filters/multipart.rs +++ b/src/filters/multipart.rs @@ -113,7 +113,7 @@ impl Stream for FormData { field .data .read_to_end(&mut data) - .map_err(crate::error::Kind::Multipart)?; + .map_err(crate::Error::new)?; Poll::Ready(Some(Ok(Part { name: field.headers.name.to_string(), filename: field.headers.filename, @@ -122,7 +122,7 @@ impl Stream for FormData { }))) } Ok(None) => Poll::Ready(None), - Err(e) => Poll::Ready(Some(Err(crate::error::Kind::Multipart(e).into()))), + Err(e) => Poll::Ready(Some(Err(crate::Error::new(e)))), } } } diff --git a/src/filters/ws.rs b/src/filters/ws.rs index 8ca363dde..06e794617 100644 --- a/src/filters/ws.rs +++ b/src/filters/ws.rs @@ -13,7 +13,6 @@ use http; use tungstenite::protocol::{self, WebSocketConfig}; use tokio::io::{AsyncRead, AsyncWrite}; use super::{body, header}; -use crate::error::Kind; use crate::filter::{Filter, One}; use crate::reject::Rejection; use crate::reply::{Reply, Response}; @@ -228,13 +227,13 @@ impl Write for AllowStd } } -fn cvt(r: tungstenite::error::Result, err_message: &str) -> Poll> { +fn cvt(r: tungstenite::error::Result, err_message: &str) -> Poll> { match r { Ok(v) => Poll::Ready(Ok(v)), Err(tungstenite::Error::Io(ref e)) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending, Err(e) => { log::debug!("{} {}", err_message, e); - Poll::Ready(Err(Kind::Ws(e).into())) + Poll::Ready(Err(crate::Error::new(e))) } } } @@ -282,7 +281,7 @@ impl Stream for WebSocket { } Err(e) => { log::debug!("websocket poll error: {}", e); - return Poll::Ready(Some(Err(Kind::Ws(e).into()))); + return Poll::Ready(Some(Err(crate::Error::new(e)))); } }; @@ -332,7 +331,7 @@ impl Sink for WebSocket { } Err(e) => { log::debug!("websocket start_send error: {}", e); - Err(Kind::Ws(e).into()) + Err(crate::Error::new(e)) } } } @@ -353,7 +352,7 @@ impl Sink for WebSocket { Err(::tungstenite::Error::ConnectionClosed) => Poll::Ready(Ok(())), Err(err) => { log::debug!("websocket close error: {}", err); - Poll::Ready(Err(Kind::Ws(err).into())) + Poll::Ready(Err(crate::Error::new(err))) } } } diff --git a/src/server.rs b/src/server.rs index dae9d7be7..77e19dc1e 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,8 +1,10 @@ use std::error::Error as StdError; use std::net::SocketAddr; #[cfg(feature = "tls")] -use std::path::Path; +use crate::tls::TlsConfigBuilder; use std::sync::Arc; +#[cfg(feature = "tls")] +use std::path::Path; use std::pin::Pin; use std::task::{Context, Poll}; use std::future::Future; @@ -44,7 +46,8 @@ pub struct Server { #[cfg(feature = "tls")] pub struct TlsServer { server: Server, - tls: ::rustls::ServerConfig, + tls: TlsConfigBuilder, + //tls: ::rustls::ServerConfig, } // Getting all various generic bounds to make this a re-usable method is @@ -78,16 +81,17 @@ macro_rules! bind_inner { let srv = HyperServer::builder(incoming) .http1_pipeline_flush($this.pipeline) .serve(service); - Ok::<_, hyper::error::Error>((addr, srv)) + Ok::<_, hyper::Error>((addr, srv)) }}; (tls: $this:ident, $addr:expr) => {{ let service = into_service!($this.server.service); let (addr, incoming) = addr_incoming!($addr); - let srv = HyperServer::builder(crate::tls::TlsAcceptor::new($this.tls, incoming)) + let tls = $this.tls.build()?; + let srv = HyperServer::builder(crate::tls::TlsAcceptor::new(tls, incoming)) .http1_pipeline_flush($this.server.pipeline) .serve(service); - Ok::<_, hyper::error::Error>((addr, srv)) + Ok::<_, Box>((addr, srv)) }}; } @@ -231,10 +235,10 @@ where pub fn try_bind_ephemeral( self, addr: impl Into + 'static, - ) -> Result<(SocketAddr, impl Future + 'static), hyper::error::Error> + ) -> Result<(SocketAddr, impl Future + 'static), crate::Error> { let addr = addr.into(); - let (addr, srv) = try_bind!(self, &addr)?; + let (addr, srv) = try_bind!(self, &addr).map_err(crate::Error::new)?; let srv = srv.map(|result| { if let Err(err) = result { log::error!("server error: {}", err) @@ -334,14 +338,15 @@ where self } - /// Configure a server to use TLS with the supplied certificate and key files. + /// Configure a server to use TLS. /// /// *This function requires the `"tls"` feature.* #[cfg(feature = "tls")] - pub fn tls(self, cert: impl AsRef, key: impl AsRef) -> TlsServer { - let tls = crate::tls::configure(cert.as_ref(), key.as_ref()); - - TlsServer { server: self, tls } + pub fn tls(self) -> TlsServer { + TlsServer { + server: self, + tls: TlsConfigBuilder::new(), + } } } @@ -354,6 +359,39 @@ where <::Reply as TryFuture>::Ok: Reply + Send, <::Reply as TryFuture>::Error: IsReject + Send, { + // TLS config methods + + /// Specify the file path to read the private key. + pub fn key_path(self, path: impl AsRef) -> Self { + self.with_tls(|tls| tls.key_path(path)) + } + + /// Specify the file path to read the certificate. + pub fn cert_path(self, path: impl AsRef) -> Self { + self.with_tls(|tls| tls.cert_path(path)) + } + + /// Specify the in-memory contents of the private key. + pub fn key(self, key: impl AsRef<[u8]>) -> Self { + self.with_tls(|tls| tls.key(key.as_ref())) + } + + /// Specify the in-memory contents of the certificate. + pub fn cert(self, cert: impl AsRef<[u8]>) -> Self { + self.with_tls(|tls| tls.cert(cert.as_ref())) + } + + fn with_tls(self, func: F) -> Self + where + F: FnOnce(TlsConfigBuilder) -> TlsConfigBuilder, + { + let TlsServer { server, tls } = self; + let tls = func(tls); + TlsServer { server, tls } + } + + // Server run methods + /// Run this `TlsServer` forever on the current thread. /// /// *This function requires the `"tls"` feature.* @@ -366,7 +404,7 @@ where } /// Bind to a socket address, returning a `Future` that can be - /// executed on any runtime. + /// executed on a runtime. /// /// *This function requires the `"tls"` feature.* pub async fn bind( diff --git a/src/tls.rs b/src/tls.rs index 318442e71..a132786bf 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -1,57 +1,166 @@ use std::fs::File; -use std::io::{self, BufReader, Read, Write}; +use std::io::{self, BufReader, Cursor, Read, Write}; use std::net::SocketAddr; -use std::path::Path; +use std::path::{Path, PathBuf}; use std::sync::Arc; use std::pin::Pin; use std::ptr::null_mut; use std::task::{Poll, Context}; use futures::ready; -use rustls::{self, ServerConfig, ServerSession, Session, Stream}; +use rustls::{self, ServerConfig, ServerSession, Session, Stream, TLSError}; use hyper::server::accept::Accept; use hyper::server::conn::{AddrIncoming, AddrStream}; use tokio::io::{AsyncRead, AsyncWrite}; use crate::transport::Transport; -pub(crate) fn configure(cert: &Path, key: &Path) -> ServerConfig { - let cert = { - let file = File::open(cert).unwrap_or_else(|e| panic!("tls cert file error: {}", e)); - let mut rdr = BufReader::new(file); - rustls::internal::pemfile::certs(&mut rdr) - .unwrap_or_else(|()| panic!("tls cert parse error")) - }; - - let key = { - let mut pkcs8 = { - let file = File::open(&key).unwrap_or_else(|e| panic!("tls key file error: {}", e)); - let mut rdr = BufReader::new(file); - rustls::internal::pemfile::pkcs8_private_keys(&mut rdr) - .unwrap_or_else(|()| panic!("tls key pkcs8 error")) - }; +/// Represents errors that can occur building the TlsConfig +#[derive(Debug)] +pub(crate) enum TlsConfigError { + Io(io::Error), + /// An Error parsing the Certificate + CertParseError, + /// An Error parsing a Pkcs8 key + Pkcs8ParseError, + /// An Error parsing a Rsa key + RsaParseError, + /// An error from an empty key + EmptyKey, + /// An error from an invalid key + InvalidKey(TLSError) +} + +impl std::fmt::Display for TlsConfigError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TlsConfigError::Io(err) => err.fmt(f), + TlsConfigError::CertParseError => write!(f, "certificate parse error"), + TlsConfigError::Pkcs8ParseError => write!(f, "pkcs8 parse error"), + TlsConfigError::RsaParseError => write!(f, "rsa parse error"), + TlsConfigError::EmptyKey => write!(f, "key contains no private key"), + TlsConfigError::InvalidKey(err) => write!(f, "key contains an invalid key, {}", err), + } + } +} + +impl std::error::Error for TlsConfigError {} + +/// Builder to set the configuration for the Tls server. +pub(crate) struct TlsConfigBuilder { + cert: Box, + key: Box, +} + +impl std::fmt::Debug for TlsConfigBuilder { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + f.debug_struct("TlsConfigBuilder") + .finish() + } +} + +impl TlsConfigBuilder { + /// Create a new TlsConfigBuilder + pub(crate) fn new() -> TlsConfigBuilder { + TlsConfigBuilder { + key: Box::new(io::empty()), + cert: Box::new(io::empty()), + } + } + + /// sets the Tls key via File Path, returns `TlsConfigError::IoError` if the file cannot be open + pub(crate) fn key_path(mut self, path: impl AsRef) -> Self { + self.key = Box::new(LazyFile { + path: path.as_ref().into(), + file: None, + }); + self + } + + /// sets the Tls key via bytes slice + pub(crate) fn key(mut self, key: &[u8]) -> Self { + self.key = Box::new(Cursor::new(Vec::from(key))); + self + } + + + /// Specify the file path for the TLS certificate to use. + pub(crate) fn cert_path(mut self, path: impl AsRef) -> Self { + self.cert = Box::new(LazyFile { + path: path.as_ref().into(), + file: None, + }); + self + } + + /// sets the Tls certificate via bytes slice + pub(crate) fn cert(mut self, cert: &[u8]) -> Self { + self.cert = Box::new(Cursor::new(Vec::from(cert))); + self + } - if !pkcs8.is_empty() { - pkcs8.remove(0) - } else { - let file = File::open(key).unwrap_or_else(|e| panic!("tls key file error: {}", e)); - let mut rdr = BufReader::new(file); - let mut rsa = rustls::internal::pemfile::rsa_private_keys(&mut rdr) - .unwrap_or_else(|()| panic!("tls key rsa error")); + pub(crate) fn build(mut self) -> Result { + let mut cert_rdr = BufReader::new(self.cert); + let cert = rustls::internal::pemfile::certs(&mut cert_rdr) + .map_err(|()| TlsConfigError::CertParseError)?; - if !rsa.is_empty() { - rsa.remove(0) + let key = { + // convert it to Vec to allow reading it again if key is RSA + let mut key_vec = Vec::new(); + self.key.read_to_end(&mut key_vec) + .map_err(TlsConfigError::Io)?; + + if key_vec.is_empty() { + return Err(TlsConfigError::EmptyKey); + } + + let mut pkcs8 = rustls::internal::pemfile::pkcs8_private_keys(&mut key_vec.as_slice()) + .map_err(|()| TlsConfigError::Pkcs8ParseError)?; + + if !pkcs8.is_empty() { + pkcs8.remove(0) } else { - panic!("tls key path contains no private key"); + let mut rsa = rustls::internal::pemfile::rsa_private_keys(&mut key_vec.as_slice()) + .map_err(|()| TlsConfigError::RsaParseError)?; + + if !rsa.is_empty() { + rsa.remove(0) + } else { + return Err(TlsConfigError::EmptyKey); + } } + }; + + let mut config = ServerConfig::new(rustls::NoClientAuth::new()); + config.set_single_cert(cert, key) + .map_err(|err| TlsConfigError::InvalidKey(err))?; + config.set_protocols(&["h2".into(), "http/1.1".into()]); + Ok(config) + } +} + +struct LazyFile { + path: PathBuf, + file: Option, +} + +impl LazyFile { + fn lazy_read(&mut self, buf: &mut [u8]) -> io::Result { + if self.file.is_none() { + self.file = Some(File::open(&self.path)?); } - }; - let mut tls = ServerConfig::new(rustls::NoClientAuth::new()); - tls.set_single_cert(cert, key) - .unwrap_or_else(|e| panic!("tls failed: {}", e)); - tls.set_protocols(&["h2".into(), "http/1.1".into()]); - tls + self.file.as_mut().unwrap().read(buf) + } +} + +impl Read for LazyFile { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.lazy_read(buf).map_err(|err| { + let kind = err.kind(); + io::Error::new(kind, format!("error reading file ({:?}): {}", self.path.display(), err)) + }) + } } /// a wrapper arround T to allow for rustls Stream read/write translations to async read and write @@ -244,4 +353,30 @@ impl Accept for TlsAcceptor { None => Poll::Ready(None) } } -} \ No newline at end of file +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn file_cert_key() { + TlsConfigBuilder::new() + .key_path("examples/tls/key.rsa") + .cert_path("examples/tls/cert.pem") + .build() + .unwrap(); + } + + #[test] + fn bytes_cert_key() { + let key = include_str!("../examples/tls/key.rsa"); + let cert = include_str!("../examples/tls/cert.pem"); + + TlsConfigBuilder::new() + .key(key.as_bytes()) + .cert(cert.as_bytes()) + .build() + .unwrap(); + } +}