Skip to content

Commit

Permalink
Re-port #90 (#194)
Browse files Browse the repository at this point in the history
* Don't project if you require Unpin
Co-authored-by: Jens Reidel <adrian@travitia.xyz>

* remove async-await feature on futures-util

Co-authored-by: Naja Melan <najamelan@autistici.org>
  • Loading branch information
Gelbpunkt and najamelan committed Oct 16, 2021
1 parent 2bfd4cd commit b40114e
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 30 deletions.
3 changes: 1 addition & 2 deletions Cargo.toml
Expand Up @@ -27,8 +27,7 @@ stream = []

[dependencies]
log = "0.4"
futures-util = { version = "0.3", default-features = false, features = ["async-await", "sink", "std"] }
pin-project = "1.0"
futures-util = { version = "0.3", default-features = false, features = ["sink", "std"] }
tokio = { version = "1.0.0", default-features = false, features = ["io-util"] }

[dependencies.tungstenite]
Expand Down
15 changes: 6 additions & 9 deletions src/handshake.rs
Expand Up @@ -3,7 +3,6 @@ use crate::{
WebSocketStream,
};
use log::*;
use pin_project::pin_project;
use std::{
future::Future,
io::{Read, Write},
Expand Down Expand Up @@ -54,7 +53,6 @@ where
}
}

#[pin_project]
struct MidHandshake<Role: HandshakeRole>(Option<WsHandshake<Role>>);

enum StartedHandshake<Role: HandshakeRole> {
Expand All @@ -71,7 +69,7 @@ struct StartedHandshakeFutureInner<F, S> {
async fn handshake<Role, F, S>(stream: S, f: F) -> Result<Role::FinalResult, Error<Role>>
where
Role: HandshakeRole + Unpin,
Role::InternalStream: SetWaker,
Role::InternalStream: SetWaker + Unpin,
F: FnOnce(AllowStd<S>) -> Result<Role::FinalResult, Error<Role>> + Unpin,
S: AsyncRead + AsyncWrite + Unpin,
{
Expand Down Expand Up @@ -125,7 +123,7 @@ where
impl<Role, F, S> Future for StartedHandshakeFuture<F, S>
where
Role: HandshakeRole,
Role::InternalStream: SetWaker,
Role::InternalStream: SetWaker + Unpin,
F: FnOnce(AllowStd<S>) -> Result<Role::FinalResult, Error<Role>> + Unpin,
S: Unpin,
AllowStd<S>: Read + Write,
Expand All @@ -148,13 +146,12 @@ where
impl<Role> Future for MidHandshake<Role>
where
Role: HandshakeRole + Unpin,
Role::InternalStream: SetWaker,
Role::InternalStream: SetWaker + Unpin,
{
type Output = Result<Role::FinalResult, Error<Role>>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let mut s = this.0.take().expect("future polled after completion");
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut s = self.as_mut().0.take().expect("future polled after completion");

let machine = s.get_mut();
trace!("Setting context in handshake");
Expand All @@ -164,7 +161,7 @@ where
Ok(stream) => Poll::Ready(Ok(stream)),
Err(Error::Failure(e)) => Poll::Ready(Err(Error::Failure(e))),
Err(Error::Interrupted(mid)) => {
*this.0 = Some(mid);
self.0 = Some(mid);
Poll::Pending
}
}
Expand Down
36 changes: 17 additions & 19 deletions src/stream.rs
Expand Up @@ -3,7 +3,6 @@
//! There is no dependency on actual TLS implementations. Everything like
//! `native_tls` or `openssl` will work as long as there is a TLS stream supporting standard
//! `Read + Write` traits.
use pin_project::pin_project;
use std::{
pin::Pin,
task::{Context, Poll},
Expand All @@ -13,11 +12,10 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

/// A stream that might be protected with TLS.
#[non_exhaustive]
#[pin_project(project = StreamProj)]
#[derive(Debug)]
pub enum MaybeTlsStream<S> {
/// Unencrypted socket stream.
Plain(#[pin] S),
Plain(S),
/// Encrypted socket stream using `native-tls`.
#[cfg(feature = "native-tls")]
NativeTls(tokio_native_tls::TlsStream<S>),
Expand All @@ -32,12 +30,12 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for MaybeTlsStream<S> {
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
match self.project() {
StreamProj::Plain(ref mut s) => Pin::new(s).poll_read(cx, buf),
match self.get_mut() {
MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_read(cx, buf),
#[cfg(feature = "native-tls")]
StreamProj::NativeTls(s) => Pin::new(s).poll_read(cx, buf),
MaybeTlsStream::NativeTls(s) => Pin::new(s).poll_read(cx, buf),
#[cfg(feature = "__rustls-tls")]
StreamProj::Rustls(s) => Pin::new(s).poll_read(cx, buf),
MaybeTlsStream::Rustls(s) => Pin::new(s).poll_read(cx, buf),
}
}
}
Expand All @@ -48,35 +46,35 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MaybeTlsStream<S> {
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
match self.project() {
StreamProj::Plain(ref mut s) => Pin::new(s).poll_write(cx, buf),
match self.get_mut() {
MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_write(cx, buf),
#[cfg(feature = "native-tls")]
StreamProj::NativeTls(s) => Pin::new(s).poll_write(cx, buf),
MaybeTlsStream::NativeTls(s) => Pin::new(s).poll_write(cx, buf),
#[cfg(feature = "__rustls-tls")]
StreamProj::Rustls(s) => Pin::new(s).poll_write(cx, buf),
MaybeTlsStream::Rustls(s) => Pin::new(s).poll_write(cx, buf),
}
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
match self.project() {
StreamProj::Plain(ref mut s) => Pin::new(s).poll_flush(cx),
match self.get_mut() {
MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_flush(cx),
#[cfg(feature = "native-tls")]
StreamProj::NativeTls(s) => Pin::new(s).poll_flush(cx),
MaybeTlsStream::NativeTls(s) => Pin::new(s).poll_flush(cx),
#[cfg(feature = "__rustls-tls")]
StreamProj::Rustls(s) => Pin::new(s).poll_flush(cx),
MaybeTlsStream::Rustls(s) => Pin::new(s).poll_flush(cx),
}
}

fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
match self.project() {
StreamProj::Plain(ref mut s) => Pin::new(s).poll_shutdown(cx),
match self.get_mut() {
MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_shutdown(cx),
#[cfg(feature = "native-tls")]
StreamProj::NativeTls(s) => Pin::new(s).poll_shutdown(cx),
MaybeTlsStream::NativeTls(s) => Pin::new(s).poll_shutdown(cx),
#[cfg(feature = "__rustls-tls")]
StreamProj::Rustls(s) => Pin::new(s).poll_shutdown(cx),
MaybeTlsStream::Rustls(s) => Pin::new(s).poll_shutdown(cx),
}
}
}

0 comments on commit b40114e

Please sign in to comment.