Skip to content

Commit

Permalink
Merge pull request #363 from shotover/master
Browse files Browse the repository at this point in the history
Subprotocol header
  • Loading branch information
agalakhov committed May 10, 2024
2 parents 60c50cd + 734234a commit 564f10a
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 4 deletions.
20 changes: 20 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,23 @@ pub enum CapacityError {
},
}

/// Indicates the specific type/cause of a subprotocol header error.
#[derive(Error, Clone, PartialEq, Eq, Debug, Copy)]
pub enum SubProtocolError {
/// The server sent a subprotocol to a client handshake request but none was requested
#[error("Server sent a subprotocol but none was requested")]
ServerSentSubProtocolNoneRequested,

/// The server sent an invalid subprotocol to a client handhshake request
#[error("Server sent an invalid subprotocol")]
InvalidSubProtocol,

/// The server sent no subprotocol to a client handshake request that requested one or more
/// subprotocols
#[error("Server sent no subprotocol")]
NoSubProtocol,
}

/// Indicates the specific type/cause of a protocol error.
#[allow(missing_copy_implementations)]
#[derive(Error, Debug, PartialEq, Eq, Clone)]
Expand All @@ -174,6 +191,9 @@ pub enum ProtocolError {
/// The `Sec-WebSocket-Accept` header is either not present or does not specify the correct key value.
#[error("Key mismatch in \"Sec-WebSocket-Accept\" header")]
SecWebSocketAcceptKeyMismatch,
/// The `Sec-WebSocket-Protocol` header was invalid
#[error("SubProtocol error: {0}")]
SecWebSocketSubProtocolError(SubProtocolError),
/// Garbage data encountered after client request.
#[error("Junk after client request")]
JunkAfterRequest,
Expand Down
43 changes: 40 additions & 3 deletions src/handshake/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use super::{
HandshakeRole, MidHandshake, ProcessingResult,
};
use crate::{
error::{Error, ProtocolError, Result, UrlError},
error::{Error, ProtocolError, Result, SubProtocolError, UrlError},
protocol::{Role, WebSocket, WebSocketConfig},
};

Expand Down Expand Up @@ -54,6 +54,8 @@ impl<S: Read + Write> ClientHandshake<S> {
// Check the URI scheme: only ws or wss are supported
let _ = crate::client::uri_mode(request.uri())?;

let subprotocols = extract_subprotocols_from_request(&request)?;

// Convert and verify the `http::Request` and turn it into the request as per RFC.
// Also extract the key from it (it must be present in a correct request).
let (request, key) = generate_request(request)?;
Expand All @@ -62,7 +64,11 @@ impl<S: Read + Write> ClientHandshake<S> {

let client = {
let accept_key = derive_accept_key(key.as_ref());
ClientHandshake { verify_data: VerifyData { accept_key }, config, _marker: PhantomData }
ClientHandshake {
verify_data: VerifyData { accept_key, subprotocols },
config,
_marker: PhantomData,
}
};

trace!("Client handshake initiated.");
Expand Down Expand Up @@ -178,11 +184,22 @@ pub fn generate_request(mut request: Request) -> Result<(Vec<u8>, String)> {
Ok((req, key))
}

fn extract_subprotocols_from_request(request: &Request) -> Result<Option<Vec<String>>> {
if let Some(subprotocols) = request.headers().get("Sec-WebSocket-Protocol") {
Ok(Some(subprotocols.to_str()?.split(",").map(|s| s.to_string()).collect()))
} else {
Ok(None)
}
}

/// Information for handshake verification.
#[derive(Debug)]
struct VerifyData {
/// Accepted server key.
accept_key: String,

/// Accepted subprotocols
subprotocols: Option<Vec<String>>,
}

impl VerifyData {
Expand Down Expand Up @@ -238,7 +255,27 @@ impl VerifyData {
// not present in the client's handshake (the server has indicated a
// subprotocol not requested by the client), the client MUST _Fail
// the WebSocket Connection_. (RFC 6455)
// TODO
if headers.get("Sec-WebSocket-Protocol").is_none() && self.subprotocols.is_some() {
return Err(Error::Protocol(ProtocolError::SecWebSocketSubProtocolError(
SubProtocolError::NoSubProtocol,
)));
}

if headers.get("Sec-WebSocket-Protocol").is_some() && self.subprotocols.is_none() {
return Err(Error::Protocol(ProtocolError::SecWebSocketSubProtocolError(
SubProtocolError::ServerSentSubProtocolNoneRequested,
)));
}

if let Some(returned_subprotocol) = headers.get("Sec-WebSocket-Protocol") {
if let Some(accepted_subprotocols) = &self.subprotocols {
if !accepted_subprotocols.contains(&returned_subprotocol.to_str()?.to_string()) {
return Err(Error::Protocol(ProtocolError::SecWebSocketSubProtocolError(
SubProtocolError::InvalidSubProtocol,
)));
}
}
}

Ok(response)
}
Expand Down
6 changes: 5 additions & 1 deletion tests/client_headers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ fn test_headers() {
}
});

let callback = |req: &Request, response: Response| {
let callback = |req: &Request, mut response: Response| {
println!("Received a new ws handshake");
println!("The request's path is: {}", req.uri().path());
println!("The request's headers are:");
Expand All @@ -64,6 +64,10 @@ fn test_headers() {
println!("Matching sec-websocket-protocol header");
assert_eq!(header.to_string(), web_socket_proto);
assert_eq!(value.to_str().unwrap(), sub_protocol);
// the server needs to respond with the same sub-protocol
response
.headers_mut()
.append("sec-websocket-protocol", sub_protocol.parse().unwrap());
}
}
Ok(response)
Expand Down
147 changes: 147 additions & 0 deletions tests/handshake.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
#![cfg(feature = "handshake")]
use std::{
net::TcpListener,
thread::{sleep, spawn},
time::Duration,
};
use tungstenite::{
accept_hdr, connect,
error::{Error, ProtocolError, SubProtocolError},
handshake::{
client::generate_key,
server::{Request, Response},
},
};

fn create_http_request(uri: &str, subprotocols: Option<Vec<String>>) -> http::Request<()> {
let uri = uri.parse::<http::Uri>().unwrap();

let authority = uri.authority().unwrap().as_str();
let host =
authority.find('@').map(|idx| authority.split_at(idx + 1).1).unwrap_or_else(|| authority);

if host.is_empty() {
panic!("Empty host name");
}

let mut builder = http::Request::builder()
.method("GET")
.header("Host", host)
.header("Connection", "Upgrade")
.header("Upgrade", "websocket")
.header("Sec-WebSocket-Version", "13")
.header("Sec-WebSocket-Key", generate_key());

if let Some(subprotocols) = subprotocols {
builder = builder.header("Sec-WebSocket-Protocol", subprotocols.join(","));
}

builder.uri(uri).body(()).unwrap()
}

fn server_thread(port: u16, server_subprotocols: Option<Vec<String>>) {
spawn(move || {
let server = TcpListener::bind(("127.0.0.1", port))
.expect("Can't listen, is this port already in use?");

let callback = |_request: &Request, mut response: Response| {
if let Some(subprotocols) = server_subprotocols {
let headers = response.headers_mut();
headers.append("Sec-WebSocket-Protocol", subprotocols.join(",").parse().unwrap());
}
Ok(response)
};

let client_handler = server.incoming().next().unwrap();
let mut client_handler = accept_hdr(client_handler.unwrap(), callback).unwrap();
client_handler.close(None).unwrap();
});
}

#[test]
fn test_server_send_no_subprotocol() {
server_thread(3012, None);
sleep(Duration::from_secs(1));

let err =
connect(create_http_request("ws://127.0.0.1:3012", Some(vec!["my-sub-protocol".into()])))
.unwrap_err();

assert!(matches!(
err,
Error::Protocol(ProtocolError::SecWebSocketSubProtocolError(
SubProtocolError::NoSubProtocol
))
));
}

#[test]
fn test_server_sent_subprotocol_none_requested() {
server_thread(3013, Some(vec!["my-sub-protocol".to_string()]));
sleep(Duration::from_secs(1));

let err = connect(create_http_request("ws://127.0.0.1:3013", None)).unwrap_err();

assert!(matches!(
err,
Error::Protocol(ProtocolError::SecWebSocketSubProtocolError(
SubProtocolError::ServerSentSubProtocolNoneRequested
))
));
}

#[test]
fn test_invalid_subprotocol() {
server_thread(3014, Some(vec!["invalid-sub-protocol".to_string()]));
sleep(Duration::from_secs(1));

let err = connect(create_http_request(
"ws://127.0.0.1:3014",
Some(vec!["my-sub-protocol".to_string()]),
))
.unwrap_err();

assert!(matches!(
err,
Error::Protocol(ProtocolError::SecWebSocketSubProtocolError(
SubProtocolError::InvalidSubProtocol
))
));
}

#[test]
fn test_request_multiple_subprotocols() {
server_thread(3015, Some(vec!["my-sub-protocol".to_string()]));
sleep(Duration::from_secs(1));
let (_, response) = connect(create_http_request(
"ws://127.0.0.1:3015",
Some(vec![
"my-sub-protocol".to_string(),
"my-sub-protocol-1".to_string(),
"my-sub-protocol-2".to_string(),
]),
))
.unwrap();

assert_eq!(
response.headers().get("Sec-WebSocket-Protocol").unwrap(),
"my-sub-protocol".parse::<http::HeaderValue>().unwrap()
);
}

#[test]
fn test_request_single_subprotocol() {
server_thread(3016, Some(vec!["my-sub-protocol".to_string()]));
sleep(Duration::from_secs(1));

let (_, response) = connect(create_http_request(
"ws://127.0.0.1:3016",
Some(vec!["my-sub-protocol".to_string()]),
))
.unwrap();

assert_eq!(
response.headers().get("Sec-WebSocket-Protocol").unwrap(),
"my-sub-protocol".parse::<http::HeaderValue>().unwrap()
);
}

0 comments on commit 564f10a

Please sign in to comment.