Skip to content

Commit

Permalink
Websocket filter with tokio-tungstenite (#430)
Browse files Browse the repository at this point in the history
* tokio-tungstenite support

* debug message on upgrade failed

* formatting

* replaced tungstenite dependency with tokio-tungstenite v0.10

* removed WebSocket::with_context and WouldBolck check

* formatting

* removed WouldBlock check and ConnectionClosed check

* call ready! on Pin::new.poll
  • Loading branch information
AlexCovizzi committed Feb 6, 2020
1 parent 01725e9 commit 9dacd7d
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 184 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Expand Up @@ -34,7 +34,7 @@ tokio = { version = "0.2", features = ["blocking", "fs", "stream", "sync", "time
tower-service = "0.3"
rustls = { version = "0.16", optional = true }
# tls is enabled by default, we don't want that yet
tungstenite = { default-features = false, version = "0.9", optional = true }
tokio-tungstenite = { version = "0.10", default-features = false, optional = true }
urlencoding = "1.0.0"
pin-project = "0.4.5"

Expand All @@ -46,7 +46,7 @@ tokio = { version = "0.2", features = ["macros"] }

[features]
default = ["multipart", "websocket"]
websocket = ["tungstenite"]
websocket = ["tokio-tungstenite"]
tls = ["rustls"]

[profile.release]
Expand Down
12 changes: 5 additions & 7 deletions examples/rejections.rs
Expand Up @@ -27,13 +27,11 @@ async fn main() {

/// Extract a denominator from a "div-by" header, or reject with DivideByZero.
fn div_by() -> impl Filter<Extract = (NonZeroU16,), Error = Rejection> + Copy {
warp::header::<u16>("div-by").and_then(|n: u16| {
async move {
if let Some(denom) = NonZeroU16::new(n) {
Ok(denom)
} else {
Err(reject::custom(DivideByZero))
}
warp::header::<u16>("div-by").and_then(|n: u16| async move {
if let Some(denom) = NonZeroU16::new(n) {
Ok(denom)
} else {
Err(reject::custom(DivideByZero))
}
})
}
Expand Down
8 changes: 4 additions & 4 deletions examples/sse_chat.rs
Expand Up @@ -22,13 +22,13 @@ async fn main() {
.and(warp::post())
.and(warp::path::param::<usize>())
.and(warp::body::content_length_limit(500))
.and(warp::body::bytes().and_then(|body: bytes::Bytes| {
async move {
.and(
warp::body::bytes().and_then(|body: bytes::Bytes| async move {
std::str::from_utf8(&body)
.map(String::from)
.map_err(|_e| warp::reject::custom(NotUtf8))
}
}))
}),
)
.and(users.clone())
.map(|my_id, msg, users| {
user_message(my_id, msg, &users);
Expand Down
16 changes: 8 additions & 8 deletions src/filters/body.rs
Expand Up @@ -171,14 +171,14 @@ pub fn aggregate() -> impl Filter<Extract = (impl Buf,), Error = Rejection> + Co
/// });
/// ```
pub fn json<T: DeserializeOwned + Send>() -> impl Filter<Extract = (T,), Error = Rejection> + Copy {
is_content_type::<Json>().and(aggregate()).and_then(|buf| {
async move {
is_content_type::<Json>()
.and(aggregate())
.and_then(|buf| async move {
Json::decode(buf).map_err(|err| {
log::debug!("request json body error: {}", err);
reject::known(BodyDeserializeError { cause: err })
})
}
})
})
}

/// Returns a `Filter` that matches any request and extracts a
Expand Down Expand Up @@ -206,14 +206,14 @@ pub fn json<T: DeserializeOwned + Send>() -> impl Filter<Extract = (T,), Error =
/// });
/// ```
pub fn form<T: DeserializeOwned + Send>() -> impl Filter<Extract = (T,), Error = Rejection> + Copy {
is_content_type::<Form>().and(aggregate()).and_then(|buf| {
async move {
is_content_type::<Form>()
.and(aggregate())
.and_then(|buf| async move {
Form::decode(buf).map_err(|err| {
log::debug!("request form body error: {}", err);
reject::known(BodyDeserializeError { cause: err })
})
}
})
})
}

// ===== Decoders =====
Expand Down
24 changes: 11 additions & 13 deletions src/filters/fs.rs
Expand Up @@ -89,20 +89,18 @@ fn path_from_tail(
base: Arc<PathBuf>,
) -> impl FilterClone<Extract = One<ArcPath>, Error = Rejection> {
crate::path::tail().and_then(move |tail: crate::path::Tail| {
future::ready(sanitize_path(base.as_ref(), tail.as_str())).and_then(|mut buf| {
async {
let is_dir = tokio::fs::metadata(buf.clone())
.await
.map(|m| m.is_dir())
.unwrap_or(false);

if is_dir {
log::debug!("dir: appending index.html to directory path");
buf.push("index.html");
}
log::trace!("dir: {:?}", buf);
Ok(ArcPath(Arc::new(buf)))
future::ready(sanitize_path(base.as_ref(), tail.as_str())).and_then(|mut buf| async {
let is_dir = tokio::fs::metadata(buf.clone())
.await
.map(|m| m.is_dir())
.unwrap_or(false);

if is_dir {
log::debug!("dir: appending index.html to directory path");
buf.push("index.html");
}
log::trace!("dir: {:?}", buf);
Ok(ArcPath(Arc::new(buf)))
})
})
}
Expand Down
172 changes: 29 additions & 143 deletions src/filters/ws.rs
Expand Up @@ -3,20 +3,20 @@
use std::borrow::Cow;
use std::fmt;
use std::future::Future;
use std::io::{self, Read, Write};
use std::pin::Pin;
use std::ptr::null_mut;
use std::task::{Context, Poll};

use super::{body, header};
use crate::filter::{Filter, One};
use crate::reject::Rejection;
use crate::reply::{Reply, Response};
use futures::{future, FutureExt, Sink, Stream, TryFutureExt};
use futures::{future, ready, FutureExt, Sink, Stream, TryFutureExt};
use headers::{Connection, HeaderMapExt, SecWebsocketAccept, SecWebsocketKey, Upgrade};
use http;
use tokio::io::{AsyncRead, AsyncWrite};
use tungstenite::protocol::{self, WebSocketConfig};
use tokio_tungstenite::{
tungstenite::protocol::{self, WebSocketConfig},
WebSocketStream,
};

/// Creates a Websocket Filter.
///
Expand Down Expand Up @@ -132,18 +132,9 @@ where
.on_upgrade()
.and_then(move |upgraded| {
log::trace!("websocket upgrade complete");

let io = protocol::WebSocket::from_raw_socket(
AllowStd {
inner: upgraded,
context: (true, null_mut()),
},
protocol::Role::Server,
config,
);

on_upgrade(WebSocket { inner: io }).map(Ok)
WebSocket::from_raw_socket(upgraded, protocol::Role::Server, config).map(Ok)
})
.and_then(move |socket| on_upgrade(socket).map(Ok))
.map(|result| {
if let Err(err) = result {
log::debug!("ws upgrade error: {}", err);
Expand All @@ -166,112 +157,18 @@ where

/// A websocket `Stream` and `Sink`, provided to `ws` filters.
pub struct WebSocket {
inner: protocol::WebSocket<AllowStd>,
}

/// wrapper around hyper Upgraded to allow Read/write from tungstenite's WebSocket
#[derive(Debug)]
pub(crate) struct AllowStd {
inner: ::hyper::upgrade::Upgraded,
context: (bool, *mut ()),
}

struct Guard<'a>(&'a mut WebSocket);

impl Drop for Guard<'_> {
fn drop(&mut self) {
(self.0).inner.get_mut().context = (true, null_mut());
}
}

// *mut () context is neither Send nor Sync
unsafe impl Send for AllowStd {}
unsafe impl Sync for AllowStd {}

impl AllowStd {
fn with_context<F, R>(&mut self, f: F) -> Poll<io::Result<R>>
where
F: FnOnce(&mut Context<'_>, Pin<&mut ::hyper::upgrade::Upgraded>) -> Poll<io::Result<R>>,
{
unsafe {
if !self.context.0 {
//was called by start_send without context
return Poll::Pending;
}
assert!(!self.context.1.is_null());
let waker = &mut *(self.context.1 as *mut _);
f(waker, Pin::new(&mut self.inner))
}
}
}

impl Read for AllowStd {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self.with_context(|ctx, stream| stream.poll_read(ctx, buf)) {
Poll::Ready(r) => r,
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
}
}
}

impl Write for AllowStd {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self.with_context(|ctx, stream| stream.poll_write(ctx, buf)) {
Poll::Ready(r) => r,
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
}
}

fn flush(&mut self) -> io::Result<()> {
match self.with_context(|ctx, stream| stream.poll_flush(ctx)) {
Poll::Ready(r) => r,
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
}
}
}

fn cvt<T>(r: tungstenite::error::Result<T>, err_message: &str) -> Poll<Result<T, crate::Error>> {
match r {
Ok(v) => Poll::Ready(Ok(v)),
Err(tungstenite::Error::Io(ref e)) if e.kind() == io::ErrorKind::WouldBlock => {
Poll::Pending
}
Err(e) => {
log::debug!("{} {}", err_message, e);
Poll::Ready(Err(crate::Error::new(e)))
}
}
inner: WebSocketStream<hyper::upgrade::Upgraded>,
}

impl WebSocket {
pub(crate) fn from_raw_socket(
inner: hyper::upgrade::Upgraded,
pub(crate) async fn from_raw_socket(
upgraded: hyper::upgrade::Upgraded,
role: protocol::Role,
config: Option<protocol::WebSocketConfig>,
) -> Self {
let ws = protocol::WebSocket::from_raw_socket(
AllowStd {
inner,
context: (false, null_mut()),
},
role,
config,
);

WebSocket { inner: ws }
}

fn with_context<F, R>(&mut self, ctx: Option<&mut Context<'_>>, f: F) -> R
where
F: FnOnce(&mut protocol::WebSocket<AllowStd>) -> R,
{
self.inner.get_mut().context = match ctx {
Some(ctx) => (true, ctx as *mut _ as *mut ()),
None => (false, null_mut()),
};

let g = Guard(self);
f(&mut (g.0).inner)
WebSocketStream::from_raw_socket(upgraded, role, config)
.map(|inner| WebSocket { inner })
.await
}

/// Gracefully close this websocket.
Expand All @@ -284,19 +181,16 @@ impl Stream for WebSocket {
type Item = Result<Message, crate::Error>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
match (*self).with_context(Some(cx), |s| s.read_message()) {
Ok(item) => Poll::Ready(Some(Ok(Message { inner: item }))),
Err(::tungstenite::Error::Io(ref err)) if err.kind() == io::ErrorKind::WouldBlock => {
Poll::Pending
match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
Some(Ok(item)) => Poll::Ready(Some(Ok(Message { inner: item }))),
Some(Err(e)) => {
log::debug!("websocket poll error: {}", e);
Poll::Ready(Some(Err(crate::Error::new(e))))
}
Err(::tungstenite::Error::ConnectionClosed) => {
None => {
log::trace!("websocket closed");
Poll::Ready(None)
}
Err(e) => {
log::debug!("websocket poll error: {}", e);
Poll::Ready(Some(Err(crate::Error::new(e))))
}
}
}
}
Expand All @@ -305,23 +199,15 @@ impl Sink<Message> for WebSocket {
type Error = crate::Error;

fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
(*self).with_context(Some(cx), |s| {
cvt(s.write_pending(), "websocket poll_ready error")
})
match ready!(Pin::new(&mut self.inner).poll_ready(cx)) {
Ok(()) => Poll::Ready(Ok(())),
Err(e) => Poll::Ready(Err(crate::Error::new(e))),
}
}

fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
match self.with_context(None, |s| s.write_message(item.inner)) {
match Pin::new(&mut self.inner).start_send(item.inner) {
Ok(()) => Ok(()),
// Err(::tungstenite::Error::SendQueueFull(inner)) => {
// log::debug!("websocket send queue full");
// Err(::tungstenite::Error::SendQueueFull(inner))
// }
Err(::tungstenite::Error::Io(ref err)) if err.kind() == io::ErrorKind::WouldBlock => {
// the message was accepted and queued
// isn't an error.
Ok(())
}
Err(e) => {
log::debug!("websocket start_send error: {}", e);
Err(crate::Error::new(e))
Expand All @@ -330,15 +216,15 @@ impl Sink<Message> for WebSocket {
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.with_context(Some(cx), |s| {
cvt(s.write_pending(), "websocket poll_flush error")
})
match ready!(Pin::new(&mut self.inner).poll_flush(cx)) {
Ok(()) => Poll::Ready(Ok(())),
Err(e) => Poll::Ready(Err(crate::Error::new(e))),
}
}

fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
match self.with_context(Some(cx), |s| s.close(None)) {
match ready!(Pin::new(&mut self.inner).poll_close(cx)) {
Ok(()) => Poll::Ready(Ok(())),
Err(::tungstenite::Error::ConnectionClosed) => Poll::Ready(Ok(())),
Err(err) => {
log::debug!("websocket close error: {}", err);
Poll::Ready(Err(crate::Error::new(err)))
Expand Down
5 changes: 3 additions & 2 deletions src/test.rs
Expand Up @@ -469,7 +469,7 @@ impl WsBuilder {
let (rd_tx, rd_rx) = mpsc::unbounded_channel();

tokio::spawn(async move {
use tungstenite::protocol;
use tokio_tungstenite::tungstenite::protocol;

let (addr, srv) = crate::serve(f).bind_ephemeral(([127, 0, 0, 1], 0));

Expand Down Expand Up @@ -509,7 +509,8 @@ impl WsBuilder {
upgraded,
protocol::Role::Client,
Default::default(),
);
)
.await;

let (tx, rx) = ws.split();
let write = wr_rx.map(Ok).forward(tx).map(|_| ());
Expand Down

0 comments on commit 9dacd7d

Please sign in to comment.