diff --git a/rumqttc/src/client.rs b/rumqttc/src/client.rs index 3fce6712a..a8c769a7d 100644 --- a/rumqttc/src/client.rs +++ b/rumqttc/src/client.rs @@ -150,7 +150,7 @@ impl AsyncClient { where T: IntoIterator, { - let subscribe = Subscribe::new_many(topics); + let subscribe = Subscribe::new_many(topics)?; let request = Request::Subscribe(subscribe); self.request_tx.send(request).await?; Ok(()) @@ -161,7 +161,7 @@ impl AsyncClient { where T: IntoIterator, { - let subscribe = Subscribe::new_many(topics); + let subscribe = Subscribe::new_many(topics)?; let request = Request::Subscribe(subscribe); self.request_tx.try_send(request)?; Ok(()) diff --git a/rumqttc/src/mqttbytes/mod.rs b/rumqttc/src/mqttbytes/mod.rs index 55251384e..ba5034586 100644 --- a/rumqttc/src/mqttbytes/mod.rs +++ b/rumqttc/src/mqttbytes/mod.rs @@ -53,6 +53,8 @@ pub enum Error { MalformedPacket, #[error("Malformed remaining length")] MalformedRemainingLength, + #[error("A Subscribe packet must contain atleast one filter")] + EmptySubscription, /// More bytes required to frame packet. Argument /// implies minimum additional bytes required to /// proceed further diff --git a/rumqttc/src/mqttbytes/v4/subscribe.rs b/rumqttc/src/mqttbytes/v4/subscribe.rs index 05380cea8..9d593a2a1 100644 --- a/rumqttc/src/mqttbytes/v4/subscribe.rs +++ b/rumqttc/src/mqttbytes/v4/subscribe.rs @@ -21,20 +21,15 @@ impl Subscribe { } } - pub fn new_many(topics: T) -> Subscribe + pub fn new_many(topics: T) -> Result where T: IntoIterator, { - Subscribe { - pkid: 0, - filters: topics.into_iter().collect(), - } - } + let filters: Vec = topics.into_iter().collect(); - pub fn empty_subscribe() -> Subscribe { - Subscribe { - pkid: 0, - filters: Vec::new(), + match filters.len() { + 0 => Err(Error::EmptySubscription), + _ => Ok(Subscribe { pkid: 0, filters }), } } @@ -70,9 +65,10 @@ impl Subscribe { }); } - let subscribe = Subscribe { pkid, filters }; - - Ok(subscribe) + match filters.len() { + 0 => Err(Error::EmptySubscription), + _ => Ok(Subscribe { pkid, filters }), + } } pub fn write(&self, buffer: &mut BytesMut) -> Result { diff --git a/rumqttc/src/state.rs b/rumqttc/src/state.rs index cc9346c1c..7039c3bfb 100644 --- a/rumqttc/src/state.rs +++ b/rumqttc/src/state.rs @@ -29,6 +29,8 @@ pub enum StateError { WrongPacket, #[error("Timeout while waiting to resolve collision")] CollisionTimeout, + #[error("A Subscribe packet must contain atleast one filter")] + EmptySubscription, #[error("Mqtt serialization/deserialization error: {0}")] Deserialization(#[from] mqttbytes::Error), } @@ -412,6 +414,10 @@ impl MqttState { } fn outgoing_subscribe(&mut self, mut subscription: Subscribe) -> Result<(), StateError> { + if subscription.filters.is_empty() { + return Err(StateError::EmptySubscription); + } + let pkid = self.next_pkid(); subscription.pkid = pkid; diff --git a/rumqttc/src/v5/client/asyncclient.rs b/rumqttc/src/v5/client/asyncclient.rs index 6c6e498bd..bd3d51ac8 100644 --- a/rumqttc/src/v5/client/asyncclient.rs +++ b/rumqttc/src/v5/client/asyncclient.rs @@ -189,7 +189,7 @@ impl AsyncClient { where T: IntoIterator, { - let mut subscribe = Subscribe::new_many(topics); + let mut subscribe = Subscribe::new_many(topics)?; let pkid = { let mut request_buf = self.outgoing_buf.lock().unwrap(); if request_buf.buf.len() == request_buf.capacity { @@ -209,7 +209,7 @@ impl AsyncClient { where T: IntoIterator, { - let mut subscribe = Subscribe::new_many(topics); + let mut subscribe = Subscribe::new_many(topics)?; let pkid = { let mut request_buf = self.outgoing_buf.lock().unwrap(); if request_buf.buf.len() == request_buf.capacity { diff --git a/rumqttc/src/v5/client/mod.rs b/rumqttc/src/v5/client/mod.rs index f20efe87e..965c90260 100644 --- a/rumqttc/src/v5/client/mod.rs +++ b/rumqttc/src/v5/client/mod.rs @@ -21,7 +21,7 @@ pub enum ClientError { #[error("Failed to send mqtt request to evenloop, to requests buffer is full right now")] RequestsFull, #[error("Serialization error")] - Mqtt5(Error), + Mqtt5(#[from] Error), } fn get_ack_req(qos: QoS, pkid: u16) -> Option { diff --git a/rumqttc/src/v5/client/syncclient.rs b/rumqttc/src/v5/client/syncclient.rs index 420b452aa..cf9761611 100644 --- a/rumqttc/src/v5/client/syncclient.rs +++ b/rumqttc/src/v5/client/syncclient.rs @@ -119,7 +119,7 @@ impl Client { where T: IntoIterator, { - let mut subscribe = Subscribe::new_many(topics); + let mut subscribe = Subscribe::new_many(topics)?; let pkid = { let mut request_buf = self.client.outgoing_buf.lock().unwrap(); if request_buf.buf.len() == request_buf.capacity { diff --git a/rumqttc/src/v5/packet/mod.rs b/rumqttc/src/v5/packet/mod.rs index 8f954c1cf..4badb3a69 100644 --- a/rumqttc/src/v5/packet/mod.rs +++ b/rumqttc/src/v5/packet/mod.rs @@ -1,7 +1,4 @@ -use std::{ - fmt::{self, Display, Formatter}, - slice::Iter, -}; +use std::slice::Iter; use bytes::{Buf, BufMut, Bytes, BytesMut}; @@ -105,33 +102,54 @@ enum PropertyType { } /// Error during serialization and deserialization -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)] pub enum Error { + #[error("Expected Connect, received: {0:?}")] NotConnect(PacketType), + #[error("Unexpected Connect")] UnexpectedConnect, + #[error("Invalid Connect return code: {0}")] InvalidConnectReturnCode(u8), - InvalidReason(u8), + #[error("Invalid protocol")] InvalidProtocol, + #[error("Invalid protocol level: {0}")] InvalidProtocolLevel(u8), + #[error("Incorrect packet format")] IncorrectPacketFormat, + #[error("Invalid packet type: {0}")] InvalidPacketType(u8), + #[error("Invalid property type: {0}")] InvalidPropertyType(u8), - InvalidRetainForwardRule(u8), + #[error("Invalid QoS level: {0}")] InvalidQoS(u8), + #[error("Invalid retain forward rule: {0}")] + InvalidRetainForwardRule(u8), + #[error("Invalid subscribe reason code: {0}")] InvalidSubscribeReasonCode(u8), + #[error("Packet id Zero")] PacketIdZero, - SubscriptionIdZero, + #[error("Payload size is incorrect")] PayloadSizeIncorrect, + #[error("payload is too long")] PayloadTooLong, + #[error("payload size limit exceeded: {0}")] PayloadSizeLimitExceeded(usize), + #[error("Payload required")] PayloadRequired, + #[error("Topic is not UTF-8")] TopicNotUtf8, + #[error("Promised boundary crossed: {0}")] BoundaryCrossed(usize), + #[error("Malformed packet")] MalformedPacket, + #[error("Malformed remaining length")] MalformedRemainingLength, + #[error("A Subscribe packet must contain atleast one filter")] + EmptySubscription, /// More bytes required to frame packet. Argument /// implies minimum additional bytes required to /// proceed further + #[error("At least {0} more bytes required to frame packet")] InsufficientBytes(usize), } @@ -481,9 +499,3 @@ fn read_u32(stream: &mut Bytes) -> Result { Ok(stream.get_u32()) } - -impl Display for Error { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!(f, "Error = {:?}", self) - } -} diff --git a/rumqttc/src/v5/packet/subscribe.rs b/rumqttc/src/v5/packet/subscribe.rs index 9a5c43d1e..1941df6ab 100644 --- a/rumqttc/src/v5/packet/subscribe.rs +++ b/rumqttc/src/v5/packet/subscribe.rs @@ -27,22 +27,19 @@ impl Subscribe { } } - pub fn new_many(topics: T) -> Subscribe + pub fn new_many(topics: T) -> Result where T: IntoIterator, { - Subscribe { - pkid: 0, - filters: topics.into_iter().collect(), - properties: None, - } - } - - pub fn empty_subscribe() -> Subscribe { - Subscribe { - pkid: 0, - filters: Vec::new(), - properties: None, + let filters: Vec = topics.into_iter().collect(); + + match filters.len() { + 0 => Err(Error::EmptySubscription), + _ => Ok(Subscribe { + pkid: 0, + filters, + properties: None, + }), } } @@ -112,13 +109,14 @@ impl Subscribe { }); } - let subscribe = Subscribe { - pkid, - filters, - properties, - }; - - Ok(subscribe) + match filters.len() { + 0 => Err(Error::EmptySubscription), + _ => Ok(Subscribe { + pkid, + filters, + properties, + }), + } } pub fn write(&self, buffer: &mut BytesMut) -> Result {