diff --git a/Cargo.toml b/Cargo.toml index 2aa00e0e9..7293f62f6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,7 +34,7 @@ tokio = { version = "0.2", features = ["blocking", "fs", "stream", "sync", "time tower-service = "0.3" rustls = { version = "0.16", optional = true } # tls is enabled by default, we don't want that yet -tungstenite = { default-features = false, version = "0.9", optional = true } +tokio-tungstenite = { version = "0.10", default-features = false, optional = true } urlencoding = "1.0.0" pin-project = "0.4.5" @@ -46,7 +46,7 @@ tokio = { version = "0.2", features = ["macros"] } [features] default = ["multipart", "websocket"] -websocket = ["tungstenite"] +websocket = ["tokio-tungstenite"] tls = ["rustls"] [profile.release] diff --git a/examples/rejections.rs b/examples/rejections.rs index 3ccc51167..286b7772b 100644 --- a/examples/rejections.rs +++ b/examples/rejections.rs @@ -27,13 +27,11 @@ async fn main() { /// Extract a denominator from a "div-by" header, or reject with DivideByZero. fn div_by() -> impl Filter + Copy { - warp::header::("div-by").and_then(|n: u16| { - async move { - if let Some(denom) = NonZeroU16::new(n) { - Ok(denom) - } else { - Err(reject::custom(DivideByZero)) - } + warp::header::("div-by").and_then(|n: u16| async move { + if let Some(denom) = NonZeroU16::new(n) { + Ok(denom) + } else { + Err(reject::custom(DivideByZero)) } }) } diff --git a/examples/sse_chat.rs b/examples/sse_chat.rs index ffb3064ed..693a87f55 100644 --- a/examples/sse_chat.rs +++ b/examples/sse_chat.rs @@ -22,13 +22,13 @@ async fn main() { .and(warp::post()) .and(warp::path::param::()) .and(warp::body::content_length_limit(500)) - .and(warp::body::bytes().and_then(|body: bytes::Bytes| { - async move { + .and( + warp::body::bytes().and_then(|body: bytes::Bytes| async move { std::str::from_utf8(&body) .map(String::from) .map_err(|_e| warp::reject::custom(NotUtf8)) - } - })) + }), + ) .and(users.clone()) .map(|my_id, msg, users| { user_message(my_id, msg, &users); diff --git a/src/filters/body.rs b/src/filters/body.rs index 7b57442e1..771ddce48 100644 --- a/src/filters/body.rs +++ b/src/filters/body.rs @@ -171,14 +171,14 @@ pub fn aggregate() -> impl Filter + Co /// }); /// ``` pub fn json() -> impl Filter + Copy { - is_content_type::().and(aggregate()).and_then(|buf| { - async move { + is_content_type::() + .and(aggregate()) + .and_then(|buf| async move { Json::decode(buf).map_err(|err| { log::debug!("request json body error: {}", err); reject::known(BodyDeserializeError { cause: err }) }) - } - }) + }) } /// Returns a `Filter` that matches any request and extracts a @@ -206,14 +206,14 @@ pub fn json() -> impl Filter() -> impl Filter + Copy { - is_content_type::
().and(aggregate()).and_then(|buf| { - async move { + is_content_type::() + .and(aggregate()) + .and_then(|buf| async move { Form::decode(buf).map_err(|err| { log::debug!("request form body error: {}", err); reject::known(BodyDeserializeError { cause: err }) }) - } - }) + }) } // ===== Decoders ===== diff --git a/src/filters/fs.rs b/src/filters/fs.rs index bba22745f..c79725c68 100644 --- a/src/filters/fs.rs +++ b/src/filters/fs.rs @@ -89,20 +89,18 @@ fn path_from_tail( base: Arc, ) -> impl FilterClone, Error = Rejection> { crate::path::tail().and_then(move |tail: crate::path::Tail| { - future::ready(sanitize_path(base.as_ref(), tail.as_str())).and_then(|mut buf| { - async { - let is_dir = tokio::fs::metadata(buf.clone()) - .await - .map(|m| m.is_dir()) - .unwrap_or(false); - - if is_dir { - log::debug!("dir: appending index.html to directory path"); - buf.push("index.html"); - } - log::trace!("dir: {:?}", buf); - Ok(ArcPath(Arc::new(buf))) + future::ready(sanitize_path(base.as_ref(), tail.as_str())).and_then(|mut buf| async { + let is_dir = tokio::fs::metadata(buf.clone()) + .await + .map(|m| m.is_dir()) + .unwrap_or(false); + + if is_dir { + log::debug!("dir: appending index.html to directory path"); + buf.push("index.html"); } + log::trace!("dir: {:?}", buf); + Ok(ArcPath(Arc::new(buf))) }) }) } diff --git a/src/filters/ws.rs b/src/filters/ws.rs index d34af0fbe..2dd8c8d9d 100644 --- a/src/filters/ws.rs +++ b/src/filters/ws.rs @@ -3,20 +3,20 @@ use std::borrow::Cow; use std::fmt; use std::future::Future; -use std::io::{self, Read, Write}; use std::pin::Pin; -use std::ptr::null_mut; use std::task::{Context, Poll}; use super::{body, header}; use crate::filter::{Filter, One}; use crate::reject::Rejection; use crate::reply::{Reply, Response}; -use futures::{future, FutureExt, Sink, Stream, TryFutureExt}; +use futures::{future, ready, FutureExt, Sink, Stream, TryFutureExt}; use headers::{Connection, HeaderMapExt, SecWebsocketAccept, SecWebsocketKey, Upgrade}; use http; -use tokio::io::{AsyncRead, AsyncWrite}; -use tungstenite::protocol::{self, WebSocketConfig}; +use tokio_tungstenite::{ + tungstenite::protocol::{self, WebSocketConfig}, + WebSocketStream, +}; /// Creates a Websocket Filter. /// @@ -132,18 +132,9 @@ where .on_upgrade() .and_then(move |upgraded| { log::trace!("websocket upgrade complete"); - - let io = protocol::WebSocket::from_raw_socket( - AllowStd { - inner: upgraded, - context: (true, null_mut()), - }, - protocol::Role::Server, - config, - ); - - on_upgrade(WebSocket { inner: io }).map(Ok) + WebSocket::from_raw_socket(upgraded, protocol::Role::Server, config).map(Ok) }) + .and_then(move |socket| on_upgrade(socket).map(Ok)) .map(|result| { if let Err(err) = result { log::debug!("ws upgrade error: {}", err); @@ -166,112 +157,18 @@ where /// A websocket `Stream` and `Sink`, provided to `ws` filters. pub struct WebSocket { - inner: protocol::WebSocket, -} - -/// wrapper around hyper Upgraded to allow Read/write from tungstenite's WebSocket -#[derive(Debug)] -pub(crate) struct AllowStd { - inner: ::hyper::upgrade::Upgraded, - context: (bool, *mut ()), -} - -struct Guard<'a>(&'a mut WebSocket); - -impl Drop for Guard<'_> { - fn drop(&mut self) { - (self.0).inner.get_mut().context = (true, null_mut()); - } -} - -// *mut () context is neither Send nor Sync -unsafe impl Send for AllowStd {} -unsafe impl Sync for AllowStd {} - -impl AllowStd { - fn with_context(&mut self, f: F) -> Poll> - where - F: FnOnce(&mut Context<'_>, Pin<&mut ::hyper::upgrade::Upgraded>) -> Poll>, - { - unsafe { - if !self.context.0 { - //was called by start_send without context - return Poll::Pending; - } - assert!(!self.context.1.is_null()); - let waker = &mut *(self.context.1 as *mut _); - f(waker, Pin::new(&mut self.inner)) - } - } -} - -impl Read for AllowStd { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - match self.with_context(|ctx, stream| stream.poll_read(ctx, buf)) { - Poll::Ready(r) => r, - Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), - } - } -} - -impl Write for AllowStd { - fn write(&mut self, buf: &[u8]) -> io::Result { - match self.with_context(|ctx, stream| stream.poll_write(ctx, buf)) { - Poll::Ready(r) => r, - Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), - } - } - - fn flush(&mut self) -> io::Result<()> { - match self.with_context(|ctx, stream| stream.poll_flush(ctx)) { - Poll::Ready(r) => r, - Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), - } - } -} - -fn cvt(r: tungstenite::error::Result, err_message: &str) -> Poll> { - match r { - Ok(v) => Poll::Ready(Ok(v)), - Err(tungstenite::Error::Io(ref e)) if e.kind() == io::ErrorKind::WouldBlock => { - Poll::Pending - } - Err(e) => { - log::debug!("{} {}", err_message, e); - Poll::Ready(Err(crate::Error::new(e))) - } - } + inner: WebSocketStream, } impl WebSocket { - pub(crate) fn from_raw_socket( - inner: hyper::upgrade::Upgraded, + pub(crate) async fn from_raw_socket( + upgraded: hyper::upgrade::Upgraded, role: protocol::Role, config: Option, ) -> Self { - let ws = protocol::WebSocket::from_raw_socket( - AllowStd { - inner, - context: (false, null_mut()), - }, - role, - config, - ); - - WebSocket { inner: ws } - } - - fn with_context(&mut self, ctx: Option<&mut Context<'_>>, f: F) -> R - where - F: FnOnce(&mut protocol::WebSocket) -> R, - { - self.inner.get_mut().context = match ctx { - Some(ctx) => (true, ctx as *mut _ as *mut ()), - None => (false, null_mut()), - }; - - let g = Guard(self); - f(&mut (g.0).inner) + WebSocketStream::from_raw_socket(upgraded, role, config) + .map(|inner| WebSocket { inner }) + .await } /// Gracefully close this websocket. @@ -284,19 +181,16 @@ impl Stream for WebSocket { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - match (*self).with_context(Some(cx), |s| s.read_message()) { - Ok(item) => Poll::Ready(Some(Ok(Message { inner: item }))), - Err(::tungstenite::Error::Io(ref err)) if err.kind() == io::ErrorKind::WouldBlock => { - Poll::Pending + match ready!(Pin::new(&mut self.inner).poll_next(cx)) { + Some(Ok(item)) => Poll::Ready(Some(Ok(Message { inner: item }))), + Some(Err(e)) => { + log::debug!("websocket poll error: {}", e); + Poll::Ready(Some(Err(crate::Error::new(e)))) } - Err(::tungstenite::Error::ConnectionClosed) => { + None => { log::trace!("websocket closed"); Poll::Ready(None) } - Err(e) => { - log::debug!("websocket poll error: {}", e); - Poll::Ready(Some(Err(crate::Error::new(e)))) - } } } } @@ -305,23 +199,15 @@ impl Sink for WebSocket { type Error = crate::Error; fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - (*self).with_context(Some(cx), |s| { - cvt(s.write_pending(), "websocket poll_ready error") - }) + match ready!(Pin::new(&mut self.inner).poll_ready(cx)) { + Ok(()) => Poll::Ready(Ok(())), + Err(e) => Poll::Ready(Err(crate::Error::new(e))), + } } fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> { - match self.with_context(None, |s| s.write_message(item.inner)) { + match Pin::new(&mut self.inner).start_send(item.inner) { Ok(()) => Ok(()), - // Err(::tungstenite::Error::SendQueueFull(inner)) => { - // log::debug!("websocket send queue full"); - // Err(::tungstenite::Error::SendQueueFull(inner)) - // } - Err(::tungstenite::Error::Io(ref err)) if err.kind() == io::ErrorKind::WouldBlock => { - // the message was accepted and queued - // isn't an error. - Ok(()) - } Err(e) => { log::debug!("websocket start_send error: {}", e); Err(crate::Error::new(e)) @@ -330,15 +216,15 @@ impl Sink for WebSocket { } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - self.with_context(Some(cx), |s| { - cvt(s.write_pending(), "websocket poll_flush error") - }) + match ready!(Pin::new(&mut self.inner).poll_flush(cx)) { + Ok(()) => Poll::Ready(Ok(())), + Err(e) => Poll::Ready(Err(crate::Error::new(e))), + } } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - match self.with_context(Some(cx), |s| s.close(None)) { + match ready!(Pin::new(&mut self.inner).poll_close(cx)) { Ok(()) => Poll::Ready(Ok(())), - Err(::tungstenite::Error::ConnectionClosed) => Poll::Ready(Ok(())), Err(err) => { log::debug!("websocket close error: {}", err); Poll::Ready(Err(crate::Error::new(err))) diff --git a/src/test.rs b/src/test.rs index be906793f..bcb002369 100644 --- a/src/test.rs +++ b/src/test.rs @@ -469,7 +469,7 @@ impl WsBuilder { let (rd_tx, rd_rx) = mpsc::unbounded_channel(); tokio::spawn(async move { - use tungstenite::protocol; + use tokio_tungstenite::tungstenite::protocol; let (addr, srv) = crate::serve(f).bind_ephemeral(([127, 0, 0, 1], 0)); @@ -509,7 +509,8 @@ impl WsBuilder { upgraded, protocol::Role::Client, Default::default(), - ); + ) + .await; let (tx, rx) = ws.split(); let write = wr_rx.map(Ok).forward(tx).map(|_| ()); diff --git a/tests/filter.rs b/tests/filter.rs index d200db173..5f6941884 100644 --- a/tests/filter.rs +++ b/tests/filter.rs @@ -149,11 +149,9 @@ async fn unify() { #[should_panic] #[tokio::test] async fn nested() { - let f = warp::any().and_then(|| { - async { - let p = warp::path::param::(); - warp::test::request().filter(&p).await - } + let f = warp::any().and_then(|| async { + let p = warp::path::param::(); + warp::test::request().filter(&p).await }); let _ = warp::test::request().filter(&f).await;