diff --git a/CHANGELOG.md b/CHANGELOG.md index b78c43b..8939190 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,8 @@ +# 0.13.0 + +- Add `CapacityError`, `UrlError`, and `ProtocolError` types to represent the different types of capacity, URL, and protocol errors respectively. +- Modify variants `Error::Capacity`, `Error::Url`, and `Error::Protocol` to hold the above errors types instead of string error messages. + # 0.12.0 - Add facilities to allow clients to follow HTTP 3XX redirects. diff --git a/Cargo.toml b/Cargo.toml index 9ec4595..cbab804 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ readme = "README.md" homepage = "https://github.com/snapview/tungstenite-rs" documentation = "https://docs.rs/tungstenite/0.12.0" repository = "https://github.com/snapview/tungstenite-rs" -version = "0.12.0" +version = "0.13.0" edition = "2018" [features] @@ -29,6 +29,7 @@ rand = "0.8.0" sha-1 = "0.9" url = "2.1.0" utf-8 = "0.7.5" +thiserror = "1.0.23" [dependencies.native-tls] optional = true diff --git a/README.md b/README.md index 83b276a..7173582 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ fn main () { let mut websocket = accept(stream.unwrap()).unwrap(); loop { let msg = websocket.read_message().unwrap(); - + // We do not want to send back ping/pong messages. if msg.is_binary() || msg.is_text() { websocket.write_message(msg).unwrap(); @@ -62,7 +62,7 @@ Testing ------- Tungstenite is thoroughly tested and passes the [Autobahn Test Suite](https://crossbar.io/autobahn/) for -WebSockets. It is also covered by internal unit tests as good as possible. +WebSockets. It is also covered by internal unit tests as well as possible. Contributing ------------ diff --git a/src/client.rs b/src/client.rs index 1741fa2..5ed89cf 100644 --- a/src/client.rs +++ b/src/client.rs @@ -52,7 +52,7 @@ mod encryption { use std::net::TcpStream; use crate::{ - error::{Error, Result}, + error::{Error, Result, UrlError}, stream::Mode, }; @@ -62,7 +62,7 @@ mod encryption { pub fn wrap_stream(stream: TcpStream, _domain: &str, mode: Mode) -> Result { match mode { Mode::Plain => Ok(stream), - Mode::Tls => Err(Error::Url("TLS support not compiled in.".into())), + Mode::Tls => Err(Error::Url(UrlError::TlsFeatureNotEnabled)), } } } @@ -71,7 +71,7 @@ use self::encryption::wrap_stream; pub use self::encryption::AutoStream; use crate::{ - error::{Error, Result}, + error::{Error, Result, UrlError}, handshake::{client::ClientHandshake, HandshakeError}, protocol::WebSocket, stream::{Mode, NoDelay}, @@ -103,8 +103,7 @@ pub fn connect_with_config( ) -> Result<(WebSocket, Response)> { let uri = request.uri(); let mode = uri_mode(uri)?; - let host = - request.uri().host().ok_or_else(|| Error::Url("No host name in the URL".into()))?; + let host = request.uri().host().ok_or(Error::Url(UrlError::NoHostName))?; let port = uri.port_u16().unwrap_or(match mode { Mode::Plain => 80, Mode::Tls => 443, @@ -166,7 +165,7 @@ pub fn connect(request: Req) -> Result<(WebSocket Result { - let domain = uri.host().ok_or_else(|| Error::Url("No host name in the URL".into()))?; + let domain = uri.host().ok_or(Error::Url(UrlError::NoHostName))?; for addr in addrs { debug!("Trying to contact {} at {}...", uri, addr); if let Ok(raw_stream) = TcpStream::connect(addr) { @@ -175,7 +174,7 @@ fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode) -> Result Result { match uri.scheme_str() { Some("ws") => Ok(Mode::Plain), Some("wss") => Ok(Mode::Tls), - _ => Err(Error::Url("URL scheme not supported".into())), + _ => Err(Error::Url(UrlError::UnsupportedUrlScheme)), } } diff --git a/src/error.rs b/src/error.rs index c2becc7..f4dfdf1 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,9 +1,10 @@ //! Error handling. -use std::{borrow::Cow, error::Error as ErrorTrait, fmt, io, result, str, string}; +use std::{io, result, str, string}; -use crate::protocol::Message; +use crate::protocol::{frame::coding::Data, Message}; use http::Response; +use thiserror::Error; #[cfg(feature = "tls")] pub mod tls { @@ -14,8 +15,8 @@ pub mod tls { /// Result type of all Tungstenite library calls. pub type Result = result::Result; -/// Possible WebSocket errors -#[derive(Debug)] +/// Possible WebSocket errors. +#[derive(Error, Debug)] pub enum Error { /// WebSocket connection closed normally. This informs you of the close. /// It's not an error as such and nothing wrong happened. @@ -28,6 +29,7 @@ pub enum Error { /// /// Receiving this error means that the WebSocket object is not usable anymore and the /// only meaningful action with it is dropping it. + #[error("Connection closed normally")] ConnectionClosed, /// Trying to work with already closed connection. /// @@ -36,56 +38,39 @@ pub enum Error { /// As opposed to `ConnectionClosed`, this indicates your code tries to operate on the /// connection when it really shouldn't anymore, so this really indicates a programmer /// error on your part. + #[error("Trying to work with closed connection")] AlreadyClosed, /// Input-output error. Apart from WouldBlock, these are generally errors with the /// underlying connection and you should probably consider them fatal. - Io(io::Error), + #[error("IO error: {0}")] + Io(#[from] io::Error), + /// TLS error. #[cfg(feature = "tls")] - /// TLS error - Tls(tls::Error), + #[error("TLS error: {0}")] + Tls(#[from] tls::Error), /// - When reading: buffer capacity exhausted. /// - When writing: your message is bigger than the configured max message size /// (64MB by default). - Capacity(Cow<'static, str>), + #[error("Space limit exceeded: {0}")] + Capacity(CapacityError), /// Protocol violation. - Protocol(Cow<'static, str>), + #[error("WebSocket protocol error: {0}")] + Protocol(ProtocolError), /// Message send queue full. + #[error("Send queue is full")] SendQueueFull(Message), - /// UTF coding error + /// UTF coding error. + #[error("UTF-8 encoding error")] Utf8, /// Invalid URL. - Url(Cow<'static, str>), + #[error("URL error: {0}")] + Url(UrlError), /// HTTP error. + #[error("HTTP error: {}", .0.status())] Http(Response>), /// HTTP format error. - HttpFormat(http::Error), -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - Error::ConnectionClosed => write!(f, "Connection closed normally"), - Error::AlreadyClosed => write!(f, "Trying to work with closed connection"), - Error::Io(ref err) => write!(f, "IO error: {}", err), - #[cfg(feature = "tls")] - Error::Tls(ref err) => write!(f, "TLS error: {}", err), - Error::Capacity(ref msg) => write!(f, "Space limit exceeded: {}", msg), - Error::Protocol(ref msg) => write!(f, "WebSocket protocol error: {}", msg), - Error::SendQueueFull(_) => write!(f, "Send queue is full"), - Error::Utf8 => write!(f, "UTF-8 encoding error"), - Error::Url(ref msg) => write!(f, "URL error: {}", msg), - Error::Http(ref code) => write!(f, "HTTP error: {}", code.status()), - Error::HttpFormat(ref err) => write!(f, "HTTP format error: {}", err), - } - } -} - -impl ErrorTrait for Error {} - -impl From for Error { - fn from(err: io::Error) -> Self { - Error::Io(err) - } + #[error("HTTP format error: {0}")] + HttpFormat(#[from] http::Error), } impl From for Error { @@ -130,24 +115,136 @@ impl From for Error { } } -impl From for Error { - fn from(err: http::Error) -> Self { - Error::HttpFormat(err) - } -} - -#[cfg(feature = "tls")] -impl From for Error { - fn from(err: tls::Error) -> Self { - Error::Tls(err) - } -} - impl From for Error { fn from(err: httparse::Error) -> Self { match err { - httparse::Error::TooManyHeaders => Error::Capacity("Too many headers".into()), - e => Error::Protocol(e.to_string().into()), + httparse::Error::TooManyHeaders => Error::Capacity(CapacityError::TooManyHeaders), + e => Error::Protocol(ProtocolError::HttparseError(e)), } } } + +/// Indicates the specific type/cause of a capacity error. +#[derive(Error, Debug, PartialEq, Eq, Clone, Copy)] +pub enum CapacityError { + /// Too many headers provided (see [`httparse::Error::TooManyHeaders`]). + #[error("Too many headers")] + TooManyHeaders, + /// Received header is too long. + #[error("Header too long")] + HeaderTooLong, + /// Message is bigger than the maximum allowed size. + #[error("Message too long: {size} > {max_size}")] + MessageTooLong { + /// The size of the message. + size: usize, + /// The maximum allowed message size. + max_size: usize, + }, + /// TCP buffer is full. + #[error("Incoming TCP buffer is full")] + TcpBufferFull, +} + +/// Indicates the specific type/cause of a protocol error. +#[derive(Error, Debug, PartialEq, Eq, Clone, Copy)] +pub enum ProtocolError { + /// Use of the wrong HTTP method (the WebSocket protocol requires the GET method be used). + #[error("Unsupported HTTP method used - only GET is allowed")] + WrongHttpMethod, + /// Wrong HTTP version used (the WebSocket protocol requires version 1.1 or higher). + #[error("HTTP version must be 1.1 or higher")] + WrongHttpVersion, + /// Missing `Connection: upgrade` HTTP header. + #[error("No \"Connection: upgrade\" header")] + MissingConnectionUpgradeHeader, + /// Missing `Upgrade: websocket` HTTP header. + #[error("No \"Upgrade: websocket\" header")] + MissingUpgradeWebSocketHeader, + /// Missing `Sec-WebSocket-Version: 13` HTTP header. + #[error("No \"Sec-WebSocket-Version: 13\" header")] + MissingSecWebSocketVersionHeader, + /// Missing `Sec-WebSocket-Key` HTTP header. + #[error("No \"Sec-WebSocket-Key\" header")] + MissingSecWebSocketKey, + /// The `Sec-WebSocket-Accept` header is either not present or does not specify the correct key value. + #[error("Key mismatch in \"Sec-WebSocket-Accept\" header")] + SecWebSocketAcceptKeyMismatch, + /// Garbage data encountered after client request. + #[error("Junk after client request")] + JunkAfterRequest, + /// Custom responses must be unsuccessful. + #[error("Custom response must not be successful")] + CustomResponseSuccessful, + /// No more data while still performing handshake. + #[error("Handshake not finished")] + HandshakeIncomplete, + /// Wrapper around a [`httparse::Error`] value. + #[error("httparse error: {0}")] + HttparseError(#[from] httparse::Error), + /// Not allowed to send after having sent a closing frame. + #[error("Sending after closing is not allowed")] + SendAfterClosing, + /// Remote sent data after sending a closing frame. + #[error("Remote sent after having closed")] + ReceivedAfterClosing, + /// Reserved bits in frame header are non-zero. + #[error("Reserved bits are non-zero")] + NonZeroReservedBits, + /// The server must close the connection when an unmasked frame is received. + #[error("Received an unmasked frame from client")] + UnmaskedFrameFromClient, + /// The client must close the connection when a masked frame is received. + #[error("Received a masked frame from server")] + MaskedFrameFromServer, + /// Control frames must not be fragmented. + #[error("Fragmented control frame")] + FragmentedControlFrame, + /// Control frames must have a payload of 125 bytes or less. + #[error("Control frame too big (payload must be 125 bytes or less)")] + ControlFrameTooBig, + /// Type of control frame not recognised. + #[error("Unknown control frame type: {0}")] + UnknownControlFrameType(u8), + /// Type of data frame not recognised. + #[error("Unknown data frame type: {0}")] + UnknownDataFrameType(u8), + /// Received a continue frame despite there being nothing to continue. + #[error("Continue frame but nothing to continue")] + UnexpectedContinueFrame, + /// Received data while waiting for more fragments. + #[error("While waiting for more fragments received: {0}")] + ExpectedFragment(Data), + /// Connection closed without performing the closing handshake. + #[error("Connection reset without closing handshake")] + ResetWithoutClosingHandshake, + /// Encountered an invalid opcode. + #[error("Encountered invalid opcode: {0}")] + InvalidOpcode(u8), + /// The payload for the closing frame is invalid. + #[error("Invalid close sequence")] + InvalidCloseSequence, +} + +/// Indicates the specific type/cause of URL error. +#[derive(Error, Debug, PartialEq, Eq)] +pub enum UrlError { + /// TLS is used despite not being compiled with the TLS feature enabled. + #[error("TLS support not compiled in")] + TlsFeatureNotEnabled, + /// The URL does not include a host name. + #[error("No host name in the URL")] + NoHostName, + /// Failed to connect with this URL. + #[error("Unable to connect to {0}")] + UnableToConnect(String), + /// Unsupported URL scheme used (only `ws://` or `wss://` may be used). + #[error("URL scheme not supported")] + UnsupportedUrlScheme, + /// The URL host name, though included, is empty. + #[error("URL contains empty host name")] + EmptyHostName, + /// The URL does not include a path/query. + #[error("No path/query in URL")] + NoPathOrQuery, +} diff --git a/src/handshake/client.rs b/src/handshake/client.rs index ea011fd..92e5477 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -16,7 +16,7 @@ use super::{ HandshakeRole, MidHandshake, ProcessingResult, }; use crate::{ - error::{Error, Result}, + error::{Error, ProtocolError, Result, UrlError}, protocol::{Role, WebSocket, WebSocketConfig}, }; @@ -42,11 +42,11 @@ impl ClientHandshake { config: Option, ) -> Result> { if request.method() != http::Method::GET { - return Err(Error::Protocol("Invalid HTTP method, only GET supported".into())); + return Err(Error::Protocol(ProtocolError::WrongHttpMethod)); } if request.version() < http::Version::HTTP_11 { - return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); + return Err(Error::Protocol(ProtocolError::WrongHttpVersion)); } // Check the URI scheme: only ws or wss are supported @@ -97,8 +97,7 @@ fn generate_request(request: Request, key: &str) -> Result> { let mut req = Vec::new(); let uri = request.uri(); - let authority = - uri.authority().ok_or_else(|| Error::Url("No host name in the URL".into()))?.as_str(); + let authority = uri.authority().ok_or(Error::Url(UrlError::NoHostName))?.as_str(); let host = if let Some(idx) = authority.find('@') { // handle possible name:password@ authority.split_at(idx + 1).1 @@ -106,7 +105,7 @@ fn generate_request(request: Request, key: &str) -> Result> { authority }; if authority.is_empty() { - return Err(Error::Url("URL contains empty host name".into())); + return Err(Error::Url(UrlError::EmptyHostName)); } write!( @@ -120,8 +119,7 @@ fn generate_request(request: Request, key: &str) -> Result> { Sec-WebSocket-Key: {key}\r\n", version = request.version(), host = host, - path = - uri.path_and_query().ok_or_else(|| Error::Url("No path/query in URL".into()))?.as_str(), + path = uri.path_and_query().ok_or(Error::Url(UrlError::NoPathOrQuery))?.as_str(), key = key ) .unwrap(); @@ -165,7 +163,7 @@ impl VerifyData { .map(|h| h.eq_ignore_ascii_case("websocket")) .unwrap_or(false) { - return Err(Error::Protocol("No \"Upgrade: websocket\" in server reply".into())); + return Err(Error::Protocol(ProtocolError::MissingUpgradeWebSocketHeader)); } // 3. If the response lacks a |Connection| header field or the // |Connection| header field doesn't contain a token that is an @@ -177,14 +175,14 @@ impl VerifyData { .map(|h| h.eq_ignore_ascii_case("Upgrade")) .unwrap_or(false) { - return Err(Error::Protocol("No \"Connection: upgrade\" in server reply".into())); + return Err(Error::Protocol(ProtocolError::MissingConnectionUpgradeHeader)); } // 4. If the response lacks a |Sec-WebSocket-Accept| header field or // the |Sec-WebSocket-Accept| contains a value other than the // base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket // Connection_. (RFC 6455) if !headers.get("Sec-WebSocket-Accept").map(|h| h == &self.accept_key).unwrap_or(false) { - return Err(Error::Protocol("Key mismatch in Sec-WebSocket-Accept".into())); + return Err(Error::Protocol(ProtocolError::SecWebSocketAcceptKeyMismatch)); } // 5. If the response includes a |Sec-WebSocket-Extensions| header // field and this header field indicates the use of an extension @@ -218,7 +216,7 @@ impl TryParse for Response { impl<'h, 'b: 'h> FromHttparse> for Response { fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result { if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { - return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); + return Err(Error::Protocol(ProtocolError::WrongHttpMethod)); } let headers = HeaderMap::from_httparse(raw.headers)?; diff --git a/src/handshake/machine.rs b/src/handshake/machine.rs index 5c7e000..ced0153 100644 --- a/src/handshake/machine.rs +++ b/src/handshake/machine.rs @@ -3,7 +3,7 @@ use log::*; use std::io::{Cursor, Read, Write}; use crate::{ - error::{Error, Result}, + error::{CapacityError, Error, ProtocolError, Result}, util::NonBlockingResult, }; use input_buffer::{InputBuffer, MIN_READ}; @@ -46,11 +46,11 @@ impl HandshakeMachine { let read = buf .prepare_reserve(MIN_READ) .with_limit(usize::max_value()) // TODO limit size - .map_err(|_| Error::Capacity("Header too long".into()))? + .map_err(|_| Error::Capacity(CapacityError::HeaderTooLong))? .read_from(&mut self.stream) .no_block()?; match read { - Some(0) => Err(Error::Protocol("Handshake not finished".into())), + Some(0) => Err(Error::Protocol(ProtocolError::HandshakeIncomplete)), Some(_) => Ok(if let Some((size, obj)) = Obj::try_parse(Buf::chunk(&buf))? { buf.advance(size); RoundResult::StageFinished(StageResult::DoneReading { diff --git a/src/handshake/server.rs b/src/handshake/server.rs index 53227ab..f80c11b 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -19,7 +19,7 @@ use super::{ HandshakeRole, MidHandshake, ProcessingResult, }; use crate::{ - error::{Error, Result}, + error::{Error, ProtocolError, Result}, protocol::{Role, WebSocket, WebSocketConfig}, }; @@ -34,11 +34,11 @@ pub type ErrorResponse = HttpResponse>; fn create_parts(request: &HttpRequest) -> Result { if request.method() != http::Method::GET { - return Err(Error::Protocol("Method is not GET".into())); + return Err(Error::Protocol(ProtocolError::WrongHttpMethod)); } if request.version() < http::Version::HTTP_11 { - return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); + return Err(Error::Protocol(ProtocolError::WrongHttpVersion)); } if !request @@ -48,7 +48,7 @@ fn create_parts(request: &HttpRequest) -> Result { .map(|h| h.split(|c| c == ' ' || c == ',').any(|p| p.eq_ignore_ascii_case("Upgrade"))) .unwrap_or(false) { - return Err(Error::Protocol("No \"Connection: upgrade\" in client request".into())); + return Err(Error::Protocol(ProtocolError::MissingConnectionUpgradeHeader)); } if !request @@ -58,17 +58,17 @@ fn create_parts(request: &HttpRequest) -> Result { .map(|h| h.eq_ignore_ascii_case("websocket")) .unwrap_or(false) { - return Err(Error::Protocol("No \"Upgrade: websocket\" in client request".into())); + return Err(Error::Protocol(ProtocolError::MissingUpgradeWebSocketHeader)); } if !request.headers().get("Sec-WebSocket-Version").map(|h| h == "13").unwrap_or(false) { - return Err(Error::Protocol("No \"Sec-WebSocket-Version: 13\" in client request".into())); + return Err(Error::Protocol(ProtocolError::MissingSecWebSocketVersionHeader)); } let key = request .headers() .get("Sec-WebSocket-Key") - .ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?; + .ok_or(Error::Protocol(ProtocolError::MissingSecWebSocketKey))?; let builder = Response::builder() .status(StatusCode::SWITCHING_PROTOCOLS) @@ -125,11 +125,11 @@ impl TryParse for Request { impl<'h, 'b: 'h> FromHttparse> for Request { fn from_httparse(raw: httparse::Request<'h, 'b>) -> Result { if raw.method.expect("Bug: no method in header") != "GET" { - return Err(Error::Protocol("Method is not GET".into())); + return Err(Error::Protocol(ProtocolError::WrongHttpMethod)); } if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { - return Err(Error::Protocol("HTTP version should be 1.1 or higher".into())); + return Err(Error::Protocol(ProtocolError::WrongHttpVersion)); } let headers = HeaderMap::from_httparse(raw.headers)?; @@ -237,7 +237,7 @@ impl HandshakeRole for ServerHandshake { Ok(match finish { StageResult::DoneReading { stream, result, tail } => { if !tail.is_empty() { - return Err(Error::Protocol("Junk after client request".into())); + return Err(Error::Protocol(ProtocolError::JunkAfterRequest)); } let response = create_response(&result)?; @@ -256,9 +256,7 @@ impl HandshakeRole for ServerHandshake { Err(resp) => { if resp.status().is_success() { - return Err(Error::Protocol( - "Custom response must not be successful".into(), - )); + return Err(Error::Protocol(ProtocolError::CustomResponseSuccessful)); } self.error_response = Some(resp); diff --git a/src/protocol/frame/coding.rs b/src/protocol/frame/coding.rs index e726161..a37dcd2 100644 --- a/src/protocol/frame/coding.rs +++ b/src/protocol/frame/coding.rs @@ -143,7 +143,7 @@ pub enum CloseCode { Abnormal, /// Indicates that an endpoint is terminating the connection /// because it has received data within a message that was not - /// consistent with the type of the message (e.g., non-UTF-8 [RFC3629] + /// consistent with the type of the message (e.g., non-UTF-8 \[RFC3629\] /// data within a text message). Invalid, /// Indicates that an endpoint is terminating the connection diff --git a/src/protocol/frame/frame.rs b/src/protocol/frame/frame.rs index ff64fa2..986bba0 100644 --- a/src/protocol/frame/frame.rs +++ b/src/protocol/frame/frame.rs @@ -13,7 +13,7 @@ use super::{ coding::{CloseCode, Control, Data, OpCode}, mask::{apply_mask, generate_mask}, }; -use crate::error::{Error, Result}; +use crate::error::{Error, ProtocolError, Result}; /// A struct representing the close command. #[derive(Debug, Clone, Eq, PartialEq)] @@ -186,9 +186,7 @@ impl FrameHeader { // Disallow bad opcode match opcode { OpCode::Control(Control::Reserved(_)) | OpCode::Data(Data::Reserved(_)) => { - return Err(Error::Protocol( - format!("Encountered invalid opcode: {}", first & 0x0F).into(), - )) + return Err(Error::Protocol(ProtocolError::InvalidOpcode(first & 0x0F))) } _ => (), } @@ -286,7 +284,7 @@ impl Frame { pub(crate) fn into_close(self) -> Result>> { match self.payload.len() { 0 => Ok(None), - 1 => Err(Error::Protocol("Invalid close sequence".into())), + 1 => Err(Error::Protocol(ProtocolError::InvalidCloseSequence)), _ => { let mut data = self.payload; let code = NetworkEndian::read_u16(&data[0..2]).into(); diff --git a/src/protocol/frame/mod.rs b/src/protocol/frame/mod.rs index dfd0bd5..1e41853 100644 --- a/src/protocol/frame/mod.rs +++ b/src/protocol/frame/mod.rs @@ -8,7 +8,7 @@ mod mask; pub use self::frame::{CloseFrame, Frame, FrameHeader}; -use crate::error::{Error, Result}; +use crate::error::{CapacityError, Error, Result}; use input_buffer::{InputBuffer, MIN_READ}; use log::*; use std::io::{Error as IoError, ErrorKind as IoErrorKind, Read, Write}; @@ -133,9 +133,10 @@ impl FrameCodec { // Enforce frame size limit early and make sure `length` // is not too big (fits into `usize`). if length > max_size as u64 { - return Err(Error::Capacity( - format!("Message length too big: {} > {}", length, max_size).into(), - )); + return Err(Error::Capacity(CapacityError::MessageTooLong { + size: length as usize, + max_size, + })); } let input_size = cursor.get_ref().len() as u64 - cursor.position(); @@ -155,7 +156,7 @@ impl FrameCodec { .in_buffer .prepare_reserve(MIN_READ) .with_limit(usize::max_value()) - .map_err(|_| Error::Capacity("Incoming TCP buffer is full".into()))? + .map_err(|_| Error::Capacity(CapacityError::TcpBufferFull))? .read_from(stream)?; if size == 0 { trace!("no frame received"); @@ -206,6 +207,8 @@ impl FrameCodec { #[cfg(test)] mod tests { + use crate::error::{CapacityError, Error}; + use super::{Frame, FrameSocket}; use std::io::Cursor; @@ -266,9 +269,9 @@ mod tests { fn size_limit_hit() { let raw = Cursor::new(vec![0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); let mut sock = FrameSocket::new(raw); - assert_eq!( - sock.read_frame(Some(5)).unwrap_err().to_string(), - "Space limit exceeded: Message length too big: 7 > 5" - ); + assert!(matches!( + sock.read_frame(Some(5)), + Err(Error::Capacity(CapacityError::MessageTooLong { size: 7, max_size: 5 })) + )); } } diff --git a/src/protocol/message.rs b/src/protocol/message.rs index f799dbf..6720c3c 100644 --- a/src/protocol/message.rs +++ b/src/protocol/message.rs @@ -6,7 +6,7 @@ use std::{ }; use super::frame::CloseFrame; -use crate::error::{Error, Result}; +use crate::error::{CapacityError, Error, Result}; mod string_collect { use utf8::DecodeError; @@ -122,9 +122,10 @@ impl IncompleteMessage { let portion_size = tail.as_ref().len(); // Be careful about integer overflows here. if my_size > max_size || portion_size > max_size - my_size { - return Err(Error::Capacity( - format!("Message too big: {} + {} > {}", my_size, portion_size, max_size).into(), - )); + return Err(Error::Capacity(CapacityError::MessageTooLong { + size: my_size + portion_size, + max_size, + })); } match self.collector { diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 63763f0..215b061 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -21,7 +21,7 @@ use self::{ message::{IncompleteMessage, IncompleteMessageType}, }; use crate::{ - error::{Error, Result}, + error::{Error, ProtocolError, Result}, util::NonBlockingResult, }; @@ -331,7 +331,7 @@ impl WebSocketContext { // Do not write after sending a close frame. if !self.state.is_active() { - return Err(Error::Protocol("Sending after closing is not allowed".into())); + return Err(Error::Protocol(ProtocolError::SendAfterClosing)); } if let Some(max_send_queue) = self.config.max_send_queue { @@ -431,9 +431,7 @@ impl WebSocketContext { .check_connection_reset(self.state)? { if !self.state.can_read() { - return Err(Error::Protocol( - "Remote sent frame after having sent a Close Frame".into(), - )); + return Err(Error::Protocol(ProtocolError::ReceivedAfterClosing)); } // MUST be 0 unless an extension is negotiated that defines meanings // for non-zero values. If a nonzero value is received and none of @@ -443,7 +441,7 @@ impl WebSocketContext { { let hdr = frame.header(); if hdr.rsv1 || hdr.rsv2 || hdr.rsv3 { - return Err(Error::Protocol("Reserved bits are non-zero".into())); + return Err(Error::Protocol(ProtocolError::NonZeroReservedBits)); } } @@ -458,15 +456,13 @@ impl WebSocketContext { // frame that is not masked. (RFC 6455) // The only exception here is if the user explicitly accepts given // stream by setting WebSocketConfig.accept_unmasked_frames to true - return Err(Error::Protocol( - "Received an unmasked frame from client".into(), - )); + return Err(Error::Protocol(ProtocolError::UnmaskedFrameFromClient)); } } Role::Client => { if frame.is_masked() { // A client MUST close a connection if it detects a masked frame. (RFC 6455) - return Err(Error::Protocol("Received a masked frame from server".into())); + return Err(Error::Protocol(ProtocolError::MaskedFrameFromServer)); } } } @@ -477,14 +473,14 @@ impl WebSocketContext { // All control frames MUST have a payload length of 125 bytes or less // and MUST NOT be fragmented. (RFC 6455) _ if !frame.header().is_final => { - Err(Error::Protocol("Fragmented control frame".into())) + Err(Error::Protocol(ProtocolError::FragmentedControlFrame)) } _ if frame.payload().len() > 125 => { - Err(Error::Protocol("Control frame too big".into())) + Err(Error::Protocol(ProtocolError::ControlFrameTooBig)) } OpCtl::Close => Ok(self.do_close(frame.into_close()?).map(Message::Close)), OpCtl::Reserved(i) => { - Err(Error::Protocol(format!("Unknown control frame type {}", i).into())) + Err(Error::Protocol(ProtocolError::UnknownControlFrameType(i))) } OpCtl::Ping => { let data = frame.into_data(); @@ -506,7 +502,7 @@ impl WebSocketContext { msg.extend(frame.into_data(), self.config.max_message_size)?; } else { return Err(Error::Protocol( - "Continue frame but nothing to continue".into(), + ProtocolError::UnexpectedContinueFrame, )); } if fin { @@ -515,9 +511,9 @@ impl WebSocketContext { Ok(None) } } - c if self.incomplete.is_some() => Err(Error::Protocol( - format!("Received {} while waiting for more fragments", c).into(), - )), + c if self.incomplete.is_some() => { + Err(Error::Protocol(ProtocolError::ExpectedFragment(c))) + } OpData::Text | OpData::Binary => { let msg = { let message_type = match data { @@ -537,7 +533,7 @@ impl WebSocketContext { } } OpData::Reserved(i) => { - Err(Error::Protocol(format!("Unknown data frame type {}", i).into())) + Err(Error::Protocol(ProtocolError::UnknownDataFrameType(i))) } } } @@ -548,7 +544,7 @@ impl WebSocketContext { WebSocketState::ClosedByPeer | WebSocketState::CloseAcknowledged => { Err(Error::ConnectionClosed) } - _ => Err(Error::Protocol("Connection reset without closing handshake".into())), + _ => Err(Error::Protocol(ProtocolError::ResetWithoutClosingHandshake)), } } } @@ -673,6 +669,7 @@ impl CheckConnectionReset for Result { #[cfg(test)] mod tests { use super::{Message, Role, WebSocket, WebSocketConfig}; + use crate::error::{CapacityError, Error}; use std::{io, io::Cursor}; @@ -715,10 +712,11 @@ mod tests { ]); let limit = WebSocketConfig { max_message_size: Some(10), ..WebSocketConfig::default() }; let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit)); - assert_eq!( - socket.read_message().unwrap_err().to_string(), - "Space limit exceeded: Message too big: 7 + 6 > 10" - ); + + assert!(matches!( + socket.read_message(), + Err(Error::Capacity(CapacityError::MessageTooLong { size: 13, max_size: 10 })) + )); } #[test] @@ -726,9 +724,10 @@ mod tests { let incoming = Cursor::new(vec![0x82, 0x03, 0x01, 0x02, 0x03]); let limit = WebSocketConfig { max_message_size: Some(2), ..WebSocketConfig::default() }; let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit)); - assert_eq!( - socket.read_message().unwrap_err().to_string(), - "Space limit exceeded: Message too big: 0 + 3 > 2" - ); + + assert!(matches!( + socket.read_message(), + Err(Error::Capacity(CapacityError::MessageTooLong { size: 3, max_size: 2 })) + )); } } diff --git a/tests/no_send_after_close.rs b/tests/no_send_after_close.rs index f348eca..a9dcab2 100644 --- a/tests/no_send_after_close.rs +++ b/tests/no_send_after_close.rs @@ -8,7 +8,7 @@ use std::{ time::Duration, }; -use tungstenite::{accept, connect, Error, Message}; +use tungstenite::{accept, connect, error::ProtocolError, Error, Message}; use url::Url; #[test] @@ -46,7 +46,7 @@ fn test_no_send_after_close() { assert!(err.is_err()); match err.unwrap_err() { - Error::Protocol(s) => assert_eq!("Sending after closing is not allowed", s), + Error::Protocol(s) => assert_eq!(s, ProtocolError::SendAfterClosing), e => panic!("unexpected error: {:?}", e), }