Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Websocket filter with tokio-tungstenite #430

Merged
merged 9 commits into from Feb 6, 2020
4 changes: 2 additions & 2 deletions Cargo.toml
Expand Up @@ -34,7 +34,7 @@ tokio = { version = "0.2", features = ["blocking", "fs", "stream", "sync", "time
tower-service = "0.3"
rustls = { version = "0.16", optional = true }
# tls is enabled by default, we don't want that yet
tungstenite = { default-features = false, version = "0.9", optional = true }
tokio-tungstenite = { version = "0.10", default-features = false, optional = true }
urlencoding = "1.0.0"
pin-project = "0.4.5"

Expand All @@ -46,7 +46,7 @@ tokio = { version = "0.2", features = ["macros"] }

[features]
default = ["multipart", "websocket"]
websocket = ["tungstenite"]
websocket = ["tokio-tungstenite"]
tls = ["rustls"]

[profile.release]
Expand Down
12 changes: 5 additions & 7 deletions examples/rejections.rs
Expand Up @@ -27,13 +27,11 @@ async fn main() {

/// Extract a denominator from a "div-by" header, or reject with DivideByZero.
fn div_by() -> impl Filter<Extract = (NonZeroU16,), Error = Rejection> + Copy {
warp::header::<u16>("div-by").and_then(|n: u16| {
async move {
if let Some(denom) = NonZeroU16::new(n) {
Ok(denom)
} else {
Err(reject::custom(DivideByZero))
}
warp::header::<u16>("div-by").and_then(|n: u16| async move {
if let Some(denom) = NonZeroU16::new(n) {
Ok(denom)
} else {
Err(reject::custom(DivideByZero))
}
})
}
Expand Down
8 changes: 4 additions & 4 deletions examples/sse_chat.rs
Expand Up @@ -22,13 +22,13 @@ async fn main() {
.and(warp::post())
.and(warp::path::param::<usize>())
.and(warp::body::content_length_limit(500))
.and(warp::body::bytes().and_then(|body: bytes::Bytes| {
async move {
.and(
warp::body::bytes().and_then(|body: bytes::Bytes| async move {
std::str::from_utf8(&body)
.map(String::from)
.map_err(|_e| warp::reject::custom(NotUtf8))
}
}))
}),
)
.and(users.clone())
.map(|my_id, msg, users| {
user_message(my_id, msg, &users);
Expand Down
16 changes: 8 additions & 8 deletions src/filters/body.rs
Expand Up @@ -171,14 +171,14 @@ pub fn aggregate() -> impl Filter<Extract = (impl Buf,), Error = Rejection> + Co
/// });
/// ```
pub fn json<T: DeserializeOwned + Send>() -> impl Filter<Extract = (T,), Error = Rejection> + Copy {
is_content_type::<Json>().and(aggregate()).and_then(|buf| {
async move {
is_content_type::<Json>()
.and(aggregate())
.and_then(|buf| async move {
Json::decode(buf).map_err(|err| {
log::debug!("request json body error: {}", err);
reject::known(BodyDeserializeError { cause: err })
})
}
})
})
}

/// Returns a `Filter` that matches any request and extracts a
Expand Down Expand Up @@ -206,14 +206,14 @@ pub fn json<T: DeserializeOwned + Send>() -> impl Filter<Extract = (T,), Error =
/// });
/// ```
pub fn form<T: DeserializeOwned + Send>() -> impl Filter<Extract = (T,), Error = Rejection> + Copy {
is_content_type::<Form>().and(aggregate()).and_then(|buf| {
async move {
is_content_type::<Form>()
.and(aggregate())
.and_then(|buf| async move {
Form::decode(buf).map_err(|err| {
log::debug!("request form body error: {}", err);
reject::known(BodyDeserializeError { cause: err })
})
}
})
})
}

// ===== Decoders =====
Expand Down
24 changes: 11 additions & 13 deletions src/filters/fs.rs
Expand Up @@ -89,20 +89,18 @@ fn path_from_tail(
base: Arc<PathBuf>,
) -> impl FilterClone<Extract = One<ArcPath>, Error = Rejection> {
crate::path::tail().and_then(move |tail: crate::path::Tail| {
future::ready(sanitize_path(base.as_ref(), tail.as_str())).and_then(|mut buf| {
async {
let is_dir = tokio::fs::metadata(buf.clone())
.await
.map(|m| m.is_dir())
.unwrap_or(false);

if is_dir {
log::debug!("dir: appending index.html to directory path");
buf.push("index.html");
}
log::trace!("dir: {:?}", buf);
Ok(ArcPath(Arc::new(buf)))
future::ready(sanitize_path(base.as_ref(), tail.as_str())).and_then(|mut buf| async {
let is_dir = tokio::fs::metadata(buf.clone())
.await
.map(|m| m.is_dir())
.unwrap_or(false);

if is_dir {
log::debug!("dir: appending index.html to directory path");
buf.push("index.html");
}
log::trace!("dir: {:?}", buf);
Ok(ArcPath(Arc::new(buf)))
})
})
}
Expand Down
183 changes: 43 additions & 140 deletions src/filters/ws.rs
Expand Up @@ -3,9 +3,8 @@
use std::borrow::Cow;
use std::fmt;
use std::future::Future;
use std::io::{self, Read, Write};
use std::io::{self};
use std::pin::Pin;
use std::ptr::null_mut;
use std::task::{Context, Poll};

use super::{body, header};
Expand All @@ -15,8 +14,13 @@ use crate::reply::{Reply, Response};
use futures::{future, FutureExt, Sink, Stream, TryFutureExt};
use headers::{Connection, HeaderMapExt, SecWebsocketAccept, SecWebsocketKey, Upgrade};
use http;
use tokio::io::{AsyncRead, AsyncWrite};
use tungstenite::protocol::{self, WebSocketConfig};
use tokio_tungstenite::{
tungstenite::{
self,
protocol::{self, WebSocketConfig},
},
WebSocketStream,
};

/// Creates a Websocket Filter.
///
Expand Down Expand Up @@ -132,18 +136,9 @@ where
.on_upgrade()
.and_then(move |upgraded| {
log::trace!("websocket upgrade complete");

let io = protocol::WebSocket::from_raw_socket(
AllowStd {
inner: upgraded,
context: (true, null_mut()),
},
protocol::Role::Server,
config,
);

on_upgrade(WebSocket { inner: io }).map(Ok)
WebSocket::from_raw_socket(upgraded, protocol::Role::Server, config).map(Ok)
})
.and_then(move |socket| on_upgrade(socket).map(Ok))
.map(|result| {
if let Err(err) = result {
log::debug!("ws upgrade error: {}", err);
Expand All @@ -166,112 +161,18 @@ where

/// A websocket `Stream` and `Sink`, provided to `ws` filters.
pub struct WebSocket {
inner: protocol::WebSocket<AllowStd>,
}

/// wrapper around hyper Upgraded to allow Read/write from tungstenite's WebSocket
#[derive(Debug)]
pub(crate) struct AllowStd {
inner: ::hyper::upgrade::Upgraded,
context: (bool, *mut ()),
}

struct Guard<'a>(&'a mut WebSocket);

impl Drop for Guard<'_> {
fn drop(&mut self) {
(self.0).inner.get_mut().context = (true, null_mut());
}
}

// *mut () context is neither Send nor Sync
unsafe impl Send for AllowStd {}
unsafe impl Sync for AllowStd {}

impl AllowStd {
fn with_context<F, R>(&mut self, f: F) -> Poll<io::Result<R>>
where
F: FnOnce(&mut Context<'_>, Pin<&mut ::hyper::upgrade::Upgraded>) -> Poll<io::Result<R>>,
{
unsafe {
if !self.context.0 {
//was called by start_send without context
return Poll::Pending;
}
assert!(!self.context.1.is_null());
let waker = &mut *(self.context.1 as *mut _);
f(waker, Pin::new(&mut self.inner))
}
}
}

impl Read for AllowStd {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self.with_context(|ctx, stream| stream.poll_read(ctx, buf)) {
Poll::Ready(r) => r,
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
}
}
}

impl Write for AllowStd {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self.with_context(|ctx, stream| stream.poll_write(ctx, buf)) {
Poll::Ready(r) => r,
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
}
}

fn flush(&mut self) -> io::Result<()> {
match self.with_context(|ctx, stream| stream.poll_flush(ctx)) {
Poll::Ready(r) => r,
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
}
}
}

fn cvt<T>(r: tungstenite::error::Result<T>, err_message: &str) -> Poll<Result<T, crate::Error>> {
match r {
Ok(v) => Poll::Ready(Ok(v)),
Err(tungstenite::Error::Io(ref e)) if e.kind() == io::ErrorKind::WouldBlock => {
Poll::Pending
}
Err(e) => {
log::debug!("{} {}", err_message, e);
Poll::Ready(Err(crate::Error::new(e)))
}
}
inner: WebSocketStream<hyper::upgrade::Upgraded>,
}

impl WebSocket {
pub(crate) fn from_raw_socket(
inner: hyper::upgrade::Upgraded,
pub(crate) async fn from_raw_socket(
upgraded: hyper::upgrade::Upgraded,
role: protocol::Role,
config: Option<protocol::WebSocketConfig>,
) -> Self {
let ws = protocol::WebSocket::from_raw_socket(
AllowStd {
inner,
context: (false, null_mut()),
},
role,
config,
);

WebSocket { inner: ws }
}

fn with_context<F, R>(&mut self, ctx: Option<&mut Context<'_>>, f: F) -> R
where
F: FnOnce(&mut protocol::WebSocket<AllowStd>) -> R,
{
self.inner.get_mut().context = match ctx {
Some(ctx) => (true, ctx as *mut _ as *mut ()),
None => (false, null_mut()),
};

let g = Guard(self);
f(&mut (g.0).inner)
WebSocketStream::from_raw_socket(upgraded, role, config)
.map(|inner| WebSocket { inner })
.await
}

/// Gracefully close this websocket.
Expand All @@ -284,19 +185,26 @@ impl Stream for WebSocket {
type Item = Result<Message, crate::Error>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
match (*self).with_context(Some(cx), |s| s.read_message()) {
Ok(item) => Poll::Ready(Some(Ok(Message { inner: item }))),
Err(::tungstenite::Error::Io(ref err)) if err.kind() == io::ErrorKind::WouldBlock => {
match Pin::new(&mut self.inner).poll_next(cx) {
Poll::Ready(Some(Ok(item))) => Poll::Ready(Some(Ok(Message { inner: item }))),
Poll::Ready(Some(Err(tungstenite::Error::Io(ref err))))
if err.kind() == io::ErrorKind::WouldBlock =>
AlexCovizzi marked this conversation as resolved.
Show resolved Hide resolved
jxs marked this conversation as resolved.
Show resolved Hide resolved
{
Poll::Pending
}
Err(::tungstenite::Error::ConnectionClosed) => {
Poll::Ready(Some(Err(tungstenite::Error::ConnectionClosed))) => {
log::trace!("websocket closed");
Poll::Ready(None)
}
Err(e) => {
Poll::Ready(Some(Err(e))) => {
log::debug!("websocket poll error: {}", e);
Poll::Ready(Some(Err(crate::Error::new(e))))
}
Poll::Ready(None) => {
log::trace!("websocket closed");
Poll::Ready(None)
}
Poll::Pending => Poll::Pending,
}
}
}
Expand All @@ -305,23 +213,16 @@ impl Sink<Message> for WebSocket {
type Error = crate::Error;

fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
(*self).with_context(Some(cx), |s| {
cvt(s.write_pending(), "websocket poll_ready error")
})
match Pin::new(&mut self.inner).poll_ready(cx) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
Poll::Ready(Err(e)) => Poll::Ready(Err(crate::Error::new(e))),
Poll::Pending => Poll::Pending,
}
}

fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
match self.with_context(None, |s| s.write_message(item.inner)) {
match Pin::new(&mut self.inner).start_send(item.inner) {
Ok(()) => Ok(()),
// Err(::tungstenite::Error::SendQueueFull(inner)) => {
// log::debug!("websocket send queue full");
// Err(::tungstenite::Error::SendQueueFull(inner))
// }
Err(::tungstenite::Error::Io(ref err)) if err.kind() == io::ErrorKind::WouldBlock => {
// the message was accepted and queued
// isn't an error.
Ok(())
}
Err(e) => {
log::debug!("websocket start_send error: {}", e);
Err(crate::Error::new(e))
Expand All @@ -330,19 +231,21 @@ impl Sink<Message> for WebSocket {
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.with_context(Some(cx), |s| {
cvt(s.write_pending(), "websocket poll_flush error")
})
match Pin::new(&mut self.inner).poll_flush(cx) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
Poll::Ready(Err(e)) => Poll::Ready(Err(crate::Error::new(e))),
Poll::Pending => Poll::Pending,
}
}

fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
match self.with_context(Some(cx), |s| s.close(None)) {
Ok(()) => Poll::Ready(Ok(())),
Err(::tungstenite::Error::ConnectionClosed) => Poll::Ready(Ok(())),
Err(err) => {
match Pin::new(&mut self.inner).poll_close(cx) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
Poll::Ready(Err(err)) => {
log::debug!("websocket close error: {}", err);
Poll::Ready(Err(crate::Error::new(err)))
}
Poll::Pending => Poll::Pending,
}
}
}
Expand Down