Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Websockets #1

Closed
davidpdrsn opened this issue Jun 9, 2021 · 0 comments · Fixed by #3
Closed

Websockets #1

davidpdrsn opened this issue Jun 9, 2021 · 0 comments · Fixed by #3

Comments

@davidpdrsn
Copy link
Member

Something along the lines of

#![allow(unused_imports)]

use bytes::Bytes;
use futures::prelude::*;
use futures::SinkExt;
use http::{header::HeaderName, HeaderValue, Request, Response, StatusCode};
use http_body::Empty;
use hyper::{
    server::conn::AddrStream,
    upgrade::{OnUpgrade, Upgraded},
    Body,
};
use sha1::{Digest, Sha1};
use std::future::Future;
use std::pin::Pin;
use std::task::Context;
use std::{convert::Infallible, task::Poll};
use tokio_tungstenite::{
    tungstenite::protocol::{self, WebSocketConfig},
    WebSocketStream,
};
use tower::{make::Shared, ServiceBuilder};
use tower::{BoxError, Service};
use tower_http::trace::{DefaultMakeSpan, DefaultOnResponse, TraceLayer};
use tower_http::LatencyUnit;

#[tokio::main]
async fn main() {
    tracing_subscriber::fmt::init();

    let svc = ServiceBuilder::new()
        .layer(TraceLayer::new_for_http())
        .service(WebSocketUpgrade::new(handle_socket));

    let addr = std::net::SocketAddr::from(([0, 0, 0, 0], 3000));

    hyper::Server::bind(&addr)
        .serve(Shared::new(svc))
        .await
        .unwrap();
}

async fn handle_socket(mut socket: WebSocket) {
    while let Some(msg) = socket.recv().await {
        println!("received message: {:?}", msg);
    }
}

#[derive(Debug, Clone)]
pub struct WebSocketUpgrade<F> {
    callback: F,
    config: WebSocketConfig,
}

impl<F> WebSocketUpgrade<F> {
    pub fn new(callback: F) -> Self {
        Self {
            callback,
            config: WebSocketConfig::default(),
        }
    }
}

impl<ReqBody, F, Fut> Service<Request<ReqBody>> for WebSocketUpgrade<F>
where
    F: FnOnce(WebSocket) -> Fut + Clone + Send + 'static,
    Fut: Future<Output = ()> + Send + 'static,
{
    type Response = Response<Empty<Bytes>>;
    type Error = BoxError;
    type Future = ResponseFuture;

    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        Poll::Ready(Ok(()))
    }

    fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
        // TODO(david): missing `upgrade` should return "bad request"

        if !header_eq(
            &req,
            HeaderName::from_static("upgrade"),
            HeaderValue::from_static("websocket"),
        ) {
            todo!()
        }

        if !header_eq(
            &req,
            HeaderName::from_static("sec-websocket-version"),
            HeaderValue::from_static("13"),
        ) {
            todo!()
        }

        let key = if let Some(key) = req.headers_mut().remove("sec-websocket-key") {
            key
        } else {
            todo!()
        };

        let on_upgrade = req.extensions_mut().remove::<OnUpgrade>().unwrap();

        let config = self.config;
        let callback = self.callback.clone();

        tokio::spawn(async move {
            let upgraded = on_upgrade.await.unwrap();
            let socket =
                WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config))
                    .await;
            let socket = WebSocket { inner: socket };
            callback(socket).await;
        });

        ResponseFuture { key: Some(key) }
    }
}

#[derive(Debug)]
pub struct ResponseFuture {
    key: Option<HeaderValue>,
}

impl Future for ResponseFuture {
    type Output = Result<Response<Empty<Bytes>>, BoxError>;

    fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
        let res = Response::builder()
            .status(StatusCode::SWITCHING_PROTOCOLS)
            .header(
                http::header::CONNECTION,
                HeaderValue::from_str("upgrade").unwrap(),
            )
            .header(
                http::header::UPGRADE,
                HeaderValue::from_str("websocket").unwrap(),
            )
            .header(
                http::header::SEC_WEBSOCKET_ACCEPT,
                sign(self.as_mut().key.take().unwrap().as_bytes()),
            )
            .body(Empty::new())
            .unwrap();

        Poll::Ready(Ok(res))
    }
}

fn header_eq<B>(req: &Request<B>, key: HeaderName, value: HeaderValue) -> bool {
    let header = if let Some(x) = req.headers().get(&key) {
        x
    } else {
        return false;
    };
    header == value
}

// from https://github.com/hyperium/headers/blob/master/src/common/sec_websocket_accept.rs#L38
fn sign(key: &[u8]) -> HeaderValue {
    let mut sha1 = Sha1::default();
    sha1.update(key);
    sha1.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
    let b64 = Bytes::from(base64::encode(&sha1.finalize()));
    HeaderValue::from_maybe_shared(b64).expect("base64 is a valid value")
}

#[derive(Debug)]
pub struct WebSocket {
    inner: WebSocketStream<Upgraded>,
}

impl WebSocket {
    pub async fn recv(&mut self) -> Option<Result<protocol::Message, BoxError>> {
        self.inner.next().await.map(|opt| opt.map_err(Into::into))
    }
}

// TODO(david): impl Stream<Message>
// TODO(david): WebSocket::close
@mortifia mortifia mentioned this issue Oct 11, 2023
1 task
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant