Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Websocket filter with tokio-tungstenite #430

Merged
merged 9 commits into from Feb 6, 2020
4 changes: 2 additions & 2 deletions Cargo.toml
Expand Up @@ -34,7 +34,7 @@ tokio = { version = "0.2", features = ["blocking", "fs", "stream", "sync", "time
tower-service = "0.3"
rustls = { version = "0.16", optional = true }
# tls is enabled by default, we don't want that yet
tungstenite = { default-features = false, version = "0.9", optional = true }
tokio-tungstenite = { version = "0.10", default-features = false, optional = true }
urlencoding = "1.0.0"
pin-project = "0.4.5"

Expand All @@ -46,7 +46,7 @@ tokio = { version = "0.2", features = ["macros"] }

[features]
default = ["multipart", "websocket"]
websocket = ["tungstenite"]
websocket = ["tokio-tungstenite"]
tls = ["rustls"]

[profile.release]
Expand Down
161 changes: 50 additions & 111 deletions src/filters/ws.rs
Expand Up @@ -3,9 +3,8 @@
use std::borrow::Cow;
use std::fmt;
use std::future::Future;
use std::io::{self, Read, Write};
use std::io::{self};
use std::pin::Pin;
use std::ptr::null_mut;
use std::task::{Context, Poll};

use super::{body, header};
Expand All @@ -15,7 +14,8 @@ use crate::reply::{Reply, Response};
use futures::{future, FutureExt, Sink, Stream, TryFutureExt};
use headers::{Connection, HeaderMapExt, SecWebsocketAccept, SecWebsocketKey, Upgrade};
use http;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_tungstenite::tungstenite;
use tokio_tungstenite::WebSocketStream;
use tungstenite::protocol::{self, WebSocketConfig};

/// Creates a Websocket Filter.
Expand Down Expand Up @@ -132,18 +132,9 @@ where
.on_upgrade()
.and_then(move |upgraded| {
log::trace!("websocket upgrade complete");

let io = protocol::WebSocket::from_raw_socket(
AllowStd {
inner: upgraded,
context: (true, null_mut()),
},
protocol::Role::Server,
config,
);

on_upgrade(WebSocket { inner: io }).map(Ok)
WebSocket::from_raw_socket(upgraded, protocol::Role::Server, config).map(Ok)
})
.and_then(move |socket| on_upgrade(socket).map(Ok))
.map(|result| {
if let Err(err) = result {
log::debug!("ws upgrade error: {}", err);
Expand All @@ -166,70 +157,10 @@ where

/// A websocket `Stream` and `Sink`, provided to `ws` filters.
pub struct WebSocket {
inner: protocol::WebSocket<AllowStd>,
}

/// wrapper around hyper Upgraded to allow Read/write from tungstenite's WebSocket
#[derive(Debug)]
pub(crate) struct AllowStd {
inner: ::hyper::upgrade::Upgraded,
context: (bool, *mut ()),
}

struct Guard<'a>(&'a mut WebSocket);

impl Drop for Guard<'_> {
fn drop(&mut self) {
(self.0).inner.get_mut().context = (true, null_mut());
}
}

// *mut () context is neither Send nor Sync
unsafe impl Send for AllowStd {}
unsafe impl Sync for AllowStd {}

impl AllowStd {
fn with_context<F, R>(&mut self, f: F) -> Poll<io::Result<R>>
where
F: FnOnce(&mut Context<'_>, Pin<&mut ::hyper::upgrade::Upgraded>) -> Poll<io::Result<R>>,
{
unsafe {
if !self.context.0 {
//was called by start_send without context
return Poll::Pending;
}
assert!(!self.context.1.is_null());
let waker = &mut *(self.context.1 as *mut _);
f(waker, Pin::new(&mut self.inner))
}
}
}

impl Read for AllowStd {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self.with_context(|ctx, stream| stream.poll_read(ctx, buf)) {
Poll::Ready(r) => r,
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
}
}
}

impl Write for AllowStd {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self.with_context(|ctx, stream| stream.poll_write(ctx, buf)) {
Poll::Ready(r) => r,
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
}
}

fn flush(&mut self) -> io::Result<()> {
match self.with_context(|ctx, stream| stream.poll_flush(ctx)) {
Poll::Ready(r) => r,
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
}
}
inner: WebSocketStream<hyper::upgrade::Upgraded>,
}

/*
fn cvt<T>(r: tungstenite::error::Result<T>, err_message: &str) -> Poll<Result<T, crate::Error>> {
match r {
Ok(v) => Poll::Ready(Ok(v)),
Expand All @@ -242,36 +173,27 @@ fn cvt<T>(r: tungstenite::error::Result<T>, err_message: &str) -> Poll<Result<T,
}
}
}
*/

impl WebSocket {
pub(crate) fn from_raw_socket(
inner: hyper::upgrade::Upgraded,
pub(crate) async fn from_raw_socket(
upgraded: hyper::upgrade::Upgraded,
role: protocol::Role,
config: Option<protocol::WebSocketConfig>,
) -> Self {
let ws = protocol::WebSocket::from_raw_socket(
AllowStd {
inner,
context: (false, null_mut()),
},
role,
config,
);

WebSocket { inner: ws }
WebSocketStream::from_raw_socket(upgraded, role, config)
.map(|inner| WebSocket { inner })
.await
}

fn with_context<F, R>(&mut self, ctx: Option<&mut Context<'_>>, f: F) -> R
jxs marked this conversation as resolved.
Show resolved Hide resolved
where
F: FnOnce(&mut protocol::WebSocket<AllowStd>) -> R,
F: FnOnce(
Pin<&mut WebSocketStream<hyper::upgrade::Upgraded>>,
Option<&mut Context<'_>>,
) -> R,
{
self.inner.get_mut().context = match ctx {
Some(ctx) => (true, ctx as *mut _ as *mut ()),
None => (false, null_mut()),
};

let g = Guard(self);
f(&mut (g.0).inner)
f(Pin::new(&mut self.inner), ctx)
}

/// Gracefully close this websocket.
Expand All @@ -284,19 +206,26 @@ impl Stream for WebSocket {
type Item = Result<Message, crate::Error>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
match (*self).with_context(Some(cx), |s| s.read_message()) {
Ok(item) => Poll::Ready(Some(Ok(Message { inner: item }))),
Err(::tungstenite::Error::Io(ref err)) if err.kind() == io::ErrorKind::WouldBlock => {
match (*self).with_context(Some(cx), |s, cx| s.poll_next(cx.unwrap())) {
Poll::Ready(Some(Ok(item))) => Poll::Ready(Some(Ok(Message { inner: item }))),
Poll::Ready(Some(Err(tungstenite::Error::Io(ref err))))
if err.kind() == io::ErrorKind::WouldBlock =>
AlexCovizzi marked this conversation as resolved.
Show resolved Hide resolved
jxs marked this conversation as resolved.
Show resolved Hide resolved
{
Poll::Pending
}
Err(::tungstenite::Error::ConnectionClosed) => {
Poll::Ready(Some(Err(tungstenite::Error::ConnectionClosed))) => {
log::trace!("websocket closed");
Poll::Ready(None)
}
Err(e) => {
Poll::Ready(Some(Err(e))) => {
log::debug!("websocket poll error: {}", e);
Poll::Ready(Some(Err(crate::Error::new(e))))
}
Poll::Ready(None) => {
log::trace!("websocket closed");
Poll::Ready(None)
}
Poll::Pending => Poll::Pending,
}
}
}
Expand All @@ -305,19 +234,24 @@ impl Sink<Message> for WebSocket {
type Error = crate::Error;

fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
(*self).with_context(Some(cx), |s| {
cvt(s.write_pending(), "websocket poll_ready error")
(*self).with_context(Some(cx), |s, cx| {
match s.poll_ready(cx.unwrap()) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
Poll::Ready(Err(e)) => Poll::Ready(Err(crate::Error::new(e))),
Poll::Pending => Poll::Pending,
}
// cvt(s.write_pending(), "websocket poll_ready error")
})
}

fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
match self.with_context(None, |s| s.write_message(item.inner)) {
match self.with_context(None, |s, _| s.start_send(item.inner)) {
Ok(()) => Ok(()),
// Err(::tungstenite::Error::SendQueueFull(inner)) => {
// log::debug!("websocket send queue full");
// Err(::tungstenite::Error::SendQueueFull(inner))
// }
Err(::tungstenite::Error::Io(ref err)) if err.kind() == io::ErrorKind::WouldBlock => {
Err(tungstenite::Error::Io(ref err)) if err.kind() == io::ErrorKind::WouldBlock => {
jxs marked this conversation as resolved.
Show resolved Hide resolved
// the message was accepted and queued
// isn't an error.
Ok(())
Expand All @@ -330,19 +264,24 @@ impl Sink<Message> for WebSocket {
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.with_context(Some(cx), |s| {
cvt(s.write_pending(), "websocket poll_flush error")
self.with_context(Some(cx), |s, cx| {
//cvt(s.write_pending(), "websocket poll_flush error")
match s.poll_flush(cx.unwrap()) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
Poll::Ready(Err(e)) => Poll::Ready(Err(crate::Error::new(e))),
Poll::Pending => Poll::Pending,
}
})
}

fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
match self.with_context(Some(cx), |s| s.close(None)) {
Ok(()) => Poll::Ready(Ok(())),
Err(::tungstenite::Error::ConnectionClosed) => Poll::Ready(Ok(())),
Err(err) => {
match self.with_context(Some(cx), |s, cx| s.poll_close(cx.unwrap())) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
Poll::Ready(Err(err)) => {
log::debug!("websocket close error: {}", err);
Poll::Ready(Err(crate::Error::new(err)))
}
Poll::Pending => Poll::Pending,
}
}
}
Expand Down
5 changes: 3 additions & 2 deletions src/test.rs
Expand Up @@ -469,7 +469,7 @@ impl WsBuilder {
let (rd_tx, rd_rx) = mpsc::unbounded_channel();

tokio::spawn(async move {
use tungstenite::protocol;
use tokio_tungstenite::tungstenite::protocol;

let (addr, srv) = crate::serve(f).bind_ephemeral(([127, 0, 0, 1], 0));

Expand Down Expand Up @@ -509,7 +509,8 @@ impl WsBuilder {
upgraded,
protocol::Role::Client,
Default::default(),
);
)
.await;

let (tx, rx) = ws.split();
let write = wr_rx.map(Ok).forward(tx).map(|_| ());
Expand Down