diff --git a/src/lib.rs b/src/lib.rs index 70e2b84..734fc64 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -193,6 +193,11 @@ pub struct WebSocketStream { inner: WebSocket>, closing: bool, ended: bool, + /// Tungstenite is probably ready to receive more data. + /// + /// `false` once start_send hits `WouldBlock` errors. + /// `true` initially and after `flush`ing. + ready: bool, } impl WebSocketStream { @@ -226,7 +231,7 @@ impl WebSocketStream { } pub(crate) fn new(ws: WebSocket>) -> Self { - WebSocketStream { inner: ws, closing: false, ended: false } + Self { inner: ws, closing: false, ended: false, ready: true } } fn with_context(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R @@ -321,19 +326,32 @@ where { type Error = WsError; - fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.ready { + Poll::Ready(Ok(())) + } else { + // Currently blocked so try to flush the blockage away + (*self).with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush())).map(|r| { + self.ready = true; + r + }) + } } fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> { match (*self).with_context(None, |s| s.write(item)) { - Ok(()) => Ok(()), + Ok(()) => { + self.ready = true; + Ok(()) + } Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => { - // the message was accepted and queued - // isn't an error. + // the message was accepted and queued so not an error + // but `poll_ready` will now start trying to flush the block + self.ready = false; Ok(()) } Err(e) => { + self.ready = true; debug!("websocket start_send error: {}", e); Err(e) } @@ -342,8 +360,9 @@ where fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { (*self).with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush())).map(|r| { - // WebSocket connection has just been closed. Flushing completed, not an error. + self.ready = true; match r { + // WebSocket connection has just been closed. Flushing completed, not an error. Err(WsError::ConnectionClosed) => Ok(()), other => other, } @@ -351,6 +370,7 @@ where } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.ready = true; let res = if self.closing { // After queueing it, we call `flush` to drive the close handshake to completion. (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.flush())