Skip to content

Commit

Permalink
Change Server::tls(path, path) to return a builder instead (#340)
Browse files Browse the repository at this point in the history
closes #223
  • Loading branch information
seanmonstar committed Dec 18, 2019
1 parent ea42ba5 commit 3a16d83
Show file tree
Hide file tree
Showing 7 changed files with 243 additions and 106 deletions.
4 changes: 3 additions & 1 deletion examples/tls.rs
Expand Up @@ -11,7 +11,9 @@ async fn main() {
let routes = warp::any().map(|| "Hello, World!");

warp::serve(routes)
.tls("examples/tls/cert.pem", "examples/tls/key.rsa")
.tls()
.cert_path("examples/tls/cert.pem")
.key_path("examples/tls/key.rsa")
.run(([127, 0, 0, 1], 3030)).await;
}

Expand Down
59 changes: 11 additions & 48 deletions src/error.rs
@@ -1,14 +1,17 @@
use std::error::Error as StdError;
use std::convert::Infallible;
use std::fmt;
use std::io;

use hyper::Error as HyperError;
#[cfg(feature = "websocket")]
use tungstenite::Error as WsError;
type BoxError = Box<dyn std::error::Error + Send + Sync>;

/// Errors that can happen inside warp.
pub struct Error(Box<Kind>);
pub struct Error(BoxError);

impl Error {
pub(crate) fn new<E: Into<BoxError>>(err: E) -> Error {
Error(err.into())
}
}

impl fmt::Debug for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
Expand All @@ -19,51 +22,11 @@ impl fmt::Debug for Error {

impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self.0.as_ref() {
Kind::Hyper(ref e) => fmt::Display::fmt(e, f),
Kind::Multipart(ref e) => fmt::Display::fmt(e, f),
#[cfg(feature = "websocket")]
Kind::Ws(ref e) => fmt::Display::fmt(e, f),
}
}
}

impl StdError for Error {
#[allow(deprecated)]
fn cause(&self) -> Option<&dyn StdError> {
match self.0.as_ref() {
Kind::Hyper(ref e) => e.cause(),
Kind::Multipart(ref e) => e.cause(),
#[cfg(feature = "websocket")]
Kind::Ws(ref e) => e.cause(),
}
}
}

pub(crate) enum Kind {
Hyper(HyperError),
Multipart(io::Error),
#[cfg(feature = "websocket")]
Ws(WsError),
}

impl fmt::Debug for Kind {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Kind::Hyper(ref e) => fmt::Debug::fmt(e, f),
Kind::Multipart(ref e) => fmt::Debug::fmt(e, f),
#[cfg(feature = "websocket")]
Kind::Ws(ref e) => fmt::Debug::fmt(e, f),
}
fmt::Display::fmt(&self.0, f)
}
}

#[doc(hidden)]
impl From<Kind> for Error {
fn from(kind: Kind) -> Error {
Error(Box::new(kind))
}
}
impl StdError for Error {}

impl From<Infallible> for Error {
fn from(infallible: Infallible) -> Error {
Expand All @@ -75,6 +38,6 @@ impl From<Infallible> for Error {
fn error_size_of() {
assert_eq!(
::std::mem::size_of::<Error>(),
::std::mem::size_of::<usize>()
::std::mem::size_of::<usize>() * 2
);
}
2 changes: 1 addition & 1 deletion src/filters/body.rs
Expand Up @@ -271,7 +271,7 @@ impl Stream for BodyStream {
None => Poll::Ready(None),
Some(item) => {
let stream_buf = item
.map_err(|e| crate::Error::from(crate::error::Kind::Hyper(e)))
.map_err(crate::Error::new)
.map(|chunk| StreamBuf { chunk });

Poll::Ready(Some(stream_buf))
Expand Down
4 changes: 2 additions & 2 deletions src/filters/multipart.rs
Expand Up @@ -113,7 +113,7 @@ impl Stream for FormData {
field
.data
.read_to_end(&mut data)
.map_err(crate::error::Kind::Multipart)?;
.map_err(crate::Error::new)?;
Poll::Ready(Some(Ok(Part {
name: field.headers.name.to_string(),
filename: field.headers.filename,
Expand All @@ -122,7 +122,7 @@ impl Stream for FormData {
})))
}
Ok(None) => Poll::Ready(None),
Err(e) => Poll::Ready(Some(Err(crate::error::Kind::Multipart(e).into()))),
Err(e) => Poll::Ready(Some(Err(crate::Error::new(e)))),
}
}
}
Expand Down
11 changes: 5 additions & 6 deletions src/filters/ws.rs
Expand Up @@ -13,7 +13,6 @@ use http;
use tungstenite::protocol::{self, WebSocketConfig};
use tokio::io::{AsyncRead, AsyncWrite};
use super::{body, header};
use crate::error::Kind;
use crate::filter::{Filter, One};
use crate::reject::Rejection;
use crate::reply::{Reply, Response};
Expand Down Expand Up @@ -228,13 +227,13 @@ impl Write for AllowStd
}
}

fn cvt<T>(r: tungstenite::error::Result<T>, err_message: &str) -> Poll<Result<T, crate::error::Error>> {
fn cvt<T>(r: tungstenite::error::Result<T>, err_message: &str) -> Poll<Result<T, crate::Error>> {
match r {
Ok(v) => Poll::Ready(Ok(v)),
Err(tungstenite::Error::Io(ref e)) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
Err(e) => {
log::debug!("{} {}", err_message, e);
Poll::Ready(Err(Kind::Ws(e).into()))
Poll::Ready(Err(crate::Error::new(e)))
}
}
}
Expand Down Expand Up @@ -282,7 +281,7 @@ impl Stream for WebSocket {
}
Err(e) => {
log::debug!("websocket poll error: {}", e);
return Poll::Ready(Some(Err(Kind::Ws(e).into())));
return Poll::Ready(Some(Err(crate::Error::new(e))));
}
};

Expand Down Expand Up @@ -332,7 +331,7 @@ impl Sink<Message> for WebSocket {
}
Err(e) => {
log::debug!("websocket start_send error: {}", e);
Err(Kind::Ws(e).into())
Err(crate::Error::new(e))
}
}
}
Expand All @@ -353,7 +352,7 @@ impl Sink<Message> for WebSocket {
Err(::tungstenite::Error::ConnectionClosed) => Poll::Ready(Ok(())),
Err(err) => {
log::debug!("websocket close error: {}", err);
Poll::Ready(Err(Kind::Ws(err).into()))
Poll::Ready(Err(crate::Error::new(err)))
}
}
}
Expand Down
64 changes: 51 additions & 13 deletions src/server.rs
@@ -1,8 +1,10 @@
use std::error::Error as StdError;
use std::net::SocketAddr;
#[cfg(feature = "tls")]
use std::path::Path;
use crate::tls::TlsConfigBuilder;
use std::sync::Arc;
#[cfg(feature = "tls")]
use std::path::Path;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::future::Future;
Expand Down Expand Up @@ -44,7 +46,8 @@ pub struct Server<S> {
#[cfg(feature = "tls")]
pub struct TlsServer<S> {
server: Server<S>,
tls: ::rustls::ServerConfig,
tls: TlsConfigBuilder,
//tls: ::rustls::ServerConfig,
}

// Getting all various generic bounds to make this a re-usable method is
Expand Down Expand Up @@ -78,16 +81,17 @@ macro_rules! bind_inner {
let srv = HyperServer::builder(incoming)
.http1_pipeline_flush($this.pipeline)
.serve(service);
Ok::<_, hyper::error::Error>((addr, srv))
Ok::<_, hyper::Error>((addr, srv))
}};

(tls: $this:ident, $addr:expr) => {{
let service = into_service!($this.server.service);
let (addr, incoming) = addr_incoming!($addr);
let srv = HyperServer::builder(crate::tls::TlsAcceptor::new($this.tls, incoming))
let tls = $this.tls.build()?;
let srv = HyperServer::builder(crate::tls::TlsAcceptor::new(tls, incoming))
.http1_pipeline_flush($this.server.pipeline)
.serve(service);
Ok::<_, hyper::error::Error>((addr, srv))
Ok::<_, Box<dyn std::error::Error + Send + Sync>>((addr, srv))
}};
}

Expand Down Expand Up @@ -231,10 +235,10 @@ where
pub fn try_bind_ephemeral(
self,
addr: impl Into<SocketAddr> + 'static,
) -> Result<(SocketAddr, impl Future<Output = ()> + 'static), hyper::error::Error>
) -> Result<(SocketAddr, impl Future<Output = ()> + 'static), crate::Error>
{
let addr = addr.into();
let (addr, srv) = try_bind!(self, &addr)?;
let (addr, srv) = try_bind!(self, &addr).map_err(crate::Error::new)?;
let srv = srv.map(|result| {
if let Err(err) = result {
log::error!("server error: {}", err)
Expand Down Expand Up @@ -334,14 +338,15 @@ where
self
}

/// Configure a server to use TLS with the supplied certificate and key files.
/// Configure a server to use TLS.
///
/// *This function requires the `"tls"` feature.*
#[cfg(feature = "tls")]
pub fn tls(self, cert: impl AsRef<Path>, key: impl AsRef<Path>) -> TlsServer<S> {
let tls = crate::tls::configure(cert.as_ref(), key.as_ref());

TlsServer { server: self, tls }
pub fn tls(self) -> TlsServer<S> {
TlsServer {
server: self,
tls: TlsConfigBuilder::new(),
}
}
}

Expand All @@ -354,6 +359,39 @@ where
<<S::Service as WarpService>::Reply as TryFuture>::Ok: Reply + Send,
<<S::Service as WarpService>::Reply as TryFuture>::Error: IsReject + Send,
{
// TLS config methods

/// Specify the file path to read the private key.
pub fn key_path(self, path: impl AsRef<Path>) -> Self {
self.with_tls(|tls| tls.key_path(path))
}

/// Specify the file path to read the certificate.
pub fn cert_path(self, path: impl AsRef<Path>) -> Self {
self.with_tls(|tls| tls.cert_path(path))
}

/// Specify the in-memory contents of the private key.
pub fn key(self, key: impl AsRef<[u8]>) -> Self {
self.with_tls(|tls| tls.key(key.as_ref()))
}

/// Specify the in-memory contents of the certificate.
pub fn cert(self, cert: impl AsRef<[u8]>) -> Self {
self.with_tls(|tls| tls.cert(cert.as_ref()))
}

fn with_tls<F>(self, func: F) -> Self
where
F: FnOnce(TlsConfigBuilder) -> TlsConfigBuilder,
{
let TlsServer { server, tls } = self;
let tls = func(tls);
TlsServer { server, tls }
}

// Server run methods

/// Run this `TlsServer` forever on the current thread.
///
/// *This function requires the `"tls"` feature.*
Expand All @@ -366,7 +404,7 @@ where
}

/// Bind to a socket address, returning a `Future` that can be
/// executed on any runtime.
/// executed on a runtime.
///
/// *This function requires the `"tls"` feature.*
pub async fn bind(
Expand Down

0 comments on commit 3a16d83

Please sign in to comment.