From 0597aac5d61690ed25df874b1cc39fbcf460001a Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Fri, 11 Feb 2022 17:24:32 +0530 Subject: [PATCH 01/41] rumqttc: move old code to v4 module Signed-off-by: Abhik Jain --- Cargo.lock | 109 ++-- benchmarks/clients/mesh.rs | 4 +- benchmarks/clients/rumqttasync.rs | 2 +- benchmarks/clients/rumqttasyncqos0.rs | 2 +- benchmarks/clients/rumqttsync.rs | 2 +- rumqttc/examples/async_manual_acks.rs | 2 +- rumqttc/examples/asyncpubsub.rs | 2 +- rumqttc/examples/syncpubsub.rs | 2 +- rumqttc/examples/tls.rs | 2 +- rumqttc/examples/tls2.rs | 2 +- rumqttc/src/lib.rs | 708 +------------------------- rumqttc/src/{ => v4}/client.rs | 2 +- rumqttc/src/{ => v4}/eventloop.rs | 6 +- rumqttc/src/{ => v4}/framed.rs | 2 +- rumqttc/src/v4/mod.rs | 706 +++++++++++++++++++++++++ rumqttc/src/{ => v4}/state.rs | 4 +- rumqttc/src/{ => v4}/tls.rs | 2 +- rumqttc/src/v5/mod.rs | 0 rumqttc/tests/broker.rs | 2 +- rumqttc/tests/reliability.rs | 2 +- 20 files changed, 807 insertions(+), 756 deletions(-) rename rumqttc/src/{ => v4}/client.rs (99%) rename rumqttc/src/{ => v4}/eventloop.rs (98%) rename rumqttc/src/{ => v4}/framed.rs (98%) create mode 100644 rumqttc/src/v4/mod.rs rename rumqttc/src/{ => v4}/state.rs (99%) rename rumqttc/src/{ => v4}/tls.rs (98%) create mode 100644 rumqttc/src/v5/mod.rs diff --git a/Cargo.lock b/Cargo.lock index 8dbaf5cef..789effd4e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -103,18 +103,18 @@ dependencies = [ [[package]] name = "async-tungstenite" -version = "0.13.1" +version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07b30ef0ea5c20caaa54baea49514a206308989c68be7ecd86c7f956e4da6378" +checksum = "5682ea0913e5c20780fe5785abacb85a411e7437bf52a1bedb93ddb3972cb8dd" dependencies = [ "futures-io", "futures-util", "log", "pin-project-lite 0.2.7", + "rustls-native-certs", "tokio 1.8.2", - "tokio-rustls", - "tungstenite 0.13.0", - "webpki-roots", + "tokio-rustls 0.23.2", + "tungstenite 0.16.0", ] [[package]] @@ -2041,14 +2041,15 @@ dependencies = [ "mqttbytes", "pollster", "pretty_env_logger", - "rustls", + "rustls 0.20.2", "rustls-native-certs", + "rustls-pemfile 0.3.0", "serde", "thiserror", "tokio 1.8.2", - "tokio-rustls", + "tokio-rustls 0.23.2", "url", - "webpki", + "webpki 0.22.0", "ws_stream_tungstenite", ] @@ -2070,7 +2071,7 @@ dependencies = [ "thiserror", "tokio 1.8.2", "tokio-native-tls", - "tokio-rustls", + "tokio-rustls 0.22.0", "warp", ] @@ -2161,22 +2162,52 @@ dependencies = [ "base64 0.13.0", "log", "ring", - "sct", - "webpki", + "sct 0.6.1", + "webpki 0.21.4", +] + +[[package]] +name = "rustls" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d37e5e2290f3e040b594b1a9e04377c2c671f1a1cfd9bfdef82106ac1c113f84" +dependencies = [ + "log", + "ring", + "sct 0.7.0", + "webpki 0.22.0", ] [[package]] name = "rustls-native-certs" -version = "0.5.0" +version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a07b7c1885bd8ed3831c289b7870b13ef46fe0e856d288c30d9cc17d75a2092" +checksum = "5ca9ebdfa27d3fc180e42879037b5338ab1c040c06affd00d8338598e7800943" dependencies = [ "openssl-probe", - "rustls", + "rustls-pemfile 0.2.1", "schannel", "security-framework", ] +[[package]] +name = "rustls-pemfile" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5eebeaeb360c87bfb72e84abdb3447159c0eaececf1bef2aecd65a8be949d1c9" +dependencies = [ + "base64 0.13.0", +] + +[[package]] +name = "rustls-pemfile" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ee86d63972a7c661d1536fefe8c3c8407321c3df668891286de28abcd087360" +dependencies = [ + "base64 0.13.0", +] + [[package]] name = "ryu" version = "1.0.5" @@ -2230,6 +2261,16 @@ dependencies = [ "untrusted", ] +[[package]] +name = "sct" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d53dcdb7c9f8158937a7981b48accfd39a43af418591a5d008c7b22b5e1b7ca4" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "security-framework" version = "2.3.1" @@ -2627,9 +2668,20 @@ version = "0.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bc6844de72e57df1980054b38be3a9f4702aba4858be64dd700181a8a6d0e1b6" dependencies = [ - "rustls", + "rustls 0.19.1", + "tokio 1.8.2", + "webpki 0.21.4", +] + +[[package]] +name = "tokio-rustls" +version = "0.23.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a27d5f2b839802bd8267fa19b0530f5a08b9c08cd417976be2a65d130fe1c11b" +dependencies = [ + "rustls 0.20.2", "tokio 1.8.2", - "webpki", + "webpki 0.22.0", ] [[package]] @@ -2733,25 +2785,23 @@ dependencies = [ [[package]] name = "tungstenite" -version = "0.13.0" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5fe8dada8c1a3aeca77d6b51a4f1314e0f4b8e438b7b1b71e3ddaca8080e4093" +checksum = "6ad3713a14ae247f22a728a0456a545df14acf3867f905adff84be99e23b3ad1" dependencies = [ "base64 0.13.0", "byteorder", "bytes 1.0.1", "http", "httparse", - "input_buffer", "log", "rand 0.8.4", - "rustls", + "rustls 0.20.2", "sha-1", "thiserror", "url", "utf-8", - "webpki", - "webpki-roots", + "webpki 0.22.0", ] [[package]] @@ -2999,12 +3049,13 @@ dependencies = [ ] [[package]] -name = "webpki-roots" -version = "0.21.1" +name = "webpki" +version = "0.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aabe153544e473b775453675851ecc86863d2a81d786d741f6b76778f2a48940" +checksum = "f095d78192e208183081cc07bc5515ef55216397af48b873e5edcd72637fa1bd" dependencies = [ - "webpki", + "ring", + "untrusted", ] [[package]] @@ -3072,9 +3123,9 @@ dependencies = [ [[package]] name = "ws_stream_tungstenite" -version = "0.6.1" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34c786fc3d0a792f8a6e7a69f3b85afa1cf7b2560bbd434d7d5c32a580e153c0" +checksum = "a672ec78525bf189cefa7f1b72c55f928b3edbdb967e680ca49748ab20821045" dependencies = [ "async-tungstenite", "async_io_stream", @@ -3087,5 +3138,5 @@ dependencies = [ "pharos", "rustc_version 0.4.0", "tokio 1.8.2", - "tungstenite 0.13.0", + "tungstenite 0.16.0", ] diff --git a/benchmarks/clients/mesh.rs b/benchmarks/clients/mesh.rs index 8f6d04eec..96f38de50 100644 --- a/benchmarks/clients/mesh.rs +++ b/benchmarks/clients/mesh.rs @@ -3,7 +3,7 @@ use tokio::task; use std::thread; use std::path::PathBuf; use bytes::Bytes; -use rumqttlog::router::{Data}; +use rumqttlog::router::Data; mod common; @@ -74,5 +74,3 @@ async fn read(tx: Sender<(usize, RouterInMessage)>) { println!("Id = {}, Total size = {}", id, total_size); } - - diff --git a/benchmarks/clients/rumqttasync.rs b/benchmarks/clients/rumqttasync.rs index 3bfb49ee4..6e4844ad8 100644 --- a/benchmarks/clients/rumqttasync.rs +++ b/benchmarks/clients/rumqttasync.rs @@ -1,4 +1,4 @@ -use rumqttc::*; +use rumqttc::v4::*; use std::error::Error; use std::time::{Duration, Instant}; diff --git a/benchmarks/clients/rumqttasyncqos0.rs b/benchmarks/clients/rumqttasyncqos0.rs index a2a668b0b..081f19ae0 100644 --- a/benchmarks/clients/rumqttasyncqos0.rs +++ b/benchmarks/clients/rumqttasyncqos0.rs @@ -1,4 +1,4 @@ -use rumqttc::*; +use rumqttc::v4::*; use std::error::Error; use std::time::{Duration, Instant}; diff --git a/benchmarks/clients/rumqttsync.rs b/benchmarks/clients/rumqttsync.rs index da85194dd..fd94bd5df 100644 --- a/benchmarks/clients/rumqttsync.rs +++ b/benchmarks/clients/rumqttsync.rs @@ -1,4 +1,4 @@ -use rumqttc::{self, Client, Event, Incoming, MqttOptions, QoS}; +use rumqttc::v4::{Client, Event, Incoming, MqttOptions, QoS}; use std::error::Error; use std::thread; use std::time::{Duration, Instant}; diff --git a/rumqttc/examples/async_manual_acks.rs b/rumqttc/examples/async_manual_acks.rs index 62f7c20c6..6790991e3 100644 --- a/rumqttc/examples/async_manual_acks.rs +++ b/rumqttc/examples/async_manual_acks.rs @@ -1,6 +1,6 @@ use tokio::{task, time}; -use rumqttc::{self, AsyncClient, Event, EventLoop, Incoming, MqttOptions, QoS}; +use rumqttc::v4::{self, AsyncClient, Event, EventLoop, Incoming, MqttOptions, QoS}; use std::error::Error; use std::time::Duration; diff --git a/rumqttc/examples/asyncpubsub.rs b/rumqttc/examples/asyncpubsub.rs index 4ec2cd983..fef15451d 100644 --- a/rumqttc/examples/asyncpubsub.rs +++ b/rumqttc/examples/asyncpubsub.rs @@ -1,6 +1,6 @@ use tokio::{task, time}; -use rumqttc::{self, AsyncClient, MqttOptions, QoS}; +use rumqttc::v4::{self, AsyncClient, MqttOptions, QoS}; use std::error::Error; use std::time::Duration; diff --git a/rumqttc/examples/syncpubsub.rs b/rumqttc/examples/syncpubsub.rs index c69ffc69f..b13abc65f 100644 --- a/rumqttc/examples/syncpubsub.rs +++ b/rumqttc/examples/syncpubsub.rs @@ -1,4 +1,4 @@ -use rumqttc::{self, Client, LastWill, MqttOptions, QoS}; +use rumqttc::v4::{self, Client, LastWill, MqttOptions, QoS}; use std::thread; use std::time::Duration; diff --git a/rumqttc/examples/tls.rs b/rumqttc/examples/tls.rs index 2bb1b2272..90bb4180f 100644 --- a/rumqttc/examples/tls.rs +++ b/rumqttc/examples/tls.rs @@ -1,6 +1,6 @@ //! Example of how to configure rumqttd to connect to a server using TLS and authentication. -use rumqttc::{self, AsyncClient, Event, Incoming, MqttOptions, Transport}; +use rumqttc::v4::{self, AsyncClient, Event, Incoming, MqttOptions, Transport}; use rustls::ClientConfig; use std::error::Error; diff --git a/rumqttc/examples/tls2.rs b/rumqttc/examples/tls2.rs index c54836819..1bf5cf4fb 100644 --- a/rumqttc/examples/tls2.rs +++ b/rumqttc/examples/tls2.rs @@ -1,6 +1,6 @@ //! Example of how to configure rumqttd to connect to a server using TLS and authentication. -use rumqttc::{self, AsyncClient, Key, MqttOptions, TlsConfiguration, Transport}; +use rumqttc::v4::{self, AsyncClient, Key, MqttOptions, TlsConfiguration, Transport}; use std::error::Error; #[tokio::main] diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index 69a3c721c..0a30cbefd 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -99,709 +99,5 @@ #[macro_use] extern crate log; -use std::fmt::{self, Debug, Formatter}; -use std::sync::Arc; -use std::time::Duration; - -mod client; -mod eventloop; -mod framed; -mod state; -mod tls; - -pub use async_channel::{SendError, Sender, TrySendError}; -pub use client::{AsyncClient, Client, ClientError, Connection}; -pub use eventloop::{ConnectionError, Event, EventLoop}; -pub use mqttbytes::v4::*; -pub use mqttbytes::*; -pub use state::{MqttState, StateError}; -pub use tokio_rustls::rustls::ClientConfig; -pub use tls::Error; - -pub type Incoming = Packet; - -/// Current outgoing activity on the eventloop -#[derive(Debug, Eq, PartialEq, Clone)] -pub enum Outgoing { - /// Publish packet with packet identifier. 0 implies QoS 0 - Publish(u16), - /// Subscribe packet with packet identifier - Subscribe(u16), - /// Unsubscribe packet with packet identifier - Unsubscribe(u16), - /// PubAck packet - PubAck(u16), - /// PubRec packet - PubRec(u16), - /// PubRel packet - PubRel(u16), - /// PubComp packet - PubComp(u16), - /// Ping request packet - PingReq, - /// Ping response packet - PingResp, - /// Disconnect packet - Disconnect, - /// Await for an ack for more outgoing progress - AwaitAck(u16), -} - -/// Requests by the client to mqtt event loop. Request are -/// handled one by one. -#[derive(Clone, Debug, PartialEq)] -pub enum Request { - Publish(Publish), - PubAck(PubAck), - PubRec(PubRec), - PubComp(PubComp), - PubRel(PubRel), - PingReq, - PingResp, - Subscribe(Subscribe), - SubAck(SubAck), - Unsubscribe(Unsubscribe), - UnsubAck(UnsubAck), - Disconnect, -} - -/// Key type for TLS authentication -#[derive(Clone, Debug, Eq, PartialEq)] -pub enum Key { - RSA(Vec), - ECC(Vec), -} - -impl From for Request { - fn from(publish: Publish) -> Request { - Request::Publish(publish) - } -} - -impl From for Request { - fn from(subscribe: Subscribe) -> Request { - Request::Subscribe(subscribe) - } -} - -impl From for Request { - fn from(unsubscribe: Unsubscribe) -> Request { - Request::Unsubscribe(unsubscribe) - } -} - -#[derive(Clone)] -pub enum Transport { - Tcp, - Tls(TlsConfiguration), - #[cfg(unix)] - Unix, - #[cfg(feature = "websocket")] - #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] - Ws, - #[cfg(feature = "websocket")] - #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] - Wss(TlsConfiguration), -} - -impl Default for Transport { - fn default() -> Self { - Self::tcp() - } -} - -impl Transport { - /// Use regular tcp as transport (default) - pub fn tcp() -> Self { - Self::Tcp - } - - /// Use secure tcp with tls as transport - pub fn tls( - ca: Vec, - client_auth: Option<(Vec, Key)>, - alpn: Option>>, - ) -> Self { - let config = TlsConfiguration::Simple { - ca, - alpn, - client_auth, - }; - - Self::tls_with_config(config) - } - - pub fn tls_with_config(tls_config: TlsConfiguration) -> Self { - Self::Tls(tls_config) - } - - #[cfg(unix)] - pub fn unix() -> Self { - Self::Unix - } - - /// Use websockets as transport - #[cfg(feature = "websocket")] - #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] - pub fn ws() -> Self { - Self::Ws - } - - /// Use secure websockets with tls as transport - #[cfg(feature = "websocket")] - #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] - pub fn wss( - ca: Vec, - client_auth: Option<(Vec, Key)>, - alpn: Option>>, - ) -> Self { - let config = TlsConfiguration::Simple { - ca, - client_auth, - alpn, - }; - - Self::wss_with_config(config) - } - - #[cfg(feature = "websocket")] - #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] - pub fn wss_with_config(tls_config: TlsConfiguration) -> Self { - Self::Wss(tls_config) - } -} - -#[derive(Clone)] -pub enum TlsConfiguration { - Simple { - /// connection method - ca: Vec, - /// alpn settings - alpn: Option>>, - /// tls client_authentication - client_auth: Option<(Vec, Key)>, - }, - /// Injected rustls ClientConfig for TLS, to allow more customisation. - Rustls(Arc), -} - -impl From for TlsConfiguration { - fn from(config: ClientConfig) -> Self { - TlsConfiguration::Rustls(Arc::new(config)) - } -} - -// TODO: Should all the options be exposed as public? Drawback -// would be loosing the ability to panic when the user options -// are wrong (e.g empty client id) or aggressive (keep alive time) -/// Options to configure the behaviour of mqtt connection -#[derive(Clone)] -pub struct MqttOptions { - /// broker address that you want to connect to - broker_addr: String, - /// broker port - port: u16, - // What transport protocol to use - transport: Transport, - /// keep alive time to send pingreq to broker when the connection is idle - keep_alive: Duration, - /// clean (or) persistent session - clean_session: bool, - /// client identifier - client_id: String, - /// username and password - credentials: Option<(String, String)>, - /// maximum incoming packet size (verifies remaining length of the packet) - max_incoming_packet_size: usize, - /// Maximum outgoing packet size (only verifies publish payload size) - // TODO Verify this with all packets. This can be packet.write but message left in - // the state might be a footgun as user has to explicitly clean it. Probably state - // has to be moved to network - max_outgoing_packet_size: usize, - /// request (publish, subscribe) channel capacity - request_channel_capacity: usize, - /// Max internal request batching - max_request_batch: usize, - /// Minimum delay time between consecutive outgoing packets - /// while retransmitting pending packets - pending_throttle: Duration, - /// maximum number of outgoing inflight messages - inflight: u16, - /// Last will that will be issued on unexpected disconnect - last_will: Option, - /// Connection timeout - conn_timeout: u64, - /// If set to `true` MQTT acknowledgements are not sent automatically. - /// Every incoming publish packet must be manually acknowledged with `client.ack(...)` method. - manual_acks: bool, -} - -impl MqttOptions { - /// New mqtt options - pub fn new, T: Into>(id: S, host: T, port: u16) -> MqttOptions { - let id = id.into(); - if id.starts_with(' ') || id.is_empty() { - panic!("Invalid client id") - } - - MqttOptions { - broker_addr: host.into(), - port, - transport: Transport::tcp(), - keep_alive: Duration::from_secs(60), - clean_session: true, - client_id: id, - credentials: None, - max_incoming_packet_size: 10 * 1024, - max_outgoing_packet_size: 10 * 1024, - request_channel_capacity: 10, - max_request_batch: 0, - pending_throttle: Duration::from_micros(0), - inflight: 100, - last_will: None, - conn_timeout: 5, - manual_acks: false, - } - } - - /// Broker address - pub fn broker_address(&self) -> (String, u16) { - (self.broker_addr.clone(), self.port) - } - - pub fn set_last_will(&mut self, will: LastWill) -> &mut Self { - self.last_will = Some(will); - self - } - - pub fn last_will(&self) -> Option { - self.last_will.clone() - } - - pub fn set_transport(&mut self, transport: Transport) -> &mut Self { - self.transport = transport; - self - } - - pub fn transport(&self) -> Transport { - self.transport.clone() - } - - /// Set number of seconds after which client should ping the broker - /// if there is no other data exchange - pub fn set_keep_alive(&mut self, duration: Duration) -> &mut Self { - if duration.as_secs() < 5 { - panic!("Keep alives should be >= 5 secs"); - } - - self.keep_alive = duration; - self - } - - /// Keep alive time - pub fn keep_alive(&self) -> Duration { - self.keep_alive - } - - /// Client identifier - pub fn client_id(&self) -> String { - self.client_id.clone() - } - - /// Set packet size limit for outgoing an incoming packets - pub fn set_max_packet_size(&mut self, incoming: usize, outgoing: usize) -> &mut Self { - self.max_incoming_packet_size = incoming; - self.max_outgoing_packet_size = outgoing; - self - } - - /// Maximum packet size - pub fn max_packet_size(&self) -> usize { - self.max_incoming_packet_size - } - - /// `clean_session = true` removes all the state from queues & instructs the broker - /// to clean all the client state when client disconnects. - /// - /// When set `false`, broker will hold the client state and performs pending - /// operations on the client when reconnection with same `client_id` - /// happens. Local queue state is also held to retransmit packets after reconnection. - pub fn set_clean_session(&mut self, clean_session: bool) -> &mut Self { - self.clean_session = clean_session; - self - } - - /// Clean session - pub fn clean_session(&self) -> bool { - self.clean_session - } - - /// Username and password - pub fn set_credentials, P: Into>(&mut self, username: U, password: P) -> &mut Self { - self.credentials = Some((username.into(), password.into())); - self - } - - /// Security options - pub fn credentials(&self) -> Option<(String, String)> { - self.credentials.clone() - } - - /// Set request channel capacity - pub fn set_request_channel_capacity(&mut self, capacity: usize) -> &mut Self { - self.request_channel_capacity = capacity; - self - } - - /// Request channel capacity - pub fn request_channel_capacity(&self) -> usize { - self.request_channel_capacity - } - - /// Enables throttling and sets outoing message rate to the specified 'rate' - pub fn set_pending_throttle(&mut self, duration: Duration) -> &mut Self { - self.pending_throttle = duration; - self - } - - /// Outgoing message rate - pub fn pending_throttle(&self) -> Duration { - self.pending_throttle - } - - /// Set number of concurrent in flight messages - pub fn set_inflight(&mut self, inflight: u16) -> &mut Self { - if inflight == 0 { - panic!("zero in flight is not allowed") - } - - self.inflight = inflight; - self - } - - /// Number of concurrent in flight messages - pub fn inflight(&self) -> u16 { - self.inflight - } - - /// set connection timeout in secs - pub fn set_connection_timeout(&mut self, timeout: u64) -> &mut Self { - self.conn_timeout = timeout; - self - } - - /// get timeout in secs - pub fn connection_timeout(&self) -> u64 { - self.conn_timeout - } - - /// set manual acknowledgements - pub fn set_manual_acks(&mut self, manual_acks: bool) -> &mut Self { - self.manual_acks = manual_acks; - self - } - - /// get manual acknowledgements - pub fn manual_acks(&self) -> bool { - self.manual_acks - } -} - -#[cfg(feature = "url")] -#[derive(Debug, PartialEq, thiserror::Error)] -pub enum OptionError { - #[error("Unsupported URL scheme.")] - Scheme, - - #[error("Missing client ID.")] - ClientId, - - #[error("Invalid keep-alive value.")] - KeepAlive, - - #[error("Invalid clean-session value.")] - CleanSession, - - #[error("Invalid max-incoming-packet-size value.")] - MaxIncomingPacketSize, - - #[error("Invalid max-outgoing-packet-size value.")] - MaxOutgoingPacketSize, - - #[error("Invalid request-channel-capacity value.")] - RequestChannelCapacity, - - #[error("Invalid max-request-batch value.")] - MaxRequestBatch, - - #[error("Invalid pending-throttle value.")] - PendingThrottle, - - #[error("Invalid inflight value.")] - Inflight, - - #[error("Invalid conn-timeout value.")] - ConnTimeout, - - #[error("Unknown option: {0}")] - Unknown(String), -} - -#[cfg(feature = "url")] -impl std::convert::TryFrom for MqttOptions { - type Error = OptionError; - - fn try_from(url: url::Url) -> Result { - use std::collections::HashMap; - - let broker_addr = url.host_str().unwrap_or_default().to_owned(); - - let (transport, default_port) = match url.scheme() { - // Encrypted connections are supported, but require explicit TLS configuration. We fall - // back to the unencrypted transport layer, so that `set_transport` can be used to - // configure the encrypted transport layer with the provided TLS configuration. - "mqtts" | "ssl" => (Transport::Tcp, 8883), - "mqtt" | "tcp" => (Transport::Tcp, 1883), - _ => return Err(OptionError::Scheme), - }; - - let port = url.port().unwrap_or(default_port); - - let mut queries = url.query_pairs().collect::>(); - - let keep_alive = Duration::from_secs( - queries - .remove("keep_alive_secs") - .map(|v| v.parse::().map_err(|_| OptionError::KeepAlive)) - .transpose()? - .unwrap_or(60), - ); - - let client_id = queries - .remove("client_id") - .ok_or(OptionError::ClientId)? - .into_owned(); - - let clean_session = queries - .remove("clean_session") - .map(|v| v.parse::().map_err(|_| OptionError::CleanSession)) - .transpose()? - .unwrap_or(true); - - let credentials = { - match url.username() { - "" => None, - username => Some(( - username.to_owned(), - url.password().unwrap_or_default().to_owned(), - )), - } - }; - - let max_incoming_packet_size = queries - .remove("max_incoming_packet_size_bytes") - .map(|v| { - v.parse::() - .map_err(|_| OptionError::MaxIncomingPacketSize) - }) - .transpose()? - .unwrap_or(10 * 1024); - - let max_outgoing_packet_size = queries - .remove("max_outgoing_packet_size_bytes") - .map(|v| { - v.parse::() - .map_err(|_| OptionError::MaxOutgoingPacketSize) - }) - .transpose()? - .unwrap_or(10 * 1024); - - let request_channel_capacity = queries - .remove("request_channel_capacity_num") - .map(|v| { - v.parse::() - .map_err(|_| OptionError::RequestChannelCapacity) - }) - .transpose()? - .unwrap_or(10); - - let max_request_batch = queries - .remove("max_request_batch_num") - .map(|v| v.parse::().map_err(|_| OptionError::MaxRequestBatch)) - .transpose()? - .unwrap_or(0); - - let pending_throttle = Duration::from_micros( - queries - .remove("pending_throttle_usecs") - .map(|v| v.parse::().map_err(|_| OptionError::PendingThrottle)) - .transpose()? - .unwrap_or(0), - ); - - let inflight = queries - .remove("inflight_num") - .map(|v| v.parse::().map_err(|_| OptionError::Inflight)) - .transpose()? - .unwrap_or(100); - - let conn_timeout = queries - .remove("conn_timeout_secs") - .map(|v| v.parse::().map_err(|_| OptionError::ConnTimeout)) - .transpose()? - .unwrap_or(5); - - if let Some((opt, _)) = queries.into_iter().next() { - return Err(OptionError::Unknown(opt.into_owned())); - } - - Ok(Self { - broker_addr, - port, - transport, - keep_alive, - clean_session, - client_id, - credentials, - max_incoming_packet_size, - max_outgoing_packet_size, - request_channel_capacity, - max_request_batch, - pending_throttle, - inflight, - last_will: None, - conn_timeout, - manual_acks: false - }) - } -} - -// Implement Debug manually because ClientConfig doesn't implement it, so derive(Debug) doesn't -// work. -impl Debug for MqttOptions { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - f.debug_struct("MqttOptions") - .field("broker_addr", &self.broker_addr) - .field("port", &self.port) - .field("keep_alive", &self.keep_alive) - .field("clean_session", &self.clean_session) - .field("client_id", &self.client_id) - .field("credentials", &self.credentials) - .field("max_packet_size", &self.max_incoming_packet_size) - .field("request_channel_capacity", &self.request_channel_capacity) - .field("max_request_batch", &self.max_request_batch) - .field("pending_throttle", &self.pending_throttle) - .field("inflight", &self.inflight) - .field("last_will", &self.last_will) - .field("conn_timeout", &self.conn_timeout) - .field("manual_acks", &self.manual_acks) - .finish() - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - #[should_panic] - fn client_id_startswith_space() { - let _mqtt_opts = MqttOptions::new(" client_a", "127.0.0.1", 1883).set_clean_session(true); - } - - #[test] - #[cfg(feature = "websocket")] - fn no_scheme() { - let mut _mqtt_opts = MqttOptions::new("client_a", "a3f8czas.iot.eu-west-1.amazonaws.com/mqtt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=MyCreds%2F20201001%2Feu-west-1%2Fiotdevicegateway%2Faws4_request&X-Amz-Date=20201001T130812Z&X-Amz-Expires=7200&X-Amz-Signature=9ae09b49896f44270f2707551581953e6cac71a4ccf34c7c3415555be751b2d1&X-Amz-SignedHeaders=host", 443); - - _mqtt_opts.set_transport(crate::Transport::wss(Vec::from("Test CA"), None, None)); - - if let crate::Transport::Wss(TlsConfiguration::Simple { - ca, - client_auth, - alpn, - }) = _mqtt_opts.transport - { - assert_eq!(ca, Vec::from("Test CA")); - assert_eq!(client_auth, None); - assert_eq!(alpn, None); - } else { - panic!("Unexpected transport!"); - } - - assert_eq!(_mqtt_opts.broker_addr, "a3f8czas.iot.eu-west-1.amazonaws.com/mqtt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=MyCreds%2F20201001%2Feu-west-1%2Fiotdevicegateway%2Faws4_request&X-Amz-Date=20201001T130812Z&X-Amz-Expires=7200&X-Amz-Signature=9ae09b49896f44270f2707551581953e6cac71a4ccf34c7c3415555be751b2d1&X-Amz-SignedHeaders=host"); - } - - #[test] - #[cfg(feature = "url")] - fn from_url() { - use std::convert::TryInto; - use std::str::FromStr; - - fn opt(s: &str) -> Result { - url::Url::from_str(s).expect("valid url").try_into() - } - fn ok(s: &str) -> MqttOptions { - opt(s).expect("valid options") - } - fn err(s: &str) -> OptionError { - opt(s).expect_err("invalid options") - } - - let v = ok("mqtt://host:42?client_id=foo"); - assert_eq!(v.broker_address(), ("host".to_owned(), 42)); - assert_eq!(v.client_id(), "foo".to_owned()); - - let v = ok("mqtt://host:42?client_id=foo&keep_alive_secs=5"); - assert_eq!(v.keep_alive, Duration::from_secs(5)); - - assert_eq!(err("mqtt://host:42"), OptionError::ClientId); - assert_eq!( - err("mqtt://host:42?client_id=foo&foo=bar"), - OptionError::Unknown("foo".to_owned()) - ); - assert_eq!(err("mqt://host:42?client_id=foo"), OptionError::Scheme); - assert_eq!( - err("mqtt://host:42?client_id=foo&keep_alive_secs=foo"), - OptionError::KeepAlive - ); - assert_eq!( - err("mqtt://host:42?client_id=foo&clean_session=foo"), - OptionError::CleanSession - ); - assert_eq!( - err("mqtt://host:42?client_id=foo&max_incoming_packet_size_bytes=foo"), - OptionError::MaxIncomingPacketSize - ); - assert_eq!( - err("mqtt://host:42?client_id=foo&max_outgoing_packet_size_bytes=foo"), - OptionError::MaxOutgoingPacketSize - ); - assert_eq!( - err("mqtt://host:42?client_id=foo&request_channel_capacity_num=foo"), - OptionError::RequestChannelCapacity - ); - assert_eq!( - err("mqtt://host:42?client_id=foo&max_request_batch_num=foo"), - OptionError::MaxRequestBatch - ); - assert_eq!( - err("mqtt://host:42?client_id=foo&pending_throttle_usecs=foo"), - OptionError::PendingThrottle - ); - assert_eq!( - err("mqtt://host:42?client_id=foo&inflight_num=foo"), - OptionError::Inflight - ); - assert_eq!( - err("mqtt://host:42?client_id=foo&conn_timeout_secs=foo"), - OptionError::ConnTimeout - ); - } - - #[test] - #[should_panic] - fn no_client_id() { - let _mqtt_opts = MqttOptions::new("", "127.0.0.1", 1883).set_clean_session(true); - } -} +pub mod v4; +pub mod v5; diff --git a/rumqttc/src/client.rs b/rumqttc/src/v4/client.rs similarity index 99% rename from rumqttc/src/client.rs rename to rumqttc/src/v4/client.rs index 139f54389..8b95c1bc2 100644 --- a/rumqttc/src/client.rs +++ b/rumqttc/src/v4/client.rs @@ -1,6 +1,6 @@ //! This module offers a high level synchronous and asynchronous abstraction to //! async eventloop. -use crate::{ConnectionError, Event, EventLoop, MqttOptions, Request}; +use crate::v4::{ConnectionError, Event, EventLoop, MqttOptions, Request}; use async_channel::{SendError, Sender, TrySendError}; use bytes::Bytes; diff --git a/rumqttc/src/eventloop.rs b/rumqttc/src/v4/eventloop.rs similarity index 98% rename from rumqttc/src/eventloop.rs rename to rumqttc/src/v4/eventloop.rs index b4065cf7e..d4ce78c7a 100644 --- a/rumqttc/src/eventloop.rs +++ b/rumqttc/src/v4/eventloop.rs @@ -1,6 +1,6 @@ -use crate::{framed::Network, Transport}; -use crate::{tls, Incoming, MqttState, Packet, Request, StateError}; -use crate::{MqttOptions, Outgoing}; +use crate::v4::{framed::Network, Transport}; +use crate::v4::{tls, Incoming, MqttState, Packet, Request, StateError}; +use crate::v4::{MqttOptions, Outgoing}; use async_channel::{bounded, Receiver, Sender}; #[cfg(feature = "websocket")] diff --git a/rumqttc/src/framed.rs b/rumqttc/src/v4/framed.rs similarity index 98% rename from rumqttc/src/framed.rs rename to rumqttc/src/v4/framed.rs index 995234c70..96bda17da 100644 --- a/rumqttc/src/framed.rs +++ b/rumqttc/src/v4/framed.rs @@ -3,7 +3,7 @@ use mqttbytes::v4::*; use mqttbytes::*; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use crate::{Incoming, MqttState, StateError}; +use crate::v4::{Incoming, MqttState, StateError}; use std::io; /// Network transforms packets <-> frames efficiently. It takes diff --git a/rumqttc/src/v4/mod.rs b/rumqttc/src/v4/mod.rs new file mode 100644 index 000000000..4797b0f2a --- /dev/null +++ b/rumqttc/src/v4/mod.rs @@ -0,0 +1,706 @@ +use std::fmt::{self, Debug, Formatter}; +use std::sync::Arc; +use std::time::Duration; + +mod client; +mod eventloop; +mod framed; +mod state; +mod tls; + +pub use async_channel::{SendError, Sender, TrySendError}; +pub use client::{AsyncClient, Client, ClientError, Connection}; +pub use eventloop::{ConnectionError, Event, EventLoop}; +pub use mqttbytes::v4::*; +pub use mqttbytes::*; +pub use state::{MqttState, StateError}; +pub use tokio_rustls::rustls::ClientConfig; +pub use tls::Error; + +pub type Incoming = Packet; + +/// Current outgoing activity on the eventloop +#[derive(Debug, Eq, PartialEq, Clone)] +pub enum Outgoing { + /// Publish packet with packet identifier. 0 implies QoS 0 + Publish(u16), + /// Subscribe packet with packet identifier + Subscribe(u16), + /// Unsubscribe packet with packet identifier + Unsubscribe(u16), + /// PubAck packet + PubAck(u16), + /// PubRec packet + PubRec(u16), + /// PubRel packet + PubRel(u16), + /// PubComp packet + PubComp(u16), + /// Ping request packet + PingReq, + /// Ping response packet + PingResp, + /// Disconnect packet + Disconnect, + /// Await for an ack for more outgoing progress + AwaitAck(u16), +} + +/// Requests by the client to mqtt event loop. Request are +/// handled one by one. +#[derive(Clone, Debug, PartialEq)] +pub enum Request { + Publish(Publish), + PubAck(PubAck), + PubRec(PubRec), + PubComp(PubComp), + PubRel(PubRel), + PingReq, + PingResp, + Subscribe(Subscribe), + SubAck(SubAck), + Unsubscribe(Unsubscribe), + UnsubAck(UnsubAck), + Disconnect, +} + +/// Key type for TLS authentication +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum Key { + RSA(Vec), + ECC(Vec), +} + +impl From for Request { + fn from(publish: Publish) -> Request { + Request::Publish(publish) + } +} + +impl From for Request { + fn from(subscribe: Subscribe) -> Request { + Request::Subscribe(subscribe) + } +} + +impl From for Request { + fn from(unsubscribe: Unsubscribe) -> Request { + Request::Unsubscribe(unsubscribe) + } +} + +#[derive(Clone)] +pub enum Transport { + Tcp, + Tls(TlsConfiguration), + #[cfg(unix)] + Unix, + #[cfg(feature = "websocket")] + #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + Ws, + #[cfg(feature = "websocket")] + #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + Wss(TlsConfiguration), +} + +impl Default for Transport { + fn default() -> Self { + Self::tcp() + } +} + +impl Transport { + /// Use regular tcp as transport (default) + pub fn tcp() -> Self { + Self::Tcp + } + + /// Use secure tcp with tls as transport + pub fn tls( + ca: Vec, + client_auth: Option<(Vec, Key)>, + alpn: Option>>, + ) -> Self { + let config = TlsConfiguration::Simple { + ca, + alpn, + client_auth, + }; + + Self::tls_with_config(config) + } + + pub fn tls_with_config(tls_config: TlsConfiguration) -> Self { + Self::Tls(tls_config) + } + + #[cfg(unix)] + pub fn unix() -> Self { + Self::Unix + } + + /// Use websockets as transport + #[cfg(feature = "websocket")] + #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + pub fn ws() -> Self { + Self::Ws + } + + /// Use secure websockets with tls as transport + #[cfg(feature = "websocket")] + #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + pub fn wss( + ca: Vec, + client_auth: Option<(Vec, Key)>, + alpn: Option>>, + ) -> Self { + let config = TlsConfiguration::Simple { + ca, + client_auth, + alpn, + }; + + Self::wss_with_config(config) + } + + #[cfg(feature = "websocket")] + #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + pub fn wss_with_config(tls_config: TlsConfiguration) -> Self { + Self::Wss(tls_config) + } +} + +#[derive(Clone)] +pub enum TlsConfiguration { + Simple { + /// connection method + ca: Vec, + /// alpn settings + alpn: Option>>, + /// tls client_authentication + client_auth: Option<(Vec, Key)>, + }, + /// Injected rustls ClientConfig for TLS, to allow more customisation. + Rustls(Arc), +} + +impl From for TlsConfiguration { + fn from(config: ClientConfig) -> Self { + TlsConfiguration::Rustls(Arc::new(config)) + } +} + +// TODO: Should all the options be exposed as public? Drawback +// would be loosing the ability to panic when the user options +// are wrong (e.g empty client id) or aggressive (keep alive time) +/// Options to configure the behaviour of mqtt connection +#[derive(Clone)] +pub struct MqttOptions { + /// broker address that you want to connect to + broker_addr: String, + /// broker port + port: u16, + // What transport protocol to use + transport: Transport, + /// keep alive time to send pingreq to broker when the connection is idle + keep_alive: Duration, + /// clean (or) persistent session + clean_session: bool, + /// client identifier + client_id: String, + /// username and password + credentials: Option<(String, String)>, + /// maximum incoming packet size (verifies remaining length of the packet) + max_incoming_packet_size: usize, + /// Maximum outgoing packet size (only verifies publish payload size) + // TODO Verify this with all packets. This can be packet.write but message left in + // the state might be a footgun as user has to explicitly clean it. Probably state + // has to be moved to network + max_outgoing_packet_size: usize, + /// request (publish, subscribe) channel capacity + request_channel_capacity: usize, + /// Max internal request batching + max_request_batch: usize, + /// Minimum delay time between consecutive outgoing packets + /// while retransmitting pending packets + pending_throttle: Duration, + /// maximum number of outgoing inflight messages + inflight: u16, + /// Last will that will be issued on unexpected disconnect + last_will: Option, + /// Connection timeout + conn_timeout: u64, + /// If set to `true` MQTT acknowledgements are not sent automatically. + /// Every incoming publish packet must be manually acknowledged with `client.ack(...)` method. + manual_acks: bool, +} + +impl MqttOptions { + /// New mqtt options + pub fn new, T: Into>(id: S, host: T, port: u16) -> MqttOptions { + let id = id.into(); + if id.starts_with(' ') || id.is_empty() { + panic!("Invalid client id") + } + + MqttOptions { + broker_addr: host.into(), + port, + transport: Transport::tcp(), + keep_alive: Duration::from_secs(60), + clean_session: true, + client_id: id, + credentials: None, + max_incoming_packet_size: 10 * 1024, + max_outgoing_packet_size: 10 * 1024, + request_channel_capacity: 10, + max_request_batch: 0, + pending_throttle: Duration::from_micros(0), + inflight: 100, + last_will: None, + conn_timeout: 5, + manual_acks: false, + } + } + + /// Broker address + pub fn broker_address(&self) -> (String, u16) { + (self.broker_addr.clone(), self.port) + } + + pub fn set_last_will(&mut self, will: LastWill) -> &mut Self { + self.last_will = Some(will); + self + } + + pub fn last_will(&self) -> Option { + self.last_will.clone() + } + + pub fn set_transport(&mut self, transport: Transport) -> &mut Self { + self.transport = transport; + self + } + + pub fn transport(&self) -> Transport { + self.transport.clone() + } + + /// Set number of seconds after which client should ping the broker + /// if there is no other data exchange + pub fn set_keep_alive(&mut self, duration: Duration) -> &mut Self { + if duration.as_secs() < 5 { + panic!("Keep alives should be >= 5 secs"); + } + + self.keep_alive = duration; + self + } + + /// Keep alive time + pub fn keep_alive(&self) -> Duration { + self.keep_alive + } + + /// Client identifier + pub fn client_id(&self) -> String { + self.client_id.clone() + } + + /// Set packet size limit for outgoing an incoming packets + pub fn set_max_packet_size(&mut self, incoming: usize, outgoing: usize) -> &mut Self { + self.max_incoming_packet_size = incoming; + self.max_outgoing_packet_size = outgoing; + self + } + + /// Maximum packet size + pub fn max_packet_size(&self) -> usize { + self.max_incoming_packet_size + } + + /// `clean_session = true` removes all the state from queues & instructs the broker + /// to clean all the client state when client disconnects. + /// + /// When set `false`, broker will hold the client state and performs pending + /// operations on the client when reconnection with same `client_id` + /// happens. Local queue state is also held to retransmit packets after reconnection. + pub fn set_clean_session(&mut self, clean_session: bool) -> &mut Self { + self.clean_session = clean_session; + self + } + + /// Clean session + pub fn clean_session(&self) -> bool { + self.clean_session + } + + /// Username and password + pub fn set_credentials, P: Into>(&mut self, username: U, password: P) -> &mut Self { + self.credentials = Some((username.into(), password.into())); + self + } + + /// Security options + pub fn credentials(&self) -> Option<(String, String)> { + self.credentials.clone() + } + + /// Set request channel capacity + pub fn set_request_channel_capacity(&mut self, capacity: usize) -> &mut Self { + self.request_channel_capacity = capacity; + self + } + + /// Request channel capacity + pub fn request_channel_capacity(&self) -> usize { + self.request_channel_capacity + } + + /// Enables throttling and sets outoing message rate to the specified 'rate' + pub fn set_pending_throttle(&mut self, duration: Duration) -> &mut Self { + self.pending_throttle = duration; + self + } + + /// Outgoing message rate + pub fn pending_throttle(&self) -> Duration { + self.pending_throttle + } + + /// Set number of concurrent in flight messages + pub fn set_inflight(&mut self, inflight: u16) -> &mut Self { + if inflight == 0 { + panic!("zero in flight is not allowed") + } + + self.inflight = inflight; + self + } + + /// Number of concurrent in flight messages + pub fn inflight(&self) -> u16 { + self.inflight + } + + /// set connection timeout in secs + pub fn set_connection_timeout(&mut self, timeout: u64) -> &mut Self { + self.conn_timeout = timeout; + self + } + + /// get timeout in secs + pub fn connection_timeout(&self) -> u64 { + self.conn_timeout + } + + /// set manual acknowledgements + pub fn set_manual_acks(&mut self, manual_acks: bool) -> &mut Self { + self.manual_acks = manual_acks; + self + } + + /// get manual acknowledgements + pub fn manual_acks(&self) -> bool { + self.manual_acks + } +} + +#[cfg(feature = "url")] +#[derive(Debug, PartialEq, thiserror::Error)] +pub enum OptionError { + #[error("Unsupported URL scheme.")] + Scheme, + + #[error("Missing client ID.")] + ClientId, + + #[error("Invalid keep-alive value.")] + KeepAlive, + + #[error("Invalid clean-session value.")] + CleanSession, + + #[error("Invalid max-incoming-packet-size value.")] + MaxIncomingPacketSize, + + #[error("Invalid max-outgoing-packet-size value.")] + MaxOutgoingPacketSize, + + #[error("Invalid request-channel-capacity value.")] + RequestChannelCapacity, + + #[error("Invalid max-request-batch value.")] + MaxRequestBatch, + + #[error("Invalid pending-throttle value.")] + PendingThrottle, + + #[error("Invalid inflight value.")] + Inflight, + + #[error("Invalid conn-timeout value.")] + ConnTimeout, + + #[error("Unknown option: {0}")] + Unknown(String), +} + +#[cfg(feature = "url")] +impl std::convert::TryFrom for MqttOptions { + type Error = OptionError; + + fn try_from(url: url::Url) -> Result { + use std::collections::HashMap; + + let broker_addr = url.host_str().unwrap_or_default().to_owned(); + + let (transport, default_port) = match url.scheme() { + // Encrypted connections are supported, but require explicit TLS configuration. We fall + // back to the unencrypted transport layer, so that `set_transport` can be used to + // configure the encrypted transport layer with the provided TLS configuration. + "mqtts" | "ssl" => (Transport::Tcp, 8883), + "mqtt" | "tcp" => (Transport::Tcp, 1883), + _ => return Err(OptionError::Scheme), + }; + + let port = url.port().unwrap_or(default_port); + + let mut queries = url.query_pairs().collect::>(); + + let keep_alive = Duration::from_secs( + queries + .remove("keep_alive_secs") + .map(|v| v.parse::().map_err(|_| OptionError::KeepAlive)) + .transpose()? + .unwrap_or(60), + ); + + let client_id = queries + .remove("client_id") + .ok_or(OptionError::ClientId)? + .into_owned(); + + let clean_session = queries + .remove("clean_session") + .map(|v| v.parse::().map_err(|_| OptionError::CleanSession)) + .transpose()? + .unwrap_or(true); + + let credentials = { + match url.username() { + "" => None, + username => Some(( + username.to_owned(), + url.password().unwrap_or_default().to_owned(), + )), + } + }; + + let max_incoming_packet_size = queries + .remove("max_incoming_packet_size_bytes") + .map(|v| { + v.parse::() + .map_err(|_| OptionError::MaxIncomingPacketSize) + }) + .transpose()? + .unwrap_or(10 * 1024); + + let max_outgoing_packet_size = queries + .remove("max_outgoing_packet_size_bytes") + .map(|v| { + v.parse::() + .map_err(|_| OptionError::MaxOutgoingPacketSize) + }) + .transpose()? + .unwrap_or(10 * 1024); + + let request_channel_capacity = queries + .remove("request_channel_capacity_num") + .map(|v| { + v.parse::() + .map_err(|_| OptionError::RequestChannelCapacity) + }) + .transpose()? + .unwrap_or(10); + + let max_request_batch = queries + .remove("max_request_batch_num") + .map(|v| v.parse::().map_err(|_| OptionError::MaxRequestBatch)) + .transpose()? + .unwrap_or(0); + + let pending_throttle = Duration::from_micros( + queries + .remove("pending_throttle_usecs") + .map(|v| v.parse::().map_err(|_| OptionError::PendingThrottle)) + .transpose()? + .unwrap_or(0), + ); + + let inflight = queries + .remove("inflight_num") + .map(|v| v.parse::().map_err(|_| OptionError::Inflight)) + .transpose()? + .unwrap_or(100); + + let conn_timeout = queries + .remove("conn_timeout_secs") + .map(|v| v.parse::().map_err(|_| OptionError::ConnTimeout)) + .transpose()? + .unwrap_or(5); + + if let Some((opt, _)) = queries.into_iter().next() { + return Err(OptionError::Unknown(opt.into_owned())); + } + + Ok(Self { + broker_addr, + port, + transport, + keep_alive, + clean_session, + client_id, + credentials, + max_incoming_packet_size, + max_outgoing_packet_size, + request_channel_capacity, + max_request_batch, + pending_throttle, + inflight, + last_will: None, + conn_timeout, + manual_acks: false + }) + } +} + +// Implement Debug manually because ClientConfig doesn't implement it, so derive(Debug) doesn't +// work. +impl Debug for MqttOptions { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("MqttOptions") + .field("broker_addr", &self.broker_addr) + .field("port", &self.port) + .field("keep_alive", &self.keep_alive) + .field("clean_session", &self.clean_session) + .field("client_id", &self.client_id) + .field("credentials", &self.credentials) + .field("max_packet_size", &self.max_incoming_packet_size) + .field("request_channel_capacity", &self.request_channel_capacity) + .field("max_request_batch", &self.max_request_batch) + .field("pending_throttle", &self.pending_throttle) + .field("inflight", &self.inflight) + .field("last_will", &self.last_will) + .field("conn_timeout", &self.conn_timeout) + .field("manual_acks", &self.manual_acks) + .finish() + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + #[should_panic] + fn client_id_startswith_space() { + let _mqtt_opts = MqttOptions::new(" client_a", "127.0.0.1", 1883).set_clean_session(true); + } + + #[test] + #[cfg(feature = "websocket")] + fn no_scheme() { + let mut _mqtt_opts = MqttOptions::new("client_a", "a3f8czas.iot.eu-west-1.amazonaws.com/mqtt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=MyCreds%2F20201001%2Feu-west-1%2Fiotdevicegateway%2Faws4_request&X-Amz-Date=20201001T130812Z&X-Amz-Expires=7200&X-Amz-Signature=9ae09b49896f44270f2707551581953e6cac71a4ccf34c7c3415555be751b2d1&X-Amz-SignedHeaders=host", 443); + + _mqtt_opts.set_transport(crate::Transport::wss(Vec::from("Test CA"), None, None)); + + if let crate::Transport::Wss(TlsConfiguration::Simple { + ca, + client_auth, + alpn, + }) = _mqtt_opts.transport + { + assert_eq!(ca, Vec::from("Test CA")); + assert_eq!(client_auth, None); + assert_eq!(alpn, None); + } else { + panic!("Unexpected transport!"); + } + + assert_eq!(_mqtt_opts.broker_addr, "a3f8czas.iot.eu-west-1.amazonaws.com/mqtt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=MyCreds%2F20201001%2Feu-west-1%2Fiotdevicegateway%2Faws4_request&X-Amz-Date=20201001T130812Z&X-Amz-Expires=7200&X-Amz-Signature=9ae09b49896f44270f2707551581953e6cac71a4ccf34c7c3415555be751b2d1&X-Amz-SignedHeaders=host"); + } + + #[test] + #[cfg(feature = "url")] + fn from_url() { + use std::convert::TryInto; + use std::str::FromStr; + + fn opt(s: &str) -> Result { + url::Url::from_str(s).expect("valid url").try_into() + } + fn ok(s: &str) -> MqttOptions { + opt(s).expect("valid options") + } + fn err(s: &str) -> OptionError { + opt(s).expect_err("invalid options") + } + + let v = ok("mqtt://host:42?client_id=foo"); + assert_eq!(v.broker_address(), ("host".to_owned(), 42)); + assert_eq!(v.client_id(), "foo".to_owned()); + + let v = ok("mqtt://host:42?client_id=foo&keep_alive_secs=5"); + assert_eq!(v.keep_alive, Duration::from_secs(5)); + + assert_eq!(err("mqtt://host:42"), OptionError::ClientId); + assert_eq!( + err("mqtt://host:42?client_id=foo&foo=bar"), + OptionError::Unknown("foo".to_owned()) + ); + assert_eq!(err("mqt://host:42?client_id=foo"), OptionError::Scheme); + assert_eq!( + err("mqtt://host:42?client_id=foo&keep_alive_secs=foo"), + OptionError::KeepAlive + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&clean_session=foo"), + OptionError::CleanSession + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&max_incoming_packet_size_bytes=foo"), + OptionError::MaxIncomingPacketSize + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&max_outgoing_packet_size_bytes=foo"), + OptionError::MaxOutgoingPacketSize + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&request_channel_capacity_num=foo"), + OptionError::RequestChannelCapacity + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&max_request_batch_num=foo"), + OptionError::MaxRequestBatch + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&pending_throttle_usecs=foo"), + OptionError::PendingThrottle + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&inflight_num=foo"), + OptionError::Inflight + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&conn_timeout_secs=foo"), + OptionError::ConnTimeout + ); + } + + #[test] + #[should_panic] + fn no_client_id() { + let _mqtt_opts = MqttOptions::new("", "127.0.0.1", 1883).set_clean_session(true); + } +} diff --git a/rumqttc/src/state.rs b/rumqttc/src/v4/state.rs similarity index 99% rename from rumqttc/src/state.rs rename to rumqttc/src/v4/state.rs index e33c7eb6e..c51116ccf 100644 --- a/rumqttc/src/state.rs +++ b/rumqttc/src/v4/state.rs @@ -1,4 +1,4 @@ -use crate::{Event, Incoming, Outgoing, Request}; +use crate::v4::{Event, Incoming, Outgoing, Request}; use bytes::BytesMut; use mqttbytes::v4::*; @@ -487,7 +487,7 @@ impl MqttState { #[cfg(test)] mod test { use super::{MqttState, StateError}; - use crate::{Event, Incoming, MqttOptions, Outgoing, Request}; + use crate::v4::{Event, Incoming, MqttOptions, Outgoing, Request}; use mqttbytes::v4::*; use mqttbytes::*; diff --git a/rumqttc/src/tls.rs b/rumqttc/src/v4/tls.rs similarity index 98% rename from rumqttc/src/tls.rs rename to rumqttc/src/v4/tls.rs index 747614015..ea1809e23 100644 --- a/rumqttc/src/tls.rs +++ b/rumqttc/src/v4/tls.rs @@ -7,7 +7,7 @@ use tokio_rustls::rustls::{ use tokio_rustls::webpki; use tokio_rustls::{client::TlsStream, TlsConnector}; -use crate::{Key, MqttOptions, TlsConfiguration}; +use crate::v4::{Key, MqttOptions, TlsConfiguration}; use std::convert::TryFrom; use std::io; diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs new file mode 100644 index 000000000..e69de29bb diff --git a/rumqttc/tests/broker.rs b/rumqttc/tests/broker.rs index 13c9ffc50..194f8ec92 100644 --- a/rumqttc/tests/broker.rs +++ b/rumqttc/tests/broker.rs @@ -9,7 +9,7 @@ use tokio::{task, time}; use async_channel::{bounded, Receiver, Sender}; use bytes::BytesMut; -use rumqttc::{Event, Incoming, Outgoing, Packet}; +use rumqttc::v4::{Event, Incoming, Outgoing, Packet}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; pub struct Broker { diff --git a/rumqttc/tests/reliability.rs b/rumqttc/tests/reliability.rs index 454c425da..3a129b6b8 100644 --- a/rumqttc/tests/reliability.rs +++ b/rumqttc/tests/reliability.rs @@ -5,7 +5,7 @@ use tokio::{task, time}; mod broker; use broker::*; -use rumqttc::*; +use rumqttc::v4::*; async fn start_requests(count: u8, qos: QoS, delay: u64, requests_tx: Sender) { for i in 1..=count { From 40716a3b89567abd3b9dd1eea1325e4c1d0aec80 Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Fri, 11 Feb 2022 17:49:20 +0530 Subject: [PATCH 02/41] rumqttc: v5: init with v4 code Signed-off-by: Abhik Jain --- rumqttc/src/v5/client.rs | 423 +++++++++++++++++++ rumqttc/src/v5/eventloop.rs | 380 ++++++++++++++++++ rumqttc/src/v5/framed.rs | 122 ++++++ rumqttc/src/v5/mod.rs | 707 ++++++++++++++++++++++++++++++++ rumqttc/src/v5/packet.rs | 0 rumqttc/src/v5/state.rs | 781 ++++++++++++++++++++++++++++++++++++ rumqttc/src/v5/tls.rs | 130 ++++++ 7 files changed, 2543 insertions(+) create mode 100644 rumqttc/src/v5/client.rs create mode 100644 rumqttc/src/v5/eventloop.rs create mode 100644 rumqttc/src/v5/framed.rs create mode 100644 rumqttc/src/v5/packet.rs create mode 100644 rumqttc/src/v5/state.rs create mode 100644 rumqttc/src/v5/tls.rs diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs new file mode 100644 index 000000000..8b95c1bc2 --- /dev/null +++ b/rumqttc/src/v5/client.rs @@ -0,0 +1,423 @@ +//! This module offers a high level synchronous and asynchronous abstraction to +//! async eventloop. +use crate::v4::{ConnectionError, Event, EventLoop, MqttOptions, Request}; + +use async_channel::{SendError, Sender, TrySendError}; +use bytes::Bytes; +use mqttbytes::v4::*; +use mqttbytes::*; +use std::mem; +use tokio::runtime; +use tokio::runtime::Runtime; + +/// Client Error +#[derive(Debug, thiserror::Error)] +pub enum ClientError { + #[error("Failed to send cancel request to eventloop")] + Cancel(#[from] SendError<()>), + #[error("Failed to send mqtt requests to eventloop")] + Request(#[from] SendError), + #[error("Failed to send mqtt requests to eventloop")] + TryRequest(#[from] TrySendError), + #[error("Serialization error")] + Mqtt4(mqttbytes::Error), +} + +/// `AsyncClient` to communicate with MQTT `Eventloop` +/// This is cloneable and can be used to asynchronously Publish, Subscribe. +#[derive(Clone, Debug)] +pub struct AsyncClient { + request_tx: Sender, + cancel_tx: Sender<()>, +} + +impl AsyncClient { + /// Create a new `AsyncClient` + pub fn new(options: MqttOptions, cap: usize) -> (AsyncClient, EventLoop) { + let mut eventloop = EventLoop::new(options, cap); + let request_tx = eventloop.handle(); + let cancel_tx = eventloop.cancel_handle(); + + let client = AsyncClient { + request_tx, + cancel_tx, + }; + + (client, eventloop) + } + + /// Create a new `AsyncClient` from a pair of async channel `Sender`s. This is mostly useful for + /// creating a test instance. + pub fn from_senders(request_tx: Sender, cancel_tx: Sender<()>) -> AsyncClient { + AsyncClient { + request_tx, + cancel_tx, + } + } + + /// Sends a MQTT Publish to the eventloop + pub async fn publish( + &self, + topic: S, + qos: QoS, + retain: bool, + payload: V, + ) -> Result<(), ClientError> + where + S: Into, + V: Into>, + { + let mut publish = Publish::new(topic, qos, payload); + publish.retain = retain; + let publish = Request::Publish(publish); + self.request_tx.send(publish).await?; + Ok(()) + } + + /// Sends a MQTT Publish to the eventloop + pub fn try_publish( + &self, + topic: S, + qos: QoS, + retain: bool, + payload: V, + ) -> Result<(), ClientError> + where + S: Into, + V: Into>, + { + let mut publish = Publish::new(topic, qos, payload); + publish.retain = retain; + let publish = Request::Publish(publish); + self.request_tx.try_send(publish)?; + Ok(()) + } + + /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. + pub async fn ack( + &self, + publish: &Publish + ) -> Result<(), ClientError> + { + let ack = get_ack_req(publish); + + if let Some(ack) = ack { + self.request_tx.send(ack).await?; + } + Ok(()) + } + + /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. + pub fn try_ack( + &self, + publish: &Publish + ) -> Result<(), ClientError> + { + let ack = get_ack_req(publish); + if let Some(ack) = ack { + self.request_tx.try_send(ack)?; + } + Ok(()) + } + + /// Sends a MQTT Publish to the eventloop + pub async fn publish_bytes( + &self, + topic: S, + qos: QoS, + retain: bool, + payload: Bytes, + ) -> Result<(), ClientError> + where + S: Into, + { + let mut publish = Publish::from_bytes(topic, qos, payload); + publish.retain = retain; + let publish = Request::Publish(publish); + self.request_tx.send(publish).await?; + Ok(()) + } + + /// Sends a MQTT Subscribe to the eventloop + pub async fn subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { + let subscribe = Subscribe::new(topic.into(), qos); + let request = Request::Subscribe(subscribe); + self.request_tx.send(request).await?; + Ok(()) + } + + /// Sends a MQTT Subscribe to the eventloop + pub fn try_subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { + let subscribe = Subscribe::new(topic.into(), qos); + let request = Request::Subscribe(subscribe); + self.request_tx.try_send(request)?; + Ok(()) + } + + /// Sends a MQTT Subscribe for multiple topics to the eventloop + pub async fn subscribe_many(&self, topics: T) -> Result<(), ClientError> + where + T: IntoIterator, + { + let subscribe = Subscribe::new_many(topics); + let request = Request::Subscribe(subscribe); + self.request_tx.send(request).await?; + Ok(()) + } + + /// Sends a MQTT Subscribe for multiple topics to the eventloop + pub fn try_subscribe_many(&self, topics: T) -> Result<(), ClientError> + where + T: IntoIterator, + { + let subscribe = Subscribe::new_many(topics); + let request = Request::Subscribe(subscribe); + self.request_tx.try_send(request)?; + Ok(()) + } + + /// Sends a MQTT Unsubscribe to the eventloop + pub async fn unsubscribe>(&self, topic: S) -> Result<(), ClientError> { + let unsubscribe = Unsubscribe::new(topic.into()); + let request = Request::Unsubscribe(unsubscribe); + self.request_tx.send(request).await?; + Ok(()) + } + + /// Sends a MQTT Unsubscribe to the eventloop + pub fn try_unsubscribe>(&self, topic: S) -> Result<(), ClientError> { + let unsubscribe = Unsubscribe::new(topic.into()); + let request = Request::Unsubscribe(unsubscribe); + self.request_tx.try_send(request)?; + Ok(()) + } + + /// Sends a MQTT disconnect to the eventloop + pub async fn disconnect(&self) -> Result<(), ClientError> { + let request = Request::Disconnect; + self.request_tx.send(request).await?; + Ok(()) + } + + /// Sends a MQTT disconnect to the eventloop + pub fn try_disconnect(&self) -> Result<(), ClientError> { + let request = Request::Disconnect; + self.request_tx.try_send(request)?; + Ok(()) + } + + /// Stops the eventloop right away + pub async fn cancel(&self) -> Result<(), ClientError> { + self.cancel_tx.send(()).await?; + Ok(()) + } +} + +fn get_ack_req(publish: &Publish) -> Option { + let ack = match publish.qos { + QoS::AtMostOnce => return None, + QoS::AtLeastOnce => Request::PubAck(PubAck::new(publish.pkid)), + QoS::ExactlyOnce => Request::PubRec(PubRec::new(publish.pkid)) + }; + Some(ack) +} + +/// `Client` to communicate with MQTT eventloop `Connection`. +/// +/// Client is cloneable and can be used to synchronously Publish, Subscribe. +/// Asynchronous channel handle can also be extracted if necessary +#[derive(Clone)] +pub struct Client { + client: AsyncClient, +} + +impl Client { + /// Create a new `Client` + pub fn new(options: MqttOptions, cap: usize) -> (Client, Connection) { + let (client, eventloop) = AsyncClient::new(options, cap); + let client = Client { client }; + let runtime = runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + let connection = Connection::new(eventloop, runtime); + (client, connection) + } + + /// Sends a MQTT Publish to the eventloop + pub fn publish( + &mut self, + topic: S, + qos: QoS, + retain: bool, + payload: V, + ) -> Result<(), ClientError> + where + S: Into, + V: Into>, + { + pollster::block_on(self.client.publish(topic, qos, retain, payload))?; + Ok(()) + } + + pub fn try_publish( + &mut self, + topic: S, + qos: QoS, + retain: bool, + payload: V, + ) -> Result<(), ClientError> + where + S: Into, + V: Into>, + { + self.client.try_publish(topic, qos, retain, payload)?; + Ok(()) + } + + /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. + pub fn ack( + &self, + publish: &Publish + ) -> Result<(), ClientError> + { + pollster::block_on(self.client.ack(publish))?; + Ok(()) + } + + /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. + pub fn try_ack( + &self, + publish: &Publish + ) -> Result<(), ClientError> + { + self.client.try_ack(publish)?; + Ok(()) + } + + + /// Sends a MQTT Subscribe to the eventloop + pub fn subscribe>(&mut self, topic: S, qos: QoS) -> Result<(), ClientError> { + pollster::block_on(self.client.subscribe(topic, qos))?; + Ok(()) + } + + /// Sends a MQTT Subscribe to the eventloop + pub fn try_subscribe>( + &mut self, + topic: S, + qos: QoS, + ) -> Result<(), ClientError> { + self.client.try_subscribe(topic, qos)?; + Ok(()) + } + + /// Sends a MQTT Subscribe for multiple topics to the eventloop + pub fn subscribe_many(&mut self, topics: T) -> Result<(), ClientError> + where + T: IntoIterator, + { + pollster::block_on(self.client.subscribe_many(topics)) + } + + pub fn try_subscribe_many(&mut self, topics: T) -> Result<(), ClientError> + where + T: IntoIterator, + { + self.client.try_subscribe_many(topics) + } + + /// Sends a MQTT Unsubscribe to the eventloop + pub fn unsubscribe>(&mut self, topic: S) -> Result<(), ClientError> { + pollster::block_on(self.client.unsubscribe(topic))?; + Ok(()) + } + + /// Sends a MQTT Unsubscribe to the eventloop + pub fn try_unsubscribe>(&mut self, topic: S) -> Result<(), ClientError> { + self.client.try_unsubscribe(topic)?; + Ok(()) + } + + /// Sends a MQTT disconnect to the eventloop + pub fn disconnect(&mut self) -> Result<(), ClientError> { + pollster::block_on(self.client.disconnect())?; + Ok(()) + } + + /// Sends a MQTT disconnect to the eventloop + pub fn try_disconnect(&mut self) -> Result<(), ClientError> { + self.client.try_disconnect()?; + Ok(()) + } + + /// Stops the eventloop right away + pub fn cancel(&mut self) -> Result<(), ClientError> { + pollster::block_on(self.client.cancel())?; + Ok(()) + } +} + +/// MQTT connection. Maintains all the necessary state +pub struct Connection { + pub eventloop: EventLoop, + runtime: Option, +} + +impl Connection { + fn new(eventloop: EventLoop, runtime: Runtime) -> Connection { + Connection { + eventloop, + runtime: Some(runtime), + } + } + + /// Returns an iterator over this connection. Iterating over this is all that's + /// necessary to make connection progress and maintain a robust connection. + /// Just continuing to loop will reconnect + /// **NOTE** Don't block this while iterating + #[must_use = "Connection should be iterated over a loop to make progress"] + pub fn iter(&mut self) -> Iter { + let runtime = self.runtime.take().unwrap(); + Iter { + connection: self, + runtime, + } + } +} + +/// Iterator which polls the eventloop for connection progress +pub struct Iter<'a> { + connection: &'a mut Connection, + runtime: runtime::Runtime, +} + +impl<'a> Iterator for Iter<'a> { + type Item = Result; + + fn next(&mut self) -> Option { + let f = self.connection.eventloop.poll(); + match self.runtime.block_on(f) { + Ok(v) => Some(Ok(v)), + // closing of request channel should stop the iterator + Err(ConnectionError::RequestsDone) => { + trace!("Done with requests"); + None + } + Err(ConnectionError::Cancel) => { + trace!("Cancellation request received"); + None + } + Err(e) => Some(Err(e)), + } + } +} + +impl<'a> Drop for Iter<'a> { + fn drop(&mut self) { + // TODO: Don't create new runtime in drop + let runtime = runtime::Builder::new_current_thread().build().unwrap(); + self.connection.runtime = Some(mem::replace(&mut self.runtime, runtime)); + } +} diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs new file mode 100644 index 000000000..d4ce78c7a --- /dev/null +++ b/rumqttc/src/v5/eventloop.rs @@ -0,0 +1,380 @@ +use crate::v4::{framed::Network, Transport}; +use crate::v4::{tls, Incoming, MqttState, Packet, Request, StateError}; +use crate::v4::{MqttOptions, Outgoing}; + +use async_channel::{bounded, Receiver, Sender}; +#[cfg(feature = "websocket")] +use async_tungstenite::tokio::{connect_async, connect_async_with_tls_connector}; +use mqttbytes::v4::*; +use tokio::net::TcpStream; +#[cfg(unix)] +use tokio::net::UnixStream; +use tokio::select; +use tokio::time::{self, error::Elapsed, Instant, Sleep}; +#[cfg(feature = "websocket")] +use ws_stream_tungstenite::WsStream; + +use std::io; +#[cfg(unix)] +use std::path::Path; +use std::pin::Pin; +use std::time::Duration; +use std::vec::IntoIter; + +/// Critical errors during eventloop polling +#[derive(Debug, thiserror::Error)] +pub enum ConnectionError { + #[error("Mqtt state: {0}")] + MqttState(#[from] StateError), + #[error("Timeout")] + Timeout(#[from] Elapsed), + #[error("Packet parsing error: {0}")] + Mqtt4Bytes(mqttbytes::Error), + #[error("Network: {0}")] + Network(#[from] tls::Error), + #[error("I/O: {0}")] + Io(#[from] io::Error), + #[error("Stream done")] + StreamDone, + #[error("Requests done")] + RequestsDone, + #[error("Cancel request by the user")] + Cancel, +} + +/// Eventloop with all the state of a connection +pub struct EventLoop { + /// Options of the current mqtt connection + pub options: MqttOptions, + /// Current state of the connection + pub state: MqttState, + /// Request stream + pub requests_rx: Receiver, + /// Requests handle to send requests + pub requests_tx: Sender, + /// Pending packets from last session + pub pending: IntoIter, + /// Network connection to the broker + pub(crate) network: Option, + /// Keep alive time + pub(crate) keepalive_timeout: Option>>, + /// Handle to read cancellation requests + pub(crate) cancel_rx: Receiver<()>, + /// Handle to send cancellation requests (and drops) + pub(crate) cancel_tx: Sender<()>, +} + +/// Events which can be yielded by the event loop +#[derive(Debug, PartialEq, Clone)] +pub enum Event { + Incoming(Incoming), + Outgoing(Outgoing), +} + +impl EventLoop { + /// New MQTT `EventLoop` + /// + /// When connection encounters critical errors (like auth failure), user has a choice to + /// access and update `options`, `state` and `requests`. + pub fn new(options: MqttOptions, cap: usize) -> EventLoop { + let (cancel_tx, cancel_rx) = bounded(5); + let (requests_tx, requests_rx) = bounded(cap); + let pending = Vec::new(); + let pending = pending.into_iter(); + let max_inflight = options.inflight; + let manual_acks = options.manual_acks; + + EventLoop { + options, + state: MqttState::new(max_inflight, manual_acks), + requests_tx, + requests_rx, + pending, + network: None, + keepalive_timeout: None, + cancel_rx, + cancel_tx, + } + } + + /// Returns a handle to communicate with this eventloop + pub fn handle(&self) -> Sender { + self.requests_tx.clone() + } + + /// Handle for cancelling the eventloop. + /// + /// Can be useful in cases when connection should be halted immediately + /// between half-open connection detections or (re)connection timeouts + pub(crate) fn cancel_handle(&mut self) -> Sender<()> { + self.cancel_tx.clone() + } + + fn clean(&mut self) { + self.network = None; + self.keepalive_timeout = None; + let pending = self.state.clean(); + self.pending = pending.into_iter(); + } + + /// Yields Next notification or outgoing request and periodically pings + /// the broker. Continuing to poll will reconnect to the broker if there is + /// a disconnection. + /// **NOTE** Don't block this while iterating + pub async fn poll(&mut self) -> Result { + if self.network.is_none() { + let (network, connack) = connect_or_cancel(&self.options, &self.cancel_rx).await?; + self.network = Some(network); + + if self.keepalive_timeout.is_none() { + self.keepalive_timeout = Some(Box::pin(time::sleep(self.options.keep_alive))); + } + + return Ok(Event::Incoming(connack)); + } + + match self.select().await { + Ok(v) => Ok(v), + Err(e) => { + self.clean(); + Err(e) + } + } + } + + /// Select on network and requests and generate keepalive pings when necessary + async fn select(&mut self) -> Result { + let network = self.network.as_mut().unwrap(); + // let await_acks = self.state.await_acks; + let inflight_full = self.state.inflight >= self.options.inflight; + let throttle = self.options.pending_throttle; + let pending = self.pending.len() > 0; + let collision = self.state.collision.is_some(); + + // Read buffered events from previous polls before calling a new poll + if let Some(event) = self.state.events.pop_front() { + return Ok(event); + } + + // this loop is necessary since self.incoming.pop_front() might return None. In that case, + // instead of returning a None event, we try again. + select! { + // Pull a bunch of packets from network, reply in bunch and yield the first item + o = network.readb(&mut self.state) => { + o?; + // flush all the acks and return first incoming packet + network.flush(&mut self.state.write).await?; + Ok(self.state.events.pop_front().unwrap()) + }, + // Pull next request from user requests channel. + // If conditions in the below branch are for flow control. We read next user + // user request only when inflight messages are < configured inflight and there + // are no collisions while handling previous outgoing requests. + // + // Flow control is based on ack count. If inflight packet count in the buffer is + // less than max_inflight setting, next outgoing request will progress. For this + // to work correctly, broker should ack in sequence (a lot of brokers won't) + // + // E.g If max inflight = 5, user requests will be blocked when inflight queue + // looks like this -> [1, 2, 3, 4, 5]. + // If broker acking 2 instead of 1 -> [1, x, 3, 4, 5]. + // This pulls next user request. But because max packet id = max_inflight, next + // user request's packet id will roll to 1. This replaces existing packet id 1. + // Resulting in a collision + // + // Eventloop can stop receiving outgoing user requests when previous outgoing + // request collided. I.e collision state. Collision state will be cleared only + // when correct ack is received + // Full inflight queue will look like -> [1a, 2, 3, 4, 5]. + // If 3 is acked instead of 1 first -> [1a, 2, x, 4, 5]. + // After collision with pkid 1 -> [1b ,2, x, 4, 5]. + // 1a is saved to state and event loop is set to collision mode stopping new + // outgoing requests (along with 1b). + o = self.requests_rx.recv(), if !inflight_full && !pending && !collision => match o { + Ok(request) => { + self.state.handle_outgoing_packet(request)?; + network.flush(&mut self.state.write).await?; + Ok(self.state.events.pop_front().unwrap()) + } + Err(_) => Err(ConnectionError::RequestsDone), + }, + // Handle the next pending packet from previous session. Disable + // this branch when done with all the pending packets + Some(request) = next_pending(throttle, &mut self.pending), if pending => { + self.state.handle_outgoing_packet(request)?; + network.flush(&mut self.state.write).await?; + Ok(self.state.events.pop_front().unwrap()) + }, + // We generate pings irrespective of network activity. This keeps the ping logic + // simple. We can change this behavior in future if necessary (to prevent extra pings) + _ = self.keepalive_timeout.as_mut().unwrap() => { + let timeout = self.keepalive_timeout.as_mut().unwrap(); + timeout.as_mut().reset(Instant::now() + self.options.keep_alive); + + self.state.handle_outgoing_packet(Request::PingReq)?; + network.flush(&mut self.state.write).await?; + Ok(self.state.events.pop_front().unwrap()) + } + // cancellation requests to stop the polling + _ = self.cancel_rx.recv() => { + Err(ConnectionError::Cancel) + } + } + } +} + +async fn connect_or_cancel( + options: &MqttOptions, + cancel_rx: &Receiver<()>, +) -> Result<(Network, Incoming), ConnectionError> { + // select here prevents cancel request from being blocked until connection request is + // resolved. Returns with an error if connections fail continuously + select! { + o = connect(options) => o, + _ = cancel_rx.recv() => { + Err(ConnectionError::Cancel) + } + } +} + +/// This stream internally processes requests from the request stream provided to the eventloop +/// while also consuming byte stream from the network and yielding mqtt packets as the output of +/// the stream. +/// This function (for convenience) includes internal delays for users to perform internal sleeps +/// between re-connections so that cancel semantics can be used during this sleep +async fn connect(options: &MqttOptions) -> Result<(Network, Incoming), ConnectionError> { + // connect to the broker + let mut network = match network_connect(options).await { + Ok(network) => network, + Err(e) => { + return Err(e); + } + }; + + // make MQTT connection request (which internally awaits for ack) + let packet = match mqtt_connect(options, &mut network).await { + Ok(p) => p, + Err(e) => return Err(e), + }; + + // Last session might contain packets which aren't acked. MQTT says these packets should be + // republished in the next session + // move pending messages from state to eventloop + // let pending = self.state.clean(); + // self.pending = pending.into_iter(); + Ok((network, packet)) +} + +async fn network_connect(options: &MqttOptions) -> Result { + let network = match options.transport() { + Transport::Tcp => { + let addr = options.broker_addr.as_str(); + let port = options.port; + let socket = TcpStream::connect((addr, port)).await?; + Network::new(socket, options.max_incoming_packet_size) + } + Transport::Tls(tls_config) => { + let socket = tls::tls_connect(&options, &tls_config).await?; + Network::new(socket, options.max_incoming_packet_size) + } + #[cfg(unix)] + Transport::Unix => { + let file = options.broker_addr.as_str(); + let socket = UnixStream::connect(Path::new(file)).await?; + Network::new(socket, options.max_incoming_packet_size) + } + #[cfg(feature = "websocket")] + Transport::Ws => { + let request = http::Request::builder() + .method(http::Method::GET) + .uri(options.broker_addr.as_str()) + .header("Sec-WebSocket-Protocol", "mqttv3.1") + .body(()) + .unwrap(); + + let (socket, _) = connect_async(request) + .await + .map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, e))?; + + Network::new(WsStream::new(socket), options.max_incoming_packet_size) + } + #[cfg(feature = "websocket")] + Transport::Wss(tls_config) => { + let request = http::Request::builder() + .method(http::Method::GET) + .uri(options.broker_addr.as_str()) + .header("Sec-WebSocket-Protocol", "mqttv3.1") + .body(()) + .unwrap(); + + let connector = tls::tls_connector(&tls_config).await?; + + let (socket, _) = connect_async_with_tls_connector(request, Some(connector)) + .await + .map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, e))?; + + Network::new(WsStream::new(socket), options.max_incoming_packet_size) + } + }; + + Ok(network) +} + +async fn mqtt_connect( + options: &MqttOptions, + network: &mut Network, +) -> Result { + let keep_alive = options.keep_alive().as_secs() as u16; + let clean_session = options.clean_session(); + let last_will = options.last_will(); + + let mut connect = Connect::new(options.client_id()); + connect.keep_alive = keep_alive; + connect.clean_session = clean_session; + connect.last_will = last_will; + + if let Some((username, password)) = options.credentials() { + let login = Login::new(username, password); + connect.login = Some(login); + } + + // mqtt connection with timeout + time::timeout(Duration::from_secs(options.connection_timeout()), async { + network.connect(connect).await?; + Ok::<_, ConnectionError>(()) + }) + .await??; + + // wait for 'timeout' time to validate connack + let packet = time::timeout(Duration::from_secs(options.connection_timeout()), async { + let packet = match network.read().await? { + Incoming::ConnAck(connack) if connack.code == ConnectReturnCode::Success => { + Packet::ConnAck(connack) + } + Incoming::ConnAck(connack) => { + let error = format!("Broker rejected. Reason = {:?}", connack.code); + return Err(io::Error::new(io::ErrorKind::InvalidData, error)); + } + packet => { + let error = format!("Expecting connack. Received = {:?}", packet); + return Err(io::Error::new(io::ErrorKind::InvalidData, error)); + } + }; + + io::Result::Ok(packet) + }) + .await??; + + Ok(packet) +} + +/// Returns the next pending packet asynchronously to be used in select! +/// This is a synchronous function but made async to make it fit in select! +pub(crate) async fn next_pending( + delay: Duration, + pending: &mut IntoIter, +) -> Option { + // return next packet with a delay + time::sleep(delay).await; + pending.next() +} diff --git a/rumqttc/src/v5/framed.rs b/rumqttc/src/v5/framed.rs new file mode 100644 index 000000000..96bda17da --- /dev/null +++ b/rumqttc/src/v5/framed.rs @@ -0,0 +1,122 @@ +use bytes::BytesMut; +use mqttbytes::v4::*; +use mqttbytes::*; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +use crate::v4::{Incoming, MqttState, StateError}; +use std::io; + +/// Network transforms packets <-> frames efficiently. It takes +/// advantage of pre-allocation, buffering and vectorization when +/// appropriate to achieve performance +pub struct Network { + /// Socket for IO + socket: Box, + /// Buffered reads + read: BytesMut, + /// Maximum packet size + max_incoming_size: usize, + /// Maximum readv count + max_readb_count: usize, +} + +impl Network { + pub fn new(socket: impl N + 'static, max_incoming_size: usize) -> Network { + let socket = Box::new(socket) as Box; + Network { + socket, + read: BytesMut::with_capacity(10 * 1024), + max_incoming_size, + max_readb_count: 10, + } + } + + /// Reads more than 'required' bytes to frame a packet into self.read buffer + async fn read_bytes(&mut self, required: usize) -> io::Result { + let mut total_read = 0; + loop { + let read = self.socket.read_buf(&mut self.read).await?; + if 0 == read { + return if self.read.is_empty() { + Err(io::Error::new( + io::ErrorKind::ConnectionAborted, + "connection closed by peer", + )) + } else { + Err(io::Error::new( + io::ErrorKind::ConnectionReset, + "connection reset by peer", + )) + }; + } + + total_read += read; + if total_read >= required { + return Ok(total_read); + } + } + } + + pub async fn read(&mut self) -> io::Result { + loop { + let required = match read(&mut self.read, self.max_incoming_size) { + Ok(packet) => return Ok(packet), + Err(mqttbytes::Error::InsufficientBytes(required)) => required, + Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())), + }; + + // read more packets until a frame can be created. This function + // blocks until a frame can be created. Use this in a select! branch + self.read_bytes(required).await?; + } + } + + /// Read packets in bulk. This allow replies to be in bulk. This method is used + /// after the connection is established to read a bunch of incoming packets + pub async fn readb(&mut self, state: &mut MqttState) -> Result<(), StateError> { + let mut count = 0; + loop { + match read(&mut self.read, self.max_incoming_size) { + Ok(packet) => { + state.handle_incoming_packet(packet)?; + + count += 1; + if count >= self.max_readb_count { + return Ok(()); + } + } + // If some packets are already framed, return those + Err(Error::InsufficientBytes(_)) if count > 0 => return Ok(()), + // Wait for more bytes until a frame can be created + Err(Error::InsufficientBytes(required)) => { + self.read_bytes(required).await?; + } + Err(e) => return Err(StateError::Deserialization(e)), + }; + } + } + + pub async fn connect(&mut self, connect: Connect) -> io::Result { + let mut write = BytesMut::new(); + let len = match connect.write(&mut write) { + Ok(size) => size, + Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())), + }; + + self.socket.write_all(&write[..]).await?; + Ok(len) + } + + pub async fn flush(&mut self, write: &mut BytesMut) -> io::Result<()> { + if write.is_empty() { + return Ok(()); + } + + self.socket.write_all(&write[..]).await?; + write.clear(); + Ok(()) + } +} + +pub trait N: AsyncRead + AsyncWrite + Send + Unpin {} +impl N for T where T: AsyncRead + AsyncWrite + Send + Unpin {} diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index e69de29bb..01800c685 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -0,0 +1,707 @@ +use std::fmt::{self, Debug, Formatter}; +use std::sync::Arc; +use std::time::Duration; + +mod client; +mod eventloop; +mod framed; +mod packet; +mod state; +mod tls; + +pub use async_channel::{SendError, Sender, TrySendError}; +pub use client::{AsyncClient, Client, ClientError, Connection}; +pub use eventloop::{ConnectionError, Event, EventLoop}; +pub use mqttbytes::v4::*; +pub use mqttbytes::*; +pub use state::{MqttState, StateError}; +pub use tokio_rustls::rustls::ClientConfig; +pub use tls::Error; + +pub type Incoming = Packet; + +/// Current outgoing activity on the eventloop +#[derive(Debug, Eq, PartialEq, Clone)] +pub enum Outgoing { + /// Publish packet with packet identifier. 0 implies QoS 0 + Publish(u16), + /// Subscribe packet with packet identifier + Subscribe(u16), + /// Unsubscribe packet with packet identifier + Unsubscribe(u16), + /// PubAck packet + PubAck(u16), + /// PubRec packet + PubRec(u16), + /// PubRel packet + PubRel(u16), + /// PubComp packet + PubComp(u16), + /// Ping request packet + PingReq, + /// Ping response packet + PingResp, + /// Disconnect packet + Disconnect, + /// Await for an ack for more outgoing progress + AwaitAck(u16), +} + +/// Requests by the client to mqtt event loop. Request are +/// handled one by one. +#[derive(Clone, Debug, PartialEq)] +pub enum Request { + Publish(Publish), + PubAck(PubAck), + PubRec(PubRec), + PubComp(PubComp), + PubRel(PubRel), + PingReq, + PingResp, + Subscribe(Subscribe), + SubAck(SubAck), + Unsubscribe(Unsubscribe), + UnsubAck(UnsubAck), + Disconnect, +} + +/// Key type for TLS authentication +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum Key { + RSA(Vec), + ECC(Vec), +} + +impl From for Request { + fn from(publish: Publish) -> Request { + Request::Publish(publish) + } +} + +impl From for Request { + fn from(subscribe: Subscribe) -> Request { + Request::Subscribe(subscribe) + } +} + +impl From for Request { + fn from(unsubscribe: Unsubscribe) -> Request { + Request::Unsubscribe(unsubscribe) + } +} + +#[derive(Clone)] +pub enum Transport { + Tcp, + Tls(TlsConfiguration), + #[cfg(unix)] + Unix, + #[cfg(feature = "websocket")] + #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + Ws, + #[cfg(feature = "websocket")] + #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + Wss(TlsConfiguration), +} + +impl Default for Transport { + fn default() -> Self { + Self::tcp() + } +} + +impl Transport { + /// Use regular tcp as transport (default) + pub fn tcp() -> Self { + Self::Tcp + } + + /// Use secure tcp with tls as transport + pub fn tls( + ca: Vec, + client_auth: Option<(Vec, Key)>, + alpn: Option>>, + ) -> Self { + let config = TlsConfiguration::Simple { + ca, + alpn, + client_auth, + }; + + Self::tls_with_config(config) + } + + pub fn tls_with_config(tls_config: TlsConfiguration) -> Self { + Self::Tls(tls_config) + } + + #[cfg(unix)] + pub fn unix() -> Self { + Self::Unix + } + + /// Use websockets as transport + #[cfg(feature = "websocket")] + #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + pub fn ws() -> Self { + Self::Ws + } + + /// Use secure websockets with tls as transport + #[cfg(feature = "websocket")] + #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + pub fn wss( + ca: Vec, + client_auth: Option<(Vec, Key)>, + alpn: Option>>, + ) -> Self { + let config = TlsConfiguration::Simple { + ca, + client_auth, + alpn, + }; + + Self::wss_with_config(config) + } + + #[cfg(feature = "websocket")] + #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + pub fn wss_with_config(tls_config: TlsConfiguration) -> Self { + Self::Wss(tls_config) + } +} + +#[derive(Clone)] +pub enum TlsConfiguration { + Simple { + /// connection method + ca: Vec, + /// alpn settings + alpn: Option>>, + /// tls client_authentication + client_auth: Option<(Vec, Key)>, + }, + /// Injected rustls ClientConfig for TLS, to allow more customisation. + Rustls(Arc), +} + +impl From for TlsConfiguration { + fn from(config: ClientConfig) -> Self { + TlsConfiguration::Rustls(Arc::new(config)) + } +} + +// TODO: Should all the options be exposed as public? Drawback +// would be loosing the ability to panic when the user options +// are wrong (e.g empty client id) or aggressive (keep alive time) +/// Options to configure the behaviour of mqtt connection +#[derive(Clone)] +pub struct MqttOptions { + /// broker address that you want to connect to + broker_addr: String, + /// broker port + port: u16, + // What transport protocol to use + transport: Transport, + /// keep alive time to send pingreq to broker when the connection is idle + keep_alive: Duration, + /// clean (or) persistent session + clean_session: bool, + /// client identifier + client_id: String, + /// username and password + credentials: Option<(String, String)>, + /// maximum incoming packet size (verifies remaining length of the packet) + max_incoming_packet_size: usize, + /// Maximum outgoing packet size (only verifies publish payload size) + // TODO Verify this with all packets. This can be packet.write but message left in + // the state might be a footgun as user has to explicitly clean it. Probably state + // has to be moved to network + max_outgoing_packet_size: usize, + /// request (publish, subscribe) channel capacity + request_channel_capacity: usize, + /// Max internal request batching + max_request_batch: usize, + /// Minimum delay time between consecutive outgoing packets + /// while retransmitting pending packets + pending_throttle: Duration, + /// maximum number of outgoing inflight messages + inflight: u16, + /// Last will that will be issued on unexpected disconnect + last_will: Option, + /// Connection timeout + conn_timeout: u64, + /// If set to `true` MQTT acknowledgements are not sent automatically. + /// Every incoming publish packet must be manually acknowledged with `client.ack(...)` method. + manual_acks: bool, +} + +impl MqttOptions { + /// New mqtt options + pub fn new, T: Into>(id: S, host: T, port: u16) -> MqttOptions { + let id = id.into(); + if id.starts_with(' ') || id.is_empty() { + panic!("Invalid client id") + } + + MqttOptions { + broker_addr: host.into(), + port, + transport: Transport::tcp(), + keep_alive: Duration::from_secs(60), + clean_session: true, + client_id: id, + credentials: None, + max_incoming_packet_size: 10 * 1024, + max_outgoing_packet_size: 10 * 1024, + request_channel_capacity: 10, + max_request_batch: 0, + pending_throttle: Duration::from_micros(0), + inflight: 100, + last_will: None, + conn_timeout: 5, + manual_acks: false, + } + } + + /// Broker address + pub fn broker_address(&self) -> (String, u16) { + (self.broker_addr.clone(), self.port) + } + + pub fn set_last_will(&mut self, will: LastWill) -> &mut Self { + self.last_will = Some(will); + self + } + + pub fn last_will(&self) -> Option { + self.last_will.clone() + } + + pub fn set_transport(&mut self, transport: Transport) -> &mut Self { + self.transport = transport; + self + } + + pub fn transport(&self) -> Transport { + self.transport.clone() + } + + /// Set number of seconds after which client should ping the broker + /// if there is no other data exchange + pub fn set_keep_alive(&mut self, duration: Duration) -> &mut Self { + if duration.as_secs() < 5 { + panic!("Keep alives should be >= 5 secs"); + } + + self.keep_alive = duration; + self + } + + /// Keep alive time + pub fn keep_alive(&self) -> Duration { + self.keep_alive + } + + /// Client identifier + pub fn client_id(&self) -> String { + self.client_id.clone() + } + + /// Set packet size limit for outgoing an incoming packets + pub fn set_max_packet_size(&mut self, incoming: usize, outgoing: usize) -> &mut Self { + self.max_incoming_packet_size = incoming; + self.max_outgoing_packet_size = outgoing; + self + } + + /// Maximum packet size + pub fn max_packet_size(&self) -> usize { + self.max_incoming_packet_size + } + + /// `clean_session = true` removes all the state from queues & instructs the broker + /// to clean all the client state when client disconnects. + /// + /// When set `false`, broker will hold the client state and performs pending + /// operations on the client when reconnection with same `client_id` + /// happens. Local queue state is also held to retransmit packets after reconnection. + pub fn set_clean_session(&mut self, clean_session: bool) -> &mut Self { + self.clean_session = clean_session; + self + } + + /// Clean session + pub fn clean_session(&self) -> bool { + self.clean_session + } + + /// Username and password + pub fn set_credentials, P: Into>(&mut self, username: U, password: P) -> &mut Self { + self.credentials = Some((username.into(), password.into())); + self + } + + /// Security options + pub fn credentials(&self) -> Option<(String, String)> { + self.credentials.clone() + } + + /// Set request channel capacity + pub fn set_request_channel_capacity(&mut self, capacity: usize) -> &mut Self { + self.request_channel_capacity = capacity; + self + } + + /// Request channel capacity + pub fn request_channel_capacity(&self) -> usize { + self.request_channel_capacity + } + + /// Enables throttling and sets outoing message rate to the specified 'rate' + pub fn set_pending_throttle(&mut self, duration: Duration) -> &mut Self { + self.pending_throttle = duration; + self + } + + /// Outgoing message rate + pub fn pending_throttle(&self) -> Duration { + self.pending_throttle + } + + /// Set number of concurrent in flight messages + pub fn set_inflight(&mut self, inflight: u16) -> &mut Self { + if inflight == 0 { + panic!("zero in flight is not allowed") + } + + self.inflight = inflight; + self + } + + /// Number of concurrent in flight messages + pub fn inflight(&self) -> u16 { + self.inflight + } + + /// set connection timeout in secs + pub fn set_connection_timeout(&mut self, timeout: u64) -> &mut Self { + self.conn_timeout = timeout; + self + } + + /// get timeout in secs + pub fn connection_timeout(&self) -> u64 { + self.conn_timeout + } + + /// set manual acknowledgements + pub fn set_manual_acks(&mut self, manual_acks: bool) -> &mut Self { + self.manual_acks = manual_acks; + self + } + + /// get manual acknowledgements + pub fn manual_acks(&self) -> bool { + self.manual_acks + } +} + +#[cfg(feature = "url")] +#[derive(Debug, PartialEq, thiserror::Error)] +pub enum OptionError { + #[error("Unsupported URL scheme.")] + Scheme, + + #[error("Missing client ID.")] + ClientId, + + #[error("Invalid keep-alive value.")] + KeepAlive, + + #[error("Invalid clean-session value.")] + CleanSession, + + #[error("Invalid max-incoming-packet-size value.")] + MaxIncomingPacketSize, + + #[error("Invalid max-outgoing-packet-size value.")] + MaxOutgoingPacketSize, + + #[error("Invalid request-channel-capacity value.")] + RequestChannelCapacity, + + #[error("Invalid max-request-batch value.")] + MaxRequestBatch, + + #[error("Invalid pending-throttle value.")] + PendingThrottle, + + #[error("Invalid inflight value.")] + Inflight, + + #[error("Invalid conn-timeout value.")] + ConnTimeout, + + #[error("Unknown option: {0}")] + Unknown(String), +} + +#[cfg(feature = "url")] +impl std::convert::TryFrom for MqttOptions { + type Error = OptionError; + + fn try_from(url: url::Url) -> Result { + use std::collections::HashMap; + + let broker_addr = url.host_str().unwrap_or_default().to_owned(); + + let (transport, default_port) = match url.scheme() { + // Encrypted connections are supported, but require explicit TLS configuration. We fall + // back to the unencrypted transport layer, so that `set_transport` can be used to + // configure the encrypted transport layer with the provided TLS configuration. + "mqtts" | "ssl" => (Transport::Tcp, 8883), + "mqtt" | "tcp" => (Transport::Tcp, 1883), + _ => return Err(OptionError::Scheme), + }; + + let port = url.port().unwrap_or(default_port); + + let mut queries = url.query_pairs().collect::>(); + + let keep_alive = Duration::from_secs( + queries + .remove("keep_alive_secs") + .map(|v| v.parse::().map_err(|_| OptionError::KeepAlive)) + .transpose()? + .unwrap_or(60), + ); + + let client_id = queries + .remove("client_id") + .ok_or(OptionError::ClientId)? + .into_owned(); + + let clean_session = queries + .remove("clean_session") + .map(|v| v.parse::().map_err(|_| OptionError::CleanSession)) + .transpose()? + .unwrap_or(true); + + let credentials = { + match url.username() { + "" => None, + username => Some(( + username.to_owned(), + url.password().unwrap_or_default().to_owned(), + )), + } + }; + + let max_incoming_packet_size = queries + .remove("max_incoming_packet_size_bytes") + .map(|v| { + v.parse::() + .map_err(|_| OptionError::MaxIncomingPacketSize) + }) + .transpose()? + .unwrap_or(10 * 1024); + + let max_outgoing_packet_size = queries + .remove("max_outgoing_packet_size_bytes") + .map(|v| { + v.parse::() + .map_err(|_| OptionError::MaxOutgoingPacketSize) + }) + .transpose()? + .unwrap_or(10 * 1024); + + let request_channel_capacity = queries + .remove("request_channel_capacity_num") + .map(|v| { + v.parse::() + .map_err(|_| OptionError::RequestChannelCapacity) + }) + .transpose()? + .unwrap_or(10); + + let max_request_batch = queries + .remove("max_request_batch_num") + .map(|v| v.parse::().map_err(|_| OptionError::MaxRequestBatch)) + .transpose()? + .unwrap_or(0); + + let pending_throttle = Duration::from_micros( + queries + .remove("pending_throttle_usecs") + .map(|v| v.parse::().map_err(|_| OptionError::PendingThrottle)) + .transpose()? + .unwrap_or(0), + ); + + let inflight = queries + .remove("inflight_num") + .map(|v| v.parse::().map_err(|_| OptionError::Inflight)) + .transpose()? + .unwrap_or(100); + + let conn_timeout = queries + .remove("conn_timeout_secs") + .map(|v| v.parse::().map_err(|_| OptionError::ConnTimeout)) + .transpose()? + .unwrap_or(5); + + if let Some((opt, _)) = queries.into_iter().next() { + return Err(OptionError::Unknown(opt.into_owned())); + } + + Ok(Self { + broker_addr, + port, + transport, + keep_alive, + clean_session, + client_id, + credentials, + max_incoming_packet_size, + max_outgoing_packet_size, + request_channel_capacity, + max_request_batch, + pending_throttle, + inflight, + last_will: None, + conn_timeout, + manual_acks: false + }) + } +} + +// Implement Debug manually because ClientConfig doesn't implement it, so derive(Debug) doesn't +// work. +impl Debug for MqttOptions { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("MqttOptions") + .field("broker_addr", &self.broker_addr) + .field("port", &self.port) + .field("keep_alive", &self.keep_alive) + .field("clean_session", &self.clean_session) + .field("client_id", &self.client_id) + .field("credentials", &self.credentials) + .field("max_packet_size", &self.max_incoming_packet_size) + .field("request_channel_capacity", &self.request_channel_capacity) + .field("max_request_batch", &self.max_request_batch) + .field("pending_throttle", &self.pending_throttle) + .field("inflight", &self.inflight) + .field("last_will", &self.last_will) + .field("conn_timeout", &self.conn_timeout) + .field("manual_acks", &self.manual_acks) + .finish() + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + #[should_panic] + fn client_id_startswith_space() { + let _mqtt_opts = MqttOptions::new(" client_a", "127.0.0.1", 1883).set_clean_session(true); + } + + #[test] + #[cfg(feature = "websocket")] + fn no_scheme() { + let mut _mqtt_opts = MqttOptions::new("client_a", "a3f8czas.iot.eu-west-1.amazonaws.com/mqtt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=MyCreds%2F20201001%2Feu-west-1%2Fiotdevicegateway%2Faws4_request&X-Amz-Date=20201001T130812Z&X-Amz-Expires=7200&X-Amz-Signature=9ae09b49896f44270f2707551581953e6cac71a4ccf34c7c3415555be751b2d1&X-Amz-SignedHeaders=host", 443); + + _mqtt_opts.set_transport(crate::Transport::wss(Vec::from("Test CA"), None, None)); + + if let crate::Transport::Wss(TlsConfiguration::Simple { + ca, + client_auth, + alpn, + }) = _mqtt_opts.transport + { + assert_eq!(ca, Vec::from("Test CA")); + assert_eq!(client_auth, None); + assert_eq!(alpn, None); + } else { + panic!("Unexpected transport!"); + } + + assert_eq!(_mqtt_opts.broker_addr, "a3f8czas.iot.eu-west-1.amazonaws.com/mqtt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=MyCreds%2F20201001%2Feu-west-1%2Fiotdevicegateway%2Faws4_request&X-Amz-Date=20201001T130812Z&X-Amz-Expires=7200&X-Amz-Signature=9ae09b49896f44270f2707551581953e6cac71a4ccf34c7c3415555be751b2d1&X-Amz-SignedHeaders=host"); + } + + #[test] + #[cfg(feature = "url")] + fn from_url() { + use std::convert::TryInto; + use std::str::FromStr; + + fn opt(s: &str) -> Result { + url::Url::from_str(s).expect("valid url").try_into() + } + fn ok(s: &str) -> MqttOptions { + opt(s).expect("valid options") + } + fn err(s: &str) -> OptionError { + opt(s).expect_err("invalid options") + } + + let v = ok("mqtt://host:42?client_id=foo"); + assert_eq!(v.broker_address(), ("host".to_owned(), 42)); + assert_eq!(v.client_id(), "foo".to_owned()); + + let v = ok("mqtt://host:42?client_id=foo&keep_alive_secs=5"); + assert_eq!(v.keep_alive, Duration::from_secs(5)); + + assert_eq!(err("mqtt://host:42"), OptionError::ClientId); + assert_eq!( + err("mqtt://host:42?client_id=foo&foo=bar"), + OptionError::Unknown("foo".to_owned()) + ); + assert_eq!(err("mqt://host:42?client_id=foo"), OptionError::Scheme); + assert_eq!( + err("mqtt://host:42?client_id=foo&keep_alive_secs=foo"), + OptionError::KeepAlive + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&clean_session=foo"), + OptionError::CleanSession + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&max_incoming_packet_size_bytes=foo"), + OptionError::MaxIncomingPacketSize + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&max_outgoing_packet_size_bytes=foo"), + OptionError::MaxOutgoingPacketSize + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&request_channel_capacity_num=foo"), + OptionError::RequestChannelCapacity + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&max_request_batch_num=foo"), + OptionError::MaxRequestBatch + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&pending_throttle_usecs=foo"), + OptionError::PendingThrottle + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&inflight_num=foo"), + OptionError::Inflight + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&conn_timeout_secs=foo"), + OptionError::ConnTimeout + ); + } + + #[test] + #[should_panic] + fn no_client_id() { + let _mqtt_opts = MqttOptions::new("", "127.0.0.1", 1883).set_clean_session(true); + } +} diff --git a/rumqttc/src/v5/packet.rs b/rumqttc/src/v5/packet.rs new file mode 100644 index 000000000..e69de29bb diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs new file mode 100644 index 000000000..c51116ccf --- /dev/null +++ b/rumqttc/src/v5/state.rs @@ -0,0 +1,781 @@ +use crate::v4::{Event, Incoming, Outgoing, Request}; + +use bytes::BytesMut; +use mqttbytes::v4::*; +use mqttbytes::*; +use std::collections::VecDeque; +use std::{io, mem, time::Instant}; + +/// Errors during state handling +#[derive(Debug, thiserror::Error)] +pub enum StateError { + /// Io Error while state is passed to network + #[error("Io error {0:?}")] + Io(#[from] io::Error), + /// Broker's error reply to client's connect packet + #[error("Connect return code `{0:?}`")] + Connect(ConnectReturnCode), + /// Invalid state for a given operation + #[error("Invalid state for a given operation")] + InvalidState, + /// Received a packet (ack) which isn't asked for + #[error("Received unsolicited ack pkid {0}")] + Unsolicited(u16), + /// Last pingreq isn't acked + #[error("Last pingreq isn't acked")] + AwaitPingResp, + /// Received a wrong packet while waiting for another packet + #[error("Received a wrong packet while waiting for another packet")] + WrongPacket, + #[error("Timeout while waiting to resolve collision")] + CollisionTimeout, + #[error("Mqtt serialization/deserialization error")] + Deserialization(mqttbytes::Error), +} + +impl From for StateError { + fn from(e: mqttbytes::Error) -> StateError { + StateError::Deserialization(e) + } +} + +/// State of the mqtt connection. +// Design: Methods will just modify the state of the object without doing any network operations +// Design: All inflight queues are maintained in a pre initialized vec with index as packet id. +// This is done for 2 reasons +// Bad acks or out of order acks aren't O(n) causing cpu spikes +// Any missing acks from the broker are detected during the next recycled use of packet ids +#[derive(Debug, Clone)] +pub struct MqttState { + /// Status of last ping + pub await_pingresp: bool, + /// Collision ping count. Collisions stop user requests + /// which inturn trigger pings. Multiple pings without + /// resolving collisions will result in error + pub collision_ping_count: usize, + /// Last incoming packet time + last_incoming: Instant, + /// Last outgoing packet time + last_outgoing: Instant, + /// Packet id of the last outgoing packet + pub(crate) last_pkid: u16, + /// Number of outgoing inflight publishes + pub(crate) inflight: u16, + /// Maximum number of allowed inflight + pub(crate) max_inflight: u16, + /// Outgoing QoS 1, 2 publishes which aren't acked yet + pub(crate) outgoing_pub: Vec>, + /// Packet ids of released QoS 2 publishes + pub(crate) outgoing_rel: Vec>, + /// Packet ids on incoming QoS 2 publishes + pub(crate) incoming_pub: Vec>, + /// Last collision due to broker not acking in order + pub collision: Option, + /// Buffered incoming packets + pub events: VecDeque, + /// Write buffer + pub write: BytesMut, + /// Indicates if acknowledgements should be send immediately + pub manual_acks: bool, +} + +impl MqttState { + /// Creates new mqtt state. Same state should be used during a + /// connection for persistent sessions while new state should + /// instantiated for clean sessions + pub fn new(max_inflight: u16, manual_acks: bool) -> Self { + MqttState { + await_pingresp: false, + collision_ping_count: 0, + last_incoming: Instant::now(), + last_outgoing: Instant::now(), + last_pkid: 0, + inflight: 0, + max_inflight, + // index 0 is wasted as 0 is not a valid packet id + outgoing_pub: vec![None; max_inflight as usize + 1], + outgoing_rel: vec![None; max_inflight as usize + 1], + incoming_pub: vec![None; std::u16::MAX as usize + 1], + collision: None, + // TODO: Optimize these sizes later + events: VecDeque::with_capacity(100), + write: BytesMut::with_capacity(10 * 1024), + manual_acks + } + } + + /// Returns inflight outgoing packets and clears internal queues + pub fn clean(&mut self) -> Vec { + let mut pending = Vec::with_capacity(100); + // remove and collect pending publishes + for publish in self.outgoing_pub.iter_mut() { + if let Some(publish) = publish.take() { + let request = Request::Publish(publish); + pending.push(request); + } + } + + // remove and collect pending releases + for rel in self.outgoing_rel.iter_mut() { + if let Some(pkid) = rel.take() { + let request = Request::PubRel(PubRel::new(pkid)); + pending.push(request); + } + } + + // remove packed ids of incoming qos2 publishes + for id in self.incoming_pub.iter_mut() { + id.take(); + } + + self.await_pingresp = false; + self.collision_ping_count = 0; + self.inflight = 0; + pending + } + + pub fn inflight(&self) -> u16 { + self.inflight + } + + /// Consolidates handling of all outgoing mqtt packet logic. Returns a packet which should + /// be put on to the network by the eventloop + pub fn handle_outgoing_packet(&mut self, request: Request) -> Result<(), StateError> { + match request { + Request::Publish(publish) => self.outgoing_publish(publish)?, + Request::PubRel(pubrel) => self.outgoing_pubrel(pubrel)?, + Request::Subscribe(subscribe) => self.outgoing_subscribe(subscribe)?, + Request::Unsubscribe(unsubscribe) => self.outgoing_unsubscribe(unsubscribe)?, + Request::PingReq => self.outgoing_ping()?, + Request::Disconnect => self.outgoing_disconnect()?, + Request::PubAck(puback) => self.outgoing_puback(puback)?, + Request::PubRec(pubrec) => self.outgoing_pubrec(pubrec)?, + _ => unimplemented!(), + }; + + self.last_outgoing = Instant::now(); + Ok(()) + } + + /// Consolidates handling of all incoming mqtt packets. Returns a `Notification` which for the + /// user to consume and `Packet` which for the eventloop to put on the network + /// E.g For incoming QoS1 publish packet, this method returns (Publish, Puback). Publish packet will + /// be forwarded to user and Pubck packet will be written to network + pub fn handle_incoming_packet(&mut self, packet: Incoming) -> Result<(), StateError> { + let out = match &packet { + Incoming::PingResp => self.handle_incoming_pingresp(), + Incoming::Publish(publish) => self.handle_incoming_publish(publish), + Incoming::SubAck(_suback) => self.handle_incoming_suback(), + Incoming::UnsubAck(_unsuback) => self.handle_incoming_unsuback(), + Incoming::PubAck(puback) => self.handle_incoming_puback(puback), + Incoming::PubRec(pubrec) => self.handle_incoming_pubrec(pubrec), + Incoming::PubRel(pubrel) => self.handle_incoming_pubrel(pubrel), + Incoming::PubComp(pubcomp) => self.handle_incoming_pubcomp(pubcomp), + _ => { + error!("Invalid incoming packet = {:?}", packet); + return Err(StateError::WrongPacket); + } + }; + + out?; + self.events.push_back(Event::Incoming(packet)); + self.last_incoming = Instant::now(); + Ok(()) + } + + fn handle_incoming_suback(&mut self) -> Result<(), StateError> { + Ok(()) + } + + fn handle_incoming_unsuback(&mut self) -> Result<(), StateError> { + Ok(()) + } + + /// Results in a publish notification in all the QoS cases. Replys with an ack + /// in case of QoS1 and Replys rec in case of QoS while also storing the message + fn handle_incoming_publish(&mut self, publish: &Publish) -> Result<(), StateError> { + let qos = publish.qos; + + match qos { + QoS::AtMostOnce => Ok(()), + QoS::AtLeastOnce => { + if !self.manual_acks { + let puback = PubAck::new(publish.pkid); + self.outgoing_puback(puback)? + } + Ok(()) + } + QoS::ExactlyOnce => { + let pkid = publish.pkid; + self.incoming_pub[pkid as usize] = Some(pkid); + if !self.manual_acks { + let pubrec = PubRec::new(pkid); + self.outgoing_pubrec(pubrec)?; + } + Ok(()) + } + } + } + + fn handle_incoming_puback(&mut self, puback: &PubAck) -> Result<(), StateError> { + let v = match mem::replace(&mut self.outgoing_pub[puback.pkid as usize], None) { + Some(_) => { + self.inflight -= 1; + Ok(()) + } + None => { + error!("Unsolicited puback packet: {:?}", puback.pkid); + Err(StateError::Unsolicited(puback.pkid)) + } + }; + + if let Some(publish) = self.check_collision(puback.pkid) { + self.outgoing_pub[publish.pkid as usize] = Some(publish.clone()); + self.inflight += 1; + + publish.write(&mut self.write)?; + let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); + self.events.push_back(event); + self.collision_ping_count = 0; + } + + v + } + + fn handle_incoming_pubrec(&mut self, pubrec: &PubRec) -> Result<(), StateError> { + match mem::replace(&mut self.outgoing_pub[pubrec.pkid as usize], None) { + Some(_) => { + // NOTE: Inflight - 1 for qos2 in comp + self.outgoing_rel[pubrec.pkid as usize] = Some(pubrec.pkid); + PubRel::new(pubrec.pkid).write(&mut self.write)?; + + let event = Event::Outgoing(Outgoing::PubRel(pubrec.pkid)); + self.events.push_back(event); + Ok(()) + } + None => { + error!("Unsolicited pubrec packet: {:?}", pubrec.pkid); + Err(StateError::Unsolicited(pubrec.pkid)) + } + } + } + + fn handle_incoming_pubrel(&mut self, pubrel: &PubRel) -> Result<(), StateError> { + match mem::replace(&mut self.incoming_pub[pubrel.pkid as usize], None) { + Some(_) => { + PubComp::new(pubrel.pkid).write(&mut self.write)?; + let event = Event::Outgoing(Outgoing::PubComp(pubrel.pkid)); + self.events.push_back(event); + Ok(()) + } + None => { + error!("Unsolicited pubrel packet: {:?}", pubrel.pkid); + Err(StateError::Unsolicited(pubrel.pkid)) + } + } + } + + fn handle_incoming_pubcomp(&mut self, pubcomp: &PubComp) -> Result<(), StateError> { + if let Some(publish) = self.check_collision(pubcomp.pkid) { + publish.write(&mut self.write)?; + let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); + self.events.push_back(event); + self.collision_ping_count = 0; + } + + match mem::replace(&mut self.outgoing_rel[pubcomp.pkid as usize], None) { + Some(_) => { + self.inflight -= 1; + Ok(()) + } + None => { + error!("Unsolicited pubcomp packet: {:?}", pubcomp.pkid); + Err(StateError::Unsolicited(pubcomp.pkid)) + } + } + } + + fn handle_incoming_pingresp(&mut self) -> Result<(), StateError> { + self.await_pingresp = false; + Ok(()) + } + + /// Adds next packet identifier to QoS 1 and 2 publish packets and returns + /// it buy wrapping publish in packet + fn outgoing_publish(&mut self, mut publish: Publish) -> Result<(), StateError> { + if publish.qos != QoS::AtMostOnce { + if publish.pkid == 0 { + publish.pkid = self.next_pkid(); + } + + let pkid = publish.pkid; + if self + .outgoing_pub + .get(publish.pkid as usize) + .unwrap() + .is_some() + { + info!("Collision on packet id = {:?}", publish.pkid); + self.collision = Some(publish); + let event = Event::Outgoing(Outgoing::AwaitAck(pkid)); + self.events.push_back(event); + return Ok(()); + } + + // if there is an existing publish at this pkid, this implies that broker hasn't acked this + // packet yet. This error is possible only when broker isn't acking sequentially + self.outgoing_pub[pkid as usize] = Some(publish.clone()); + self.inflight += 1; + }; + + debug!( + "Publish. Topic = {}, Pkid = {:?}, Payload Size = {:?}", + publish.topic, + publish.pkid, + publish.payload.len() + ); + + publish.write(&mut self.write)?; + let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); + self.events.push_back(event); + Ok(()) + } + + fn outgoing_pubrel(&mut self, pubrel: PubRel) -> Result<(), StateError> { + let pubrel = self.save_pubrel(pubrel)?; + + debug!("Pubrel. Pkid = {}", pubrel.pkid); + PubRel::new(pubrel.pkid).write(&mut self.write)?; + + let event = Event::Outgoing(Outgoing::PubRel(pubrel.pkid)); + self.events.push_back(event); + Ok(()) + } + + fn outgoing_puback(&mut self, puback: PubAck) -> Result<(), StateError> { + puback.write(&mut self.write)?; + let event = Event::Outgoing(Outgoing::PubAck(puback.pkid)); + self.events.push_back(event); + Ok(()) + } + + fn outgoing_pubrec(&mut self, pubrec: PubRec) -> Result<(), StateError> { + pubrec.write(&mut self.write)?; + let event = Event::Outgoing(Outgoing::PubRec(pubrec.pkid)); + self.events.push_back(event); + Ok(()) + } + + /// check when the last control packet/pingreq packet is received and return + /// the status which tells if keep alive time has exceeded + /// NOTE: status will be checked for zero keepalive times also + fn outgoing_ping(&mut self) -> Result<(), StateError> { + let elapsed_in = self.last_incoming.elapsed(); + let elapsed_out = self.last_outgoing.elapsed(); + + if self.collision.is_some() { + self.collision_ping_count += 1; + if self.collision_ping_count >= 2 { + return Err(StateError::CollisionTimeout); + } + } + + // raise error if last ping didn't receive ack + if self.await_pingresp { + return Err(StateError::AwaitPingResp); + } + + self.await_pingresp = true; + + debug!( + "Pingreq, + last incoming packet before {} millisecs, + last outgoing request before {} millisecs", + elapsed_in.as_millis(), + elapsed_out.as_millis() + ); + + PingReq.write(&mut self.write)?; + let event = Event::Outgoing(Outgoing::PingReq); + self.events.push_back(event); + Ok(()) + } + + fn outgoing_subscribe(&mut self, mut subscription: Subscribe) -> Result<(), StateError> { + let pkid = self.next_pkid(); + subscription.pkid = pkid; + + debug!( + "Subscribe. Topics = {:?}, Pkid = {:?}", + subscription.filters, subscription.pkid + ); + + subscription.write(&mut self.write)?; + let event = Event::Outgoing(Outgoing::Subscribe(subscription.pkid)); + self.events.push_back(event); + Ok(()) + } + + fn outgoing_unsubscribe(&mut self, mut unsub: Unsubscribe) -> Result<(), StateError> { + let pkid = self.next_pkid(); + unsub.pkid = pkid; + + debug!( + "Unsubscribe. Topics = {:?}, Pkid = {:?}", + unsub.topics, unsub.pkid + ); + + unsub.write(&mut self.write)?; + let event = Event::Outgoing(Outgoing::Unsubscribe(unsub.pkid)); + self.events.push_back(event); + Ok(()) + } + + fn outgoing_disconnect(&mut self) -> Result<(), StateError> { + debug!("Disconnect"); + + Disconnect.write(&mut self.write)?; + let event = Event::Outgoing(Outgoing::Disconnect); + self.events.push_back(event); + Ok(()) + } + + fn check_collision(&mut self, pkid: u16) -> Option { + if let Some(publish) = &self.collision { + if publish.pkid == pkid { + return self.collision.take(); + } + } + + None + } + + fn save_pubrel(&mut self, mut pubrel: PubRel) -> Result { + let pubrel = match pubrel.pkid { + // consider PacketIdentifier(0) as uninitialized packets + 0 => { + pubrel.pkid = self.next_pkid(); + pubrel + } + _ => pubrel, + }; + + self.outgoing_rel[pubrel.pkid as usize] = Some(pubrel.pkid); + Ok(pubrel) + } + + /// http://stackoverflow.com/questions/11115364/mqtt-messageid-practical-implementation + /// Packet ids are incremented till maximum set inflight messages and reset to 1 after that. + /// + fn next_pkid(&mut self) -> u16 { + let next_pkid = self.last_pkid + 1; + + // When next packet id is at the edge of inflight queue, + // set await flag. This instructs eventloop to stop + // processing requests until all the inflight publishes + // are acked + if next_pkid == self.max_inflight { + self.last_pkid = 0; + return next_pkid; + } + + self.last_pkid = next_pkid; + next_pkid + } +} + +#[cfg(test)] +mod test { + use super::{MqttState, StateError}; + use crate::v4::{Event, Incoming, MqttOptions, Outgoing, Request}; + use mqttbytes::v4::*; + use mqttbytes::*; + + fn build_outgoing_publish(qos: QoS) -> Publish { + let topic = "hello/world".to_owned(); + let payload = vec![1, 2, 3]; + + let mut publish = Publish::new(topic, QoS::AtLeastOnce, payload); + publish.qos = qos; + publish + } + + fn build_incoming_publish(qos: QoS, pkid: u16) -> Publish { + let topic = "hello/world".to_owned(); + let payload = vec![1, 2, 3]; + + let mut publish = Publish::new(topic, QoS::AtLeastOnce, payload); + publish.pkid = pkid; + publish.qos = qos; + publish + } + + fn build_mqttstate() -> MqttState { + MqttState::new(100, false) + } + + #[test] + fn next_pkid_increments_as_expected() { + let mut mqtt = build_mqttstate(); + + for i in 1..=100 { + let pkid = mqtt.next_pkid(); + + // loops between 0-99. % 100 == 0 implies border + let expected = i % 100; + if expected == 0 { + break; + } + + assert_eq!(expected, pkid); + } + } + + #[test] + fn outgoing_publish_should_set_pkid_and_add_publish_to_queue() { + let mut mqtt = build_mqttstate(); + + // QoS0 Publish + let publish = build_outgoing_publish(QoS::AtMostOnce); + + // QoS 0 publish shouldn't be saved in queue + mqtt.outgoing_publish(publish).unwrap(); + assert_eq!(mqtt.last_pkid, 0); + assert_eq!(mqtt.inflight, 0); + + // QoS1 Publish + let publish = build_outgoing_publish(QoS::AtLeastOnce); + + // Packet id should be set and publish should be saved in queue + mqtt.outgoing_publish(publish.clone()).unwrap(); + assert_eq!(mqtt.last_pkid, 1); + assert_eq!(mqtt.inflight, 1); + + // Packet id should be incremented and publish should be saved in queue + mqtt.outgoing_publish(publish).unwrap(); + assert_eq!(mqtt.last_pkid, 2); + assert_eq!(mqtt.inflight, 2); + + // QoS1 Publish + let publish = build_outgoing_publish(QoS::ExactlyOnce); + + // Packet id should be set and publish should be saved in queue + mqtt.outgoing_publish(publish.clone()).unwrap(); + assert_eq!(mqtt.last_pkid, 3); + assert_eq!(mqtt.inflight, 3); + + // Packet id should be incremented and publish should be saved in queue + mqtt.outgoing_publish(publish).unwrap(); + assert_eq!(mqtt.last_pkid, 4); + assert_eq!(mqtt.inflight, 4); + } + + #[test] + fn incoming_publish_should_be_added_to_queue_correctly() { + let mut mqtt = build_mqttstate(); + + // QoS0, 1, 2 Publishes + let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); + let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); + let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); + + mqtt.handle_incoming_publish(&publish1).unwrap(); + mqtt.handle_incoming_publish(&publish2).unwrap(); + mqtt.handle_incoming_publish(&publish3).unwrap(); + + let pkid = mqtt.incoming_pub[3].unwrap(); + + // only qos2 publish should be add to queue + assert_eq!(pkid, 3); + } + + #[test] + fn incoming_publish_should_be_acked() { + let mut mqtt = build_mqttstate(); + + // QoS0, 1, 2 Publishes + let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); + let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); + let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); + + mqtt.handle_incoming_publish(&publish1).unwrap(); + mqtt.handle_incoming_publish(&publish2).unwrap(); + mqtt.handle_incoming_publish(&publish3).unwrap(); + + if let Event::Outgoing(Outgoing::PubAck(pkid)) = mqtt.events[0] { + assert_eq!(pkid, 2); + } else { + panic!("missing puback") + } + + if let Event::Outgoing(Outgoing::PubRec(pkid)) = mqtt.events[1] { + assert_eq!(pkid, 3); + } else { + panic!("missing PubRec") + } + } + + #[test] + fn incoming_publish_should_not_be_acked_with_manual_acks() { + let mut mqtt = build_mqttstate(); + mqtt.manual_acks = true; + + // QoS0, 1, 2 Publishes + let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); + let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); + let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); + + mqtt.handle_incoming_publish(&publish1).unwrap(); + mqtt.handle_incoming_publish(&publish2).unwrap(); + mqtt.handle_incoming_publish(&publish3).unwrap(); + + let pkid = mqtt.incoming_pub[3].unwrap(); + assert_eq!(pkid, 3); + + assert!(mqtt.events.is_empty()); + } + + #[test] + fn incoming_qos2_publish_should_send_rec_to_network_and_publish_to_user() { + let mut mqtt = build_mqttstate(); + let publish = build_incoming_publish(QoS::ExactlyOnce, 1); + + mqtt.handle_incoming_publish(&publish).unwrap(); + let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + match packet { + Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), + _ => panic!("Invalid network request: {:?}", packet), + } + } + + #[test] + fn incoming_puback_should_remove_correct_publish_from_queue() { + let mut mqtt = build_mqttstate(); + + let publish1 = build_outgoing_publish(QoS::AtLeastOnce); + let publish2 = build_outgoing_publish(QoS::ExactlyOnce); + + mqtt.outgoing_publish(publish1).unwrap(); + mqtt.outgoing_publish(publish2).unwrap(); + assert_eq!(mqtt.inflight, 2); + + mqtt.handle_incoming_puback(&PubAck::new(1)).unwrap(); + assert_eq!(mqtt.inflight, 1); + + mqtt.handle_incoming_puback(&PubAck::new(2)).unwrap(); + assert_eq!(mqtt.inflight, 0); + + assert!(mqtt.outgoing_pub[1].is_none()); + assert!(mqtt.outgoing_pub[2].is_none()); + } + + #[test] + fn incoming_pubrec_should_release_publish_from_queue_and_add_relid_to_rel_queue() { + let mut mqtt = build_mqttstate(); + + let publish1 = build_outgoing_publish(QoS::AtLeastOnce); + let publish2 = build_outgoing_publish(QoS::ExactlyOnce); + + let _publish_out = mqtt.outgoing_publish(publish1); + let _publish_out = mqtt.outgoing_publish(publish2); + + mqtt.handle_incoming_pubrec(&PubRec::new(2)).unwrap(); + assert_eq!(mqtt.inflight, 2); + + // check if the remaining element's pkid is 1 + let backup = mqtt.outgoing_pub[1].clone(); + assert_eq!(backup.unwrap().pkid, 1); + + // check if the qos2 element's release pkid is 2 + assert_eq!(mqtt.outgoing_rel[2].unwrap(), 2); + } + + #[test] + fn incoming_pubrec_should_send_release_to_network_and_nothing_to_user() { + let mut mqtt = build_mqttstate(); + + let publish = build_outgoing_publish(QoS::ExactlyOnce); + mqtt.outgoing_publish(publish).unwrap(); + let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + match packet { + Packet::Publish(publish) => assert_eq!(publish.pkid, 1), + packet => panic!("Invalid network request: {:?}", packet), + } + + mqtt.handle_incoming_pubrec(&PubRec::new(1)).unwrap(); + let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + match packet { + Packet::PubRel(pubrel) => assert_eq!(pubrel.pkid, 1), + packet => panic!("Invalid network request: {:?}", packet), + } + } + + #[test] + fn incoming_pubrel_should_send_comp_to_network_and_nothing_to_user() { + let mut mqtt = build_mqttstate(); + let publish = build_incoming_publish(QoS::ExactlyOnce, 1); + + mqtt.handle_incoming_publish(&publish).unwrap(); + let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + match packet { + Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), + packet => panic!("Invalid network request: {:?}", packet), + } + + mqtt.handle_incoming_pubrel(&PubRel::new(1)).unwrap(); + let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + match packet { + Packet::PubComp(pubcomp) => assert_eq!(pubcomp.pkid, 1), + packet => panic!("Invalid network request: {:?}", packet), + } + } + + #[test] + fn incoming_pubcomp_should_release_correct_pkid_from_release_queue() { + let mut mqtt = build_mqttstate(); + let publish = build_outgoing_publish(QoS::ExactlyOnce); + + mqtt.outgoing_publish(publish).unwrap(); + mqtt.handle_incoming_pubrec(&PubRec::new(1)).unwrap(); + + mqtt.handle_incoming_pubcomp(&PubComp::new(1)).unwrap(); + assert_eq!(mqtt.inflight, 0); + } + + #[test] + fn outgoing_ping_handle_should_throw_errors_for_no_pingresp() { + let mut mqtt = build_mqttstate(); + let mut opts = MqttOptions::new("test", "localhost", 1883); + opts.set_keep_alive(std::time::Duration::from_secs(10)); + mqtt.outgoing_ping().unwrap(); + + // network activity other than pingresp + let publish = build_outgoing_publish(QoS::AtLeastOnce); + mqtt.handle_outgoing_packet(Request::Publish(publish)) + .unwrap(); + mqtt.handle_incoming_packet(Incoming::PubAck(PubAck::new(1))) + .unwrap(); + + // should throw error because we didn't get pingresp for previous ping + match mqtt.outgoing_ping() { + Ok(_) => panic!("Should throw pingresp await error"), + Err(StateError::AwaitPingResp) => (), + Err(e) => panic!("Should throw pingresp await error. Error = {:?}", e), + } + } + + #[test] + fn outgoing_ping_handle_should_succeed_if_pingresp_is_received() { + let mut mqtt = build_mqttstate(); + + let mut opts = MqttOptions::new("test", "localhost", 1883); + opts.set_keep_alive(std::time::Duration::from_secs(10)); + + // should ping + mqtt.outgoing_ping().unwrap(); + mqtt.handle_incoming_packet(Incoming::PingResp).unwrap(); + + // should ping + mqtt.outgoing_ping().unwrap(); + } +} diff --git a/rumqttc/src/v5/tls.rs b/rumqttc/src/v5/tls.rs new file mode 100644 index 000000000..ea1809e23 --- /dev/null +++ b/rumqttc/src/v5/tls.rs @@ -0,0 +1,130 @@ +use tokio::net::TcpStream; +use tokio_rustls::rustls; +use tokio_rustls::rustls::client::InvalidDnsNameError; +use tokio_rustls::rustls::{ + Certificate, ClientConfig, OwnedTrustAnchor, PrivateKey, RootCertStore, ServerName, +}; +use tokio_rustls::webpki; +use tokio_rustls::{client::TlsStream, TlsConnector}; + +use crate::v4::{Key, MqttOptions, TlsConfiguration}; + +use std::convert::TryFrom; +use std::io; +use std::io::{BufReader, Cursor}; +use std::net::AddrParseError; +use std::sync::Arc; + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("Addr")] + Addr(#[from] AddrParseError), + #[error("I/O")] + Io(#[from] io::Error), + #[error("Web Pki")] + WebPki(#[from] webpki::Error), + #[error("DNS name")] + DNSName(#[from] InvalidDnsNameError), + #[error("TLS error")] + TLS(#[from] rustls::Error), + #[error("No valid cert in chain")] + NoValidCertInChain, +} + +// The cert handling functions return unit right now, this is a shortcut +impl From<()> for Error { + fn from(_: ()) -> Self { + Error::NoValidCertInChain + } +} + +pub async fn tls_connector(tls_config: &TlsConfiguration) -> Result { + let config = match tls_config { + TlsConfiguration::Simple { + ca, + alpn, + client_auth, + } => { + // Add ca to root store if the connection is TLS + let mut root_cert_store = RootCertStore::empty(); + let certs = rustls_pemfile::certs(&mut BufReader::new(Cursor::new(ca)))?; + + let trust_anchors = certs.iter().map_while(|cert| { + if let Ok(ta) = webpki::TrustAnchor::try_from_cert_der(&cert[..]) { + Some(OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + )) + } else { + None + } + }); + + root_cert_store.add_server_trust_anchors(trust_anchors); + + if root_cert_store.is_empty() { + return Err(Error::NoValidCertInChain); + } + + let config = ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_cert_store); + + // Add der encoded client cert and key + let mut config = if let Some(client) = client_auth.as_ref() { + let certs = + rustls_pemfile::certs(&mut BufReader::new(Cursor::new(client.0.clone())))?; + // load appropriate Key as per the user request. The underlying signature algorithm + // of key generation determines the Signature Algorithm during the TLS Handskahe. + let read_keys = match &client.1 { + Key::RSA(k) => rustls_pemfile::rsa_private_keys(&mut BufReader::new( + Cursor::new(k.clone()), + )), + Key::ECC(k) => rustls_pemfile::pkcs8_private_keys(&mut BufReader::new( + Cursor::new(k.clone()), + )), + }; + let keys = match read_keys { + Ok(v) => v, + Err(_e) => return Err(Error::NoValidCertInChain), + }; + + // Get the first key. Error if it's not valid + let key = match keys.first() { + Some(k) => k.clone(), + None => return Err(Error::NoValidCertInChain), + }; + + let certs = certs.into_iter().map(|cert| Certificate(cert)).collect(); + + config.with_single_cert(certs, PrivateKey(key))? + } else { + config.with_no_client_auth() + }; + + // Set ALPN + if let Some(alpn) = alpn.as_ref() { + config.alpn_protocols.extend_from_slice(alpn); + } + + Arc::new(config) + } + TlsConfiguration::Rustls(tls_client_config) => tls_client_config.clone(), + }; + + Ok(TlsConnector::from(config)) +} + +pub async fn tls_connect( + options: &MqttOptions, + tls_config: &TlsConfiguration, +) -> Result, Error> { + let addr = options.broker_addr.as_str(); + let port = options.port; + let connector = tls_connector(tls_config).await?; + let domain = ServerName::try_from(addr)?; + let tcp = TcpStream::connect((addr, port)).await?; + let tls = connector.connect(domain, tcp).await?; + Ok(tls) +} From e94000f62a699e620f1233d899f9ab20dfc0a8ab Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Fri, 11 Feb 2022 18:25:44 +0530 Subject: [PATCH 03/41] rumqttc: v5: packet: copy code from mqttbytes/v5 Signed-off-by: Abhik Jain --- rumqttc/Cargo.toml | 1 + rumqttc/src/v5/packet.rs | 0 rumqttc/src/v5/packet/connack.rs | 553 ++++++++++++++++ rumqttc/src/v5/packet/connect.rs | 928 +++++++++++++++++++++++++++ rumqttc/src/v5/packet/disconnect.rs | 434 +++++++++++++ rumqttc/src/v5/packet/mod.rs | 489 ++++++++++++++ rumqttc/src/v5/packet/ping.rs | 20 + rumqttc/src/v5/packet/puback.rs | 324 ++++++++++ rumqttc/src/v5/packet/pubcomp.rs | 237 +++++++ rumqttc/src/v5/packet/publish.rs | 394 ++++++++++++ rumqttc/src/v5/packet/pubrec.rs | 252 ++++++++ rumqttc/src/v5/packet/pubrel.rs | 236 +++++++ rumqttc/src/v5/packet/suback.rs | 263 ++++++++ rumqttc/src/v5/packet/subscribe.rs | 425 ++++++++++++ rumqttc/src/v5/packet/unsuback.rs | 249 +++++++ rumqttc/src/v5/packet/unsubscribe.rs | 238 +++++++ 16 files changed, 5043 insertions(+) delete mode 100644 rumqttc/src/v5/packet.rs create mode 100644 rumqttc/src/v5/packet/connack.rs create mode 100644 rumqttc/src/v5/packet/connect.rs create mode 100644 rumqttc/src/v5/packet/disconnect.rs create mode 100644 rumqttc/src/v5/packet/mod.rs create mode 100644 rumqttc/src/v5/packet/ping.rs create mode 100644 rumqttc/src/v5/packet/puback.rs create mode 100644 rumqttc/src/v5/packet/pubcomp.rs create mode 100644 rumqttc/src/v5/packet/publish.rs create mode 100644 rumqttc/src/v5/packet/pubrec.rs create mode 100644 rumqttc/src/v5/packet/pubrel.rs create mode 100644 rumqttc/src/v5/packet/suback.rs create mode 100644 rumqttc/src/v5/packet/subscribe.rs create mode 100644 rumqttc/src/v5/packet/unsuback.rs create mode 100644 rumqttc/src/v5/packet/unsubscribe.rs diff --git a/rumqttc/Cargo.toml b/rumqttc/Cargo.toml index 7e8a89b89..81a7b05b7 100644 --- a/rumqttc/Cargo.toml +++ b/rumqttc/Cargo.toml @@ -43,3 +43,4 @@ tokio = { version = "1.0", features = ["full", "macros"] } matches = "0.1.8" rustls = "0.20.2" rustls-native-certs = "0.6.1" +pretty_assertions = "1.1.0" diff --git a/rumqttc/src/v5/packet.rs b/rumqttc/src/v5/packet.rs deleted file mode 100644 index e69de29bb..000000000 diff --git a/rumqttc/src/v5/packet/connack.rs b/rumqttc/src/v5/packet/connack.rs new file mode 100644 index 000000000..f0e0ebaee --- /dev/null +++ b/rumqttc/src/v5/packet/connack.rs @@ -0,0 +1,553 @@ +use super::*; +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +/// Return code in connack +#[derive(Debug, Clone, Copy, PartialEq)] +#[repr(u8)] +pub enum ConnectReturnCode { + Success = 0, + UnspecifiedError = 128, + MalformedPacket = 129, + ProtocolError = 130, + ImplementationSpecificError = 131, + UnsupportedProtocolVersion = 132, + ClientIdentifierNotValid = 133, + BadUserNamePassword = 134, + NotAuthorized = 135, + ServerUnavailable = 136, + ServerBusy = 137, + Banned = 138, + BadAuthenticationMethod = 140, + TopicNameInvalid = 144, + PacketTooLarge = 149, + QuotaExceeded = 151, + PayloadFormatInvalid = 153, + RetainNotSupported = 154, + QoSNotSupported = 155, + UseAnotherServer = 156, + ServerMoved = 157, + ConnectionRateExceeded = 159, +} + +/// Acknowledgement to connect packet +#[derive(Debug, Clone, PartialEq)] +pub struct ConnAck { + pub session_present: bool, + pub code: ConnectReturnCode, + pub properties: Option, +} + +impl ConnAck { + pub fn new(code: ConnectReturnCode, session_present: bool) -> ConnAck { + ConnAck { + code, + session_present, + properties: None, + } + } + + fn len(&self) -> usize { + let mut len = 1 // session present + + 1; // code + + if let Some(properties) = &self.properties { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + let flags = read_u8(&mut bytes)?; + let return_code = read_u8(&mut bytes)?; + + let session_present = (flags & 0x01) == 1; + let code = connect_return(return_code)?; + let connack = ConnAck { + session_present, + code, + properties: ConnAckProperties::extract(&mut bytes)?, + }; + + Ok(connack) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + let len = self.len(); + buffer.put_u8(0x20); + + let count = write_remaining_length(buffer, len)?; + buffer.put_u8(self.session_present as u8); + buffer.put_u8(self.code as u8); + + if let Some(properties) = &self.properties { + properties.write(buffer)?; + } + + Ok(1 + count + len) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct ConnAckProperties { + pub session_expiry_interval: Option, + pub receive_max: Option, + pub max_qos: Option, + pub retain_available: Option, + pub max_packet_size: Option, + pub assigned_client_identifier: Option, + pub topic_alias_max: Option, + pub reason_string: Option, + pub user_properties: Vec<(String, String)>, + pub wildcard_subscription_available: Option, + pub subscription_identifiers_available: Option, + pub shared_subscription_available: Option, + pub server_keep_alive: Option, + pub response_information: Option, + pub server_reference: Option, + pub authentication_method: Option, + pub authentication_data: Option, +} + +impl ConnAckProperties { + pub fn new() -> ConnAckProperties { + ConnAckProperties { + session_expiry_interval: None, + receive_max: None, + max_qos: None, + retain_available: None, + max_packet_size: None, + assigned_client_identifier: None, + topic_alias_max: None, + reason_string: None, + user_properties: Vec::new(), + wildcard_subscription_available: None, + subscription_identifiers_available: None, + shared_subscription_available: None, + server_keep_alive: None, + response_information: None, + server_reference: None, + authentication_method: None, + authentication_data: None, + } + } + + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(_) = &self.session_expiry_interval { + len += 1 + 4; + } + + if let Some(_) = &self.receive_max { + len += 1 + 2; + } + + if let Some(_) = &self.max_qos { + len += 1 + 1; + } + + if let Some(_) = &self.retain_available { + len += 1 + 1; + } + + if let Some(_) = &self.max_packet_size { + len += 1 + 4; + } + + if let Some(id) = &self.assigned_client_identifier { + len += 1 + 2 + id.len(); + } + + if let Some(_) = &self.topic_alias_max { + len += 1 + 2; + } + + if let Some(reason) = &self.reason_string { + len += 1 + 2 + reason.len(); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + if let Some(_) = &self.wildcard_subscription_available { + len += 1 + 1; + } + + if let Some(_) = &self.subscription_identifiers_available { + len += 1 + 1; + } + + if let Some(_) = &self.shared_subscription_available { + len += 1 + 1; + } + + if let Some(_) = &self.server_keep_alive { + len += 1 + 2; + } + + if let Some(info) = &self.response_information { + len += 1 + 2 + info.len(); + } + + if let Some(reference) = &self.server_reference { + len += 1 + 2 + reference.len(); + } + + if let Some(authentication_method) = &self.authentication_method { + len += 1 + 2 + authentication_method.len(); + } + + if let Some(authentication_data) = &self.authentication_data { + len += 1 + 2 + authentication_data.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut session_expiry_interval = None; + let mut receive_max = None; + let mut max_qos = None; + let mut retain_available = None; + let mut max_packet_size = None; + let mut assigned_client_identifier = None; + let mut topic_alias_max = None; + let mut reason_string = None; + let mut user_properties = Vec::new(); + let mut wildcard_subscription_available = None; + let mut subscription_identifiers_available = None; + let mut shared_subscription_available = None; + let mut server_keep_alive = None; + let mut response_information = None; + let mut server_reference = None; + let mut authentication_method = None; + let mut authentication_data = None; + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::SessionExpiryInterval => { + session_expiry_interval = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::ReceiveMaximum => { + receive_max = Some(read_u16(&mut bytes)?); + cursor += 2; + } + PropertyType::MaximumQos => { + max_qos = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::RetainAvailable => { + retain_available = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::AssignedClientIdentifier => { + let id = read_mqtt_string(&mut bytes)?; + cursor += 2 + id.len(); + assigned_client_identifier = Some(id); + } + PropertyType::MaximumPacketSize => { + max_packet_size = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::TopicAliasMaximum => { + topic_alias_max = Some(read_u16(&mut bytes)?); + cursor += 2; + } + PropertyType::ReasonString => { + let reason = read_mqtt_string(&mut bytes)?; + cursor += 2 + reason.len(); + reason_string = Some(reason); + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + PropertyType::WildcardSubscriptionAvailable => { + wildcard_subscription_available = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::SubscriptionIdentifierAvailable => { + subscription_identifiers_available = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::SharedSubscriptionAvailable => { + shared_subscription_available = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::ServerKeepAlive => { + server_keep_alive = Some(read_u16(&mut bytes)?); + cursor += 2; + } + PropertyType::ResponseInformation => { + let info = read_mqtt_string(&mut bytes)?; + cursor += 2 + info.len(); + response_information = Some(info); + } + PropertyType::ServerReference => { + let reference = read_mqtt_string(&mut bytes)?; + cursor += 2 + reference.len(); + server_reference = Some(reference); + } + PropertyType::AuthenticationMethod => { + let method = read_mqtt_string(&mut bytes)?; + cursor += 2 + method.len(); + authentication_method = Some(method); + } + PropertyType::AuthenticationData => { + let data = read_mqtt_bytes(&mut bytes)?; + cursor += 2 + data.len(); + authentication_data = Some(data); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(ConnAckProperties { + session_expiry_interval, + receive_max, + max_qos, + retain_available, + max_packet_size, + assigned_client_identifier, + topic_alias_max, + reason_string, + user_properties, + wildcard_subscription_available, + subscription_identifiers_available, + shared_subscription_available, + server_keep_alive, + response_information, + server_reference, + authentication_method, + authentication_data, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(session_expiry_interval) = self.session_expiry_interval { + buffer.put_u8(PropertyType::SessionExpiryInterval as u8); + buffer.put_u32(session_expiry_interval); + } + + if let Some(receive_maximum) = self.receive_max { + buffer.put_u8(PropertyType::ReceiveMaximum as u8); + buffer.put_u16(receive_maximum); + } + + if let Some(qos) = self.max_qos { + buffer.put_u8(PropertyType::MaximumQos as u8); + buffer.put_u8(qos); + } + + if let Some(retain_available) = self.retain_available { + buffer.put_u8(PropertyType::RetainAvailable as u8); + buffer.put_u8(retain_available); + } + + if let Some(max_packet_size) = self.max_packet_size { + buffer.put_u8(PropertyType::MaximumPacketSize as u8); + buffer.put_u32(max_packet_size); + } + + if let Some(id) = &self.assigned_client_identifier { + buffer.put_u8(PropertyType::AssignedClientIdentifier as u8); + write_mqtt_string(buffer, id); + } + + if let Some(topic_alias_max) = self.topic_alias_max { + buffer.put_u8(PropertyType::TopicAliasMaximum as u8); + buffer.put_u16(topic_alias_max); + } + + if let Some(reason) = &self.reason_string { + buffer.put_u8(PropertyType::ReasonString as u8); + write_mqtt_string(buffer, reason); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + if let Some(w) = self.wildcard_subscription_available { + buffer.put_u8(PropertyType::WildcardSubscriptionAvailable as u8); + buffer.put_u8(w); + } + + if let Some(s) = self.subscription_identifiers_available { + buffer.put_u8(PropertyType::SubscriptionIdentifierAvailable as u8); + buffer.put_u8(s); + } + + if let Some(s) = self.shared_subscription_available { + buffer.put_u8(PropertyType::SharedSubscriptionAvailable as u8); + buffer.put_u8(s); + } + + if let Some(keep_alive) = self.server_keep_alive { + buffer.put_u8(PropertyType::ServerKeepAlive as u8); + buffer.put_u16(keep_alive); + } + + if let Some(info) = &self.response_information { + buffer.put_u8(PropertyType::ResponseInformation as u8); + write_mqtt_string(buffer, info); + } + + if let Some(reference) = &self.server_reference { + buffer.put_u8(PropertyType::ServerReference as u8); + write_mqtt_string(buffer, reference); + } + + if let Some(authentication_method) = &self.authentication_method { + buffer.put_u8(PropertyType::AuthenticationMethod as u8); + write_mqtt_string(buffer, authentication_method); + } + + if let Some(authentication_data) = &self.authentication_data { + buffer.put_u8(PropertyType::AuthenticationData as u8); + write_mqtt_bytes(buffer, authentication_data); + } + + Ok(()) + } +} + +/// Connection return code type +fn connect_return(num: u8) -> Result { + let code = match num { + 0 => ConnectReturnCode::Success, + 128 => ConnectReturnCode::UnspecifiedError, + 129 => ConnectReturnCode::MalformedPacket, + 130 => ConnectReturnCode::ProtocolError, + 131 => ConnectReturnCode::ImplementationSpecificError, + 132 => ConnectReturnCode::UnsupportedProtocolVersion, + 133 => ConnectReturnCode::ClientIdentifierNotValid, + 134 => ConnectReturnCode::BadUserNamePassword, + 135 => ConnectReturnCode::NotAuthorized, + 136 => ConnectReturnCode::ServerUnavailable, + 137 => ConnectReturnCode::ServerBusy, + 138 => ConnectReturnCode::Banned, + 140 => ConnectReturnCode::BadAuthenticationMethod, + 144 => ConnectReturnCode::TopicNameInvalid, + 149 => ConnectReturnCode::PacketTooLarge, + 151 => ConnectReturnCode::QuotaExceeded, + 153 => ConnectReturnCode::PayloadFormatInvalid, + 154 => ConnectReturnCode::RetainNotSupported, + 155 => ConnectReturnCode::QoSNotSupported, + 156 => ConnectReturnCode::UseAnotherServer, + 157 => ConnectReturnCode::ServerMoved, + 159 => ConnectReturnCode::ConnectionRateExceeded, + num => return Err(Error::InvalidConnectReturnCode(num)), + }; + + Ok(code) +} + +#[cfg(test)] +mod test { + use super::*; + use bytes::{Bytes, BytesMut}; + use pretty_assertions::assert_eq; + + fn sample() -> ConnAck { + let properties = ConnAckProperties { + session_expiry_interval: Some(1234), + receive_max: Some(432), + max_qos: Some(2), + retain_available: Some(1), + max_packet_size: Some(100), + assigned_client_identifier: Some("test".to_owned()), + topic_alias_max: Some(456), + reason_string: Some("test".to_owned()), + user_properties: vec![("test".to_owned(), "test".to_owned())], + wildcard_subscription_available: Some(1), + subscription_identifiers_available: Some(1), + shared_subscription_available: Some(0), + server_keep_alive: Some(1234), + response_information: Some("test".to_owned()), + server_reference: Some("test".to_owned()), + authentication_method: Some("test".to_owned()), + authentication_data: Some(Bytes::from(vec![1, 2, 3, 4])), + }; + + ConnAck { + session_present: false, + code: ConnectReturnCode::Success, + properties: Some(properties), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0x20, // Packet type + 0x57, // Remaining length + 0x00, 0x00, // Session, code + 0x54, // Properties length + 0x11, 0x00, 0x00, 0x04, 0xd2, // Session expiry interval + 0x21, 0x01, 0xb0, // Receive maximum + 0x24, 0x02, // Maximum qos + 0x25, 0x01, // Retain available + 0x27, 0x00, 0x00, 0x00, 0x64, // Maximum packet size + 0x12, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // Assigned client identifier + 0x22, 0x01, 0xc8, // Topic alias max + 0x1f, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // Reason string + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // user properties + 0x28, 0x01, // wildcard_subscription_available + 0x29, 0x01, // subscription_identifiers_available + 0x2a, 0x00, // shared_subscription_available + 0x13, 0x04, 0xd2, // server keep_alive + 0x1a, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // response_information + 0x1c, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // server reference + 0x15, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // authentication method + 0x16, 0x00, 0x04, 0x01, 0x02, 0x03, 0x04, // authentication data + ] + } + + #[test] + fn connack_parsing_works() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let connack_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let connack = ConnAck::read(fixed_header, connack_bytes).unwrap(); + + assert_eq!(connack, sample()); + } + + #[test] + fn connack_encoding_works() { + let connack = sample(); + let mut buf = BytesMut::new(); + connack.write(&mut buf).unwrap(); + assert_eq!(&buf[..], sample_bytes()); + } +} diff --git a/rumqttc/src/v5/packet/connect.rs b/rumqttc/src/v5/packet/connect.rs new file mode 100644 index 000000000..1c630ae79 --- /dev/null +++ b/rumqttc/src/v5/packet/connect.rs @@ -0,0 +1,928 @@ +use super::*; +use bytes::{Buf, Bytes}; + +/// Connection packet initiated by the client +#[derive(Debug, Clone, PartialEq)] +pub struct Connect { + /// Mqtt protocol version + pub protocol: Protocol, + /// Mqtt keep alive time + pub keep_alive: u16, + /// Client Id + pub client_id: String, + /// Clean session. Asks the broker to clear previous state + pub clean_session: bool, + /// Will that broker needs to publish when the client disconnects + pub last_will: Option, + /// Login credentials + pub login: Option, + /// Properties + pub properties: Option, +} + +impl Connect { + pub fn new>(id: S) -> Connect { + Connect { + protocol: Protocol::V5, + keep_alive: 10, + properties: None, + client_id: id.into(), + clean_session: true, + last_will: None, + login: None, + } + } + + pub fn set_login, P: Into>(&mut self, u: U, p: P) -> &mut Connect { + let login = Login { + username: u.into(), + password: p.into(), + }; + + self.login = Some(login); + self + } + + pub fn len(&self) -> usize { + let mut len = 2 + "MQTT".len() // protocol name + + 1 // protocol version + + 1 // connect flags + + 2; // keep alive + + match &self.properties { + Some(properties) => { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + None => { + // just 1 byte representing 0 len + len += 1; + } + } + + len += 2 + self.client_id.len(); + + // last will len + if let Some(last_will) = &self.last_will { + len += last_will.len(); + } + + // username and password len + if let Some(login) = &self.login { + len += login.len(); + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + // Variable header + let protocol_name = read_mqtt_string(&mut bytes)?; + let protocol_level = read_u8(&mut bytes)?; + if protocol_name != "MQTT" { + return Err(Error::InvalidProtocol); + } + + let protocol = match protocol_level { + 4 => Protocol::V4, + 5 => Protocol::V5, + num => return Err(Error::InvalidProtocolLevel(num)), + }; + + let connect_flags = read_u8(&mut bytes)?; + let clean_session = (connect_flags & 0b10) != 0; + let keep_alive = read_u16(&mut bytes)?; + + // Properties in variable header + let properties = match protocol { + Protocol::V5 => ConnectProperties::read(&mut bytes)?, + Protocol::V4 => None, + }; + + let client_id = read_mqtt_string(&mut bytes)?; + let last_will = LastWill::read(connect_flags, &mut bytes)?; + let login = Login::read(connect_flags, &mut bytes)?; + + let connect = Connect { + protocol, + keep_alive, + client_id, + clean_session, + last_will, + login, + properties, + }; + + Ok(connect) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + let len = self.len(); + buffer.put_u8(0b0001_0000); + let count = write_remaining_length(buffer, len)?; + write_mqtt_string(buffer, "MQTT"); + + match self.protocol { + Protocol::V4 => buffer.put_u8(0x04), + Protocol::V5 => buffer.put_u8(0x05), + } + + let flags_index = 1 + count + 2 + 4 + 1; + + let mut connect_flags = 0; + if self.clean_session { + connect_flags |= 0x02; + } + + buffer.put_u8(connect_flags); + buffer.put_u16(self.keep_alive); + + match &self.properties { + Some(properties) => properties.write(buffer)?, + None => { + write_remaining_length(buffer, 0)?; + } + }; + + write_mqtt_string(buffer, &self.client_id); + + if let Some(last_will) = &self.last_will { + connect_flags |= last_will.write(buffer)?; + } + + if let Some(login) = &self.login { + connect_flags |= login.write(buffer); + } + + // update connect flags + buffer[flags_index] = connect_flags; + Ok(len) + } +} + +/// LastWill that broker forwards on behalf of the client +#[derive(Debug, Clone, PartialEq)] +pub struct LastWill { + pub topic: String, + pub message: Bytes, + pub qos: QoS, + pub retain: bool, + pub properties: Option, +} + +impl LastWill { + pub fn new( + topic: impl Into, + payload: impl Into>, + qos: QoS, + retain: bool, + ) -> LastWill { + LastWill { + topic: topic.into(), + message: Bytes::from(payload.into()), + qos, + retain, + properties: None, + } + } + + fn len(&self) -> usize { + let mut len = 0; + + match &self.properties { + Some(properties) => { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + None => { + // just 1 byte representing 0 len + len += 1; + } + }; + + len += 2 + self.topic.len() + 2 + self.message.len(); + len + } + + fn read(connect_flags: u8, mut bytes: &mut Bytes) -> Result, Error> { + let last_will = match connect_flags & 0b100 { + 0 if (connect_flags & 0b0011_1000) != 0 => { + return Err(Error::IncorrectPacketFormat); + } + 0 => None, + _ => { + // Properties in variable header + let properties = WillProperties::read(&mut bytes)?; + + let will_topic = read_mqtt_string(&mut bytes)?; + let will_message = read_mqtt_bytes(&mut bytes)?; + let will_qos = qos((connect_flags & 0b11000) >> 3)?; + Some(LastWill { + topic: will_topic, + message: will_message, + qos: will_qos, + retain: (connect_flags & 0b0010_0000) != 0, + properties, + }) + } + }; + + Ok(last_will) + } + + fn write(&self, buffer: &mut BytesMut) -> Result { + let mut connect_flags = 0; + + connect_flags |= 0x04 | (self.qos as u8) << 3; + if self.retain { + connect_flags |= 0x20; + } + + match &self.properties { + Some(properties) => properties.write(buffer)?, + None => { + write_remaining_length(buffer, 0)?; + } + }; + + write_mqtt_string(buffer, &self.topic); + write_mqtt_bytes(buffer, &self.message); + Ok(connect_flags) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct WillProperties { + pub delay_interval: Option, + pub payload_format_indicator: Option, + pub message_expiry_interval: Option, + pub content_type: Option, + pub response_topic: Option, + pub correlation_data: Option, + pub user_properties: Vec<(String, String)>, +} + +impl WillProperties { + fn len(&self) -> usize { + let mut len = 0; + + if self.delay_interval.is_some() { + len += 1 + 4; + } + + if self.payload_format_indicator.is_some() { + len += 1 + 1; + } + + if self.message_expiry_interval.is_some() { + len += 1 + 4; + } + + if let Some(typ) = &self.content_type { + len += 1 + 2 + typ.len() + } + + if let Some(topic) = &self.response_topic { + len += 1 + 2 + topic.len() + } + + if let Some(data) = &self.correlation_data { + len += 1 + 2 + data.len() + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + fn read(mut bytes: &mut Bytes) -> Result, Error> { + let mut delay_interval = None; + let mut payload_format_indicator = None; + let mut message_expiry_interval = None; + let mut content_type = None; + let mut response_topic = None; + let mut correlation_data = None; + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::WillDelayInterval => { + delay_interval = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::PayloadFormatIndicator => { + payload_format_indicator = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::MessageExpiryInterval => { + message_expiry_interval = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::ContentType => { + let typ = read_mqtt_string(&mut bytes)?; + cursor += 2 + typ.len(); + content_type = Some(typ); + } + PropertyType::ResponseTopic => { + let topic = read_mqtt_string(&mut bytes)?; + cursor += 2 + topic.len(); + response_topic = Some(topic); + } + PropertyType::CorrelationData => { + let data = read_mqtt_bytes(&mut bytes)?; + cursor += 2 + data.len(); + correlation_data = Some(data); + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(WillProperties { + delay_interval, + payload_format_indicator, + message_expiry_interval, + content_type, + response_topic, + correlation_data, + user_properties, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(delay_interval) = self.delay_interval { + buffer.put_u8(PropertyType::WillDelayInterval as u8); + buffer.put_u32(delay_interval); + } + + if let Some(payload_format_indicator) = self.payload_format_indicator { + buffer.put_u8(PropertyType::PayloadFormatIndicator as u8); + buffer.put_u8(payload_format_indicator); + } + + if let Some(message_expiry_interval) = self.message_expiry_interval { + buffer.put_u8(PropertyType::MessageExpiryInterval as u8); + buffer.put_u32(message_expiry_interval); + } + + if let Some(typ) = &self.content_type { + buffer.put_u8(PropertyType::ContentType as u8); + write_mqtt_string(buffer, typ); + } + + if let Some(topic) = &self.response_topic { + buffer.put_u8(PropertyType::ResponseTopic as u8); + write_mqtt_string(buffer, topic); + } + + if let Some(data) = &self.correlation_data { + buffer.put_u8(PropertyType::CorrelationData as u8); + write_mqtt_bytes(buffer, data); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct Login { + pub username: String, + pub password: String, +} + +impl Login { + pub fn new, P: Into>(u: U, p: P) -> Login { + Login { + username: u.into(), + password: p.into(), + } + } + + fn read(connect_flags: u8, mut bytes: &mut Bytes) -> Result, Error> { + let username = match connect_flags & 0b1000_0000 { + 0 => String::new(), + _ => read_mqtt_string(&mut bytes)?, + }; + + let password = match connect_flags & 0b0100_0000 { + 0 => String::new(), + _ => read_mqtt_string(&mut bytes)?, + }; + + if username.is_empty() && password.is_empty() { + Ok(None) + } else { + Ok(Some(Login { username, password })) + } + } + + fn len(&self) -> usize { + let mut len = 0; + + if !self.username.is_empty() { + len += 2 + self.username.len(); + } + + if !self.password.is_empty() { + len += 2 + self.password.len(); + } + + len + } + + fn write(&self, buffer: &mut BytesMut) -> u8 { + let mut connect_flags = 0; + if !self.username.is_empty() { + connect_flags |= 0x80; + write_mqtt_string(buffer, &self.username); + } + + if !self.password.is_empty() { + connect_flags |= 0x40; + write_mqtt_string(buffer, &self.password); + } + + connect_flags + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct ConnectProperties { + /// Expiry interval property after loosing connection + pub session_expiry_interval: Option, + /// Maximum simultaneous packets + pub receive_maximum: Option, + /// Maximum packet size + pub max_packet_size: Option, + /// Maximum mapping integer for a topic + pub topic_alias_max: Option, + pub request_response_info: Option, + pub request_problem_info: Option, + /// List of user properties + pub user_properties: Vec<(String, String)>, + /// Method of authentication + pub authentication_method: Option, + /// Authentication data + pub authentication_data: Option, +} + +impl ConnectProperties { + fn _new() -> ConnectProperties { + ConnectProperties { + session_expiry_interval: None, + receive_maximum: None, + max_packet_size: None, + topic_alias_max: None, + request_response_info: None, + request_problem_info: None, + user_properties: Vec::new(), + authentication_method: None, + authentication_data: None, + } + } + + fn read(mut bytes: &mut Bytes) -> Result, Error> { + let mut session_expiry_interval = None; + let mut receive_maximum = None; + let mut max_packet_size = None; + let mut topic_alias_max = None; + let mut request_response_info = None; + let mut request_problem_info = None; + let mut user_properties = Vec::new(); + let mut authentication_method = None; + let mut authentication_data = None; + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + match property(prop)? { + PropertyType::SessionExpiryInterval => { + session_expiry_interval = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::ReceiveMaximum => { + receive_maximum = Some(read_u16(&mut bytes)?); + cursor += 2; + } + PropertyType::MaximumPacketSize => { + max_packet_size = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::TopicAliasMaximum => { + topic_alias_max = Some(read_u16(&mut bytes)?); + cursor += 2; + } + PropertyType::RequestResponseInformation => { + request_response_info = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::RequestProblemInformation => { + request_problem_info = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + PropertyType::AuthenticationMethod => { + let method = read_mqtt_string(&mut bytes)?; + cursor += 2 + method.len(); + authentication_method = Some(method); + } + PropertyType::AuthenticationData => { + let data = read_mqtt_bytes(&mut bytes)?; + cursor += 2 + data.len(); + authentication_data = Some(data); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(ConnectProperties { + session_expiry_interval, + receive_maximum, + max_packet_size, + topic_alias_max, + request_response_info, + request_problem_info, + user_properties, + authentication_method, + authentication_data, + })) + } + + fn len(&self) -> usize { + let mut len = 0; + + if self.session_expiry_interval.is_some() { + len += 1 + 4; + } + + if self.receive_maximum.is_some() { + len += 1 + 2; + } + + if self.max_packet_size.is_some() { + len += 1 + 4; + } + + if self.topic_alias_max.is_some() { + len += 1 + 2; + } + + if self.request_response_info.is_some() { + len += 1 + 1; + } + + if self.request_problem_info.is_some() { + len += 1 + 1; + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + if let Some(authentication_method) = &self.authentication_method { + len += 1 + 2 + authentication_method.len(); + } + + if let Some(authentication_data) = &self.authentication_data { + len += 1 + 2 + authentication_data.len(); + } + + len + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(session_expiry_interval) = self.session_expiry_interval { + buffer.put_u8(PropertyType::SessionExpiryInterval as u8); + buffer.put_u32(session_expiry_interval); + } + + if let Some(receive_maximum) = self.receive_maximum { + buffer.put_u8(PropertyType::ReceiveMaximum as u8); + buffer.put_u16(receive_maximum); + } + + if let Some(max_packet_size) = self.max_packet_size { + buffer.put_u8(PropertyType::MaximumPacketSize as u8); + buffer.put_u32(max_packet_size); + } + + if let Some(topic_alias_max) = self.topic_alias_max { + buffer.put_u8(PropertyType::TopicAliasMaximum as u8); + buffer.put_u16(topic_alias_max); + } + + if let Some(request_response_info) = self.request_response_info { + buffer.put_u8(PropertyType::RequestResponseInformation as u8); + buffer.put_u8(request_response_info); + } + + if let Some(request_problem_info) = self.request_problem_info { + buffer.put_u8(PropertyType::RequestProblemInformation as u8); + buffer.put_u8(request_problem_info); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + if let Some(authentication_method) = &self.authentication_method { + buffer.put_u8(PropertyType::AuthenticationMethod as u8); + write_mqtt_string(buffer, authentication_method); + } + + if let Some(authentication_data) = &self.authentication_data { + buffer.put_u8(PropertyType::AuthenticationData as u8); + write_mqtt_bytes(buffer, authentication_data); + } + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use super::*; + + fn sample() -> Connect { + let connect_properties = ConnectProperties { + session_expiry_interval: Some(1234), + receive_maximum: Some(432), + max_packet_size: Some(100), + topic_alias_max: Some(456), + request_response_info: Some(1), + request_problem_info: Some(1), + user_properties: vec![("test".to_owned(), "test".to_owned())], + authentication_method: Some("test".to_owned()), + authentication_data: Some(Bytes::from(vec![1, 2, 3, 4])), + }; + + let will_properties = WillProperties { + delay_interval: Some(1234), + payload_format_indicator: Some(0), + message_expiry_interval: Some(4321), + content_type: Some("test".to_owned()), + response_topic: Some("topic".to_owned()), + correlation_data: Some(Bytes::from(vec![1, 2, 3, 4])), + user_properties: vec![("test".to_owned(), "test".to_owned())], + }; + + let will = LastWill { + topic: "mydevice/status".to_string(), + message: Bytes::from(vec![b'd', b'e', b'a', b'd']), + qos: QoS::AtMostOnce, + retain: false, + properties: Some(will_properties), + }; + + let login = Login { + username: "matteo".to_string(), + password: "collina".to_string(), + }; + + Connect { + protocol: Protocol::V5, + keep_alive: 0, + properties: Some(connect_properties), + client_id: "my-device".to_string(), + clean_session: true, + last_will: Some(will), + login: Some(login), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0x10, // packet type + 0x9d, // remaining len + 0x01, // remaining len + 0x00, 0x04, // 4 + 0x4d, // M + 0x51, // Q + 0x54, // T + 0x54, // T + 0x05, // Level + 0xc6, // connect flags + 0x00, 0x00, // keep alive + 0x2f, // properties len + 0x11, 0x00, 0x00, 0x04, 0xd2, // session expiry interval + 0x21, 0x01, 0xb0, // receive_maximum + 0x27, 0x00, 0x00, 0x00, 0x64, // max packet size + 0x22, 0x01, 0xc8, // topic_alias_max + 0x19, 0x01, // request_response_info + 0x17, 0x01, // request_problem_info + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // user + 0x15, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // authentication_method + 0x16, 0x00, 0x04, 0x01, 0x02, 0x03, 0x04, // authentication_data + 0x00, 0x09, 0x6d, 0x79, 0x2d, 0x64, 0x65, 0x76, 0x69, 0x63, 0x65, // client id + 0x2f, // will properties len + 0x18, 0x00, 0x00, 0x04, 0xd2, // will delay interval + 0x01, 0x00, // payload format indicator + 0x02, 0x00, 0x00, 0x10, 0xe1, // message expiry interval + 0x03, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // content type + 0x08, 0x00, 0x05, 0x74, 0x6f, 0x70, 0x69, 0x63, // response topic + 0x09, 0x00, 0x04, 0x01, 0x02, 0x03, 0x04, // correlation_data + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // will user properties + 0x00, 0x0f, 0x6d, 0x79, 0x64, 0x65, 0x76, 0x69, 0x63, 0x65, 0x2f, 0x73, 0x74, 0x61, + 0x74, 0x75, 0x73, // will topic + 0x00, 0x04, 0x64, 0x65, 0x61, 0x64, // will payload + 0x00, 0x06, 0x6d, 0x61, 0x74, 0x74, 0x65, 0x6f, // username + 0x00, 0x07, 0x63, 0x6f, 0x6c, 0x6c, 0x69, 0x6e, 0x61, // password + ] + } + + #[test] + fn connect1_parsing_works() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let connect_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let connect = Connect::read(fixed_header, connect_bytes).unwrap(); + assert_eq!(connect, sample()); + } + + #[test] + fn connect1_encoding_works() { + let connect = sample(); + let mut buf = BytesMut::new(); + connect.write(&mut buf).unwrap(); + assert_eq!(&buf[..], sample_bytes()); + } + + fn sample2() -> Connect { + Connect { + protocol: Protocol::V5, + keep_alive: 10, + properties: None, + client_id: "hackathonmqtt5test".to_owned(), + clean_session: true, + last_will: None, + login: None, + } + } + + fn sample2_bytes() -> Vec { + vec![ + 0x10, // packet type + 0x1f, 0x00, // remaining len + 0x04, // 4 + 0x4d, 0x51, 0x54, 0x54, // MQTT + 0x05, // level + 0x02, // connect flags + 0x00, 0x0a, // keep alive + 0x00, 0x00, 0x12, 0x68, 0x61, 0x63, 0x6b, 0x61, 0x74, 0x68, 0x6f, 0x6e, 0x6d, 0x71, + 0x74, 0x74, 0x35, 0x74, 0x65, 0x73, 0x74, // payload + 0x10, 0x11, 0x12, // extra bytes in the stream + ] + } + + #[test] + fn connect2_parsing_works() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample2_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let connect_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let connect = Connect::read(fixed_header, connect_bytes).unwrap(); + assert_eq!(connect, sample2()); + } + + #[test] + fn connect2_encoding_works() { + let connect = sample2(); + let mut buf = BytesMut::new(); + connect.write(&mut buf).unwrap(); + + let expected = sample2_bytes(); + assert_eq!(&buf[..], &expected[0..(expected.len() - 3)]); + } + + fn sample3() -> Connect { + let connect_properties = ConnectProperties { + session_expiry_interval: Some(1234), + receive_maximum: Some(432), + max_packet_size: Some(100), + topic_alias_max: Some(456), + request_response_info: Some(1), + request_problem_info: Some(1), + user_properties: vec![("test".to_owned(), "test".to_owned())], + authentication_method: Some("test".to_owned()), + authentication_data: Some(Bytes::from(vec![1, 2, 3, 4])), + }; + + let will = LastWill { + topic: "mydevice/status".to_string(), + message: Bytes::from(vec![b'd', b'e', b'a', b'd']), + qos: QoS::AtMostOnce, + retain: false, + properties: None, + }; + + let login = Login { + username: "matteo".to_string(), + password: "collina".to_string(), + }; + + Connect { + protocol: Protocol::V5, + keep_alive: 0, + properties: Some(connect_properties), + client_id: "my-device".to_string(), + clean_session: true, + last_will: Some(will), + login: Some(login), + } + } + + fn sample3_bytes() -> Vec { + vec![ + 0x10, 0x6e, 0x00, 0x04, 0x4d, 0x51, 0x54, 0x54, 0x05, 0xc6, 0x00, 0x00, 0x2f, 0x11, + 0x00, 0x00, 0x04, 0xd2, 0x21, 0x01, 0xb0, 0x27, 0x00, 0x00, 0x00, 0x64, 0x22, 0x01, + 0xc8, 0x19, 0x01, 0x17, 0x01, 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, + 0x74, 0x65, 0x73, 0x74, 0x15, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x16, 0x00, 0x04, + 0x01, 0x02, 0x03, 0x04, 0x00, 0x09, 0x6d, 0x79, 0x2d, 0x64, 0x65, 0x76, 0x69, 0x63, + 0x65, 0x00, 0x00, 0x0f, 0x6d, 0x79, 0x64, 0x65, 0x76, 0x69, 0x63, 0x65, 0x2f, 0x73, + 0x74, 0x61, 0x74, 0x75, 0x73, 0x00, 0x04, 0x64, 0x65, 0x61, 0x64, 0x00, 0x06, 0x6d, + 0x61, 0x74, 0x74, 0x65, 0x6f, 0x00, 0x07, 0x63, 0x6f, 0x6c, 0x6c, 0x69, 0x6e, 0x61, + ] + } + + #[test] + fn connect3_parsing_works() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample3_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let connect_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let connect = Connect::read(fixed_header, connect_bytes).unwrap(); + assert_eq!(connect, sample3()); + } + + #[test] + fn connect3_encoding_works() { + let connect = sample3(); + let mut buf = BytesMut::new(); + connect.write(&mut buf).unwrap(); + + let expected = sample3_bytes(); + assert_eq!(&buf[..], &expected[0..(expected.len())]); + } + + #[test] + fn missing_properties_are_encoded() {} +} diff --git a/rumqttc/src/v5/packet/disconnect.rs b/rumqttc/src/v5/packet/disconnect.rs new file mode 100644 index 000000000..508f3c427 --- /dev/null +++ b/rumqttc/src/v5/packet/disconnect.rs @@ -0,0 +1,434 @@ +use std::convert::{TryFrom, TryInto}; + +use bytes::{BufMut, BytesMut, Bytes}; + +use super::*; + +use super::{property, PropertyType}; + +#[derive(Debug, Clone, Copy, PartialEq)] +#[repr(u8)] +pub enum DisconnectReasonCode { + /// Close the connection normally. Do not send the Will Message. + NormalDisconnection = 0x00, + /// The Client wishes to disconnect but requires that the Server also publishes its Will Message. + DisconnectWithWillMessage = 0x04, + /// The Connection is closed but the sender either does not wish to reveal the reason, or none of the other Reason Codes apply. + UnspecifiedError = 0x80, + /// The received packet does not conform to this specification. + MalformedPacket = 0x81, + /// An unexpected or out of order packet was received. + ProtocolError = 0x82, + /// The packet received is valid but cannot be processed by this implementation. + ImplementationSpecificError = 0x83, + /// The request is not authorized. + NotAuthorized = 0x87, + /// The Server is busy and cannot continue processing requests from this Client. + ServerBusy = 0x89, + /// The Server is shutting down. + ServerShuttingDown = 0x8B, + /// The Connection is closed because no packet has been received for 1.5 times the Keepalive time. + KeepAliveTimeout = 0x8D, + /// Another Connection using the same ClientID has connected causing this Connection to be closed. + SessionTakenOver = 0x8E, + /// The Topic Filter is correctly formed, but is not accepted by this Sever. + TopicFilterInvalid = 0x8F, + /// The Topic Name is correctly formed, but is not accepted by this Client or Server. + TopicNameInvalid = 0x90, + /// The Client or Server has received more than Receive Maximum publication for which it has not sent PUBACK or PUBCOMP. + ReceiveMaximumExceeded = 0x93, + /// The Client or Server has received a PUBLISH packet containing a Topic Alias which is greater than the Maximum Topic Alias it sent in the CONNECT or CONNACK packet. + TopicAliasInvalid = 0x94, + /// The packet size is greater than Maximum Packet Size for this Client or Server. + PacketTooLarge = 0x95, + /// The received data rate is too high. + MessageRateTooHigh = 0x96, + /// An implementation or administrative imposed limit has been exceeded. + QuotaExceeded = 0x97, + /// The Connection is closed due to an administrative action. + AdministrativeAction = 0x98, + /// The payload format does not match the one specified by the Payload Format Indicator. + PayloadFormatInvalid = 0x99, + /// The Server has does not support retained messages. + RetainNotSupported = 0x9A, + /// The Client specified a QoS greater than the QoS specified in a Maximum QoS in the CONNACK. + QoSNotSupported = 0x9B, + /// The Client should temporarily change its Server. + UseAnotherServer = 0x9C, + /// The Server is moved and the Client should permanently change its server location. + ServerMoved = 0x9D, + /// The Server does not support Shared Subscriptions. + SharedSubscriptionNotSupported = 0x9E, + /// This connection is closed because the connection rate is too high. + ConnectionRateExceeded = 0x9F, + /// The maximum connection time authorized for this connection has been exceeded. + MaximumConnectTime = 0xA0, + /// The Server does not support Subscription Identifiers; the subscription is not accepted. + SubscriptionIdentifiersNotSupported = 0xA1, + /// The Server does not support Wildcard subscription; the subscription is not accepted. + WildcardSubscriptionsNotSupported = 0xA2, +} + +impl TryFrom for DisconnectReasonCode { + type Error = Error; + + fn try_from(value: u8) -> Result { + let rc = match value { + 0x00 => Self::NormalDisconnection, + 0x04 => Self::DisconnectWithWillMessage, + 0x80 => Self::UnspecifiedError, + 0x81 => Self::MalformedPacket, + 0x82 => Self::ProtocolError, + 0x83 => Self::ImplementationSpecificError, + 0x87 => Self::NotAuthorized, + 0x89 => Self::ServerBusy, + 0x8B => Self::ServerShuttingDown, + 0x8D => Self::KeepAliveTimeout, + 0x8E => Self::SessionTakenOver, + 0x8F => Self::TopicFilterInvalid, + 0x90 => Self::TopicNameInvalid, + 0x93 => Self::ReceiveMaximumExceeded, + 0x94 => Self::TopicAliasInvalid, + 0x95 => Self::PacketTooLarge, + 0x96 => Self::MessageRateTooHigh, + 0x97 => Self::QuotaExceeded, + 0x98 => Self::AdministrativeAction, + 0x99 => Self::PayloadFormatInvalid, + 0x9A => Self::RetainNotSupported, + 0x9B => Self::QoSNotSupported, + 0x9C => Self::UseAnotherServer, + 0x9D => Self::ServerMoved, + 0x9E => Self::SharedSubscriptionNotSupported, + 0x9F => Self::ConnectionRateExceeded, + 0xA0 => Self::MaximumConnectTime, + 0xA1 => Self::SubscriptionIdentifiersNotSupported, + 0xA2 => Self::WildcardSubscriptionsNotSupported, + other => return Err(Error::InvalidConnectReturnCode(other)), + }; + + Ok(rc) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct DisconnectProperties { + /// Session Expiry Interval in seconds + pub session_expiry_interval: Option, + + /// Human readable reason for the disconnect + pub reason_string: Option, + + /// List of user properties + pub user_properties: Vec<(String, String)>, + + /// String which can be used by the Client to identify another Server to use. + pub server_reference: Option, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct Disconnect { + /// Disconnect Reason Code + pub reason_code: DisconnectReasonCode, + + /// Disconnect Properties + pub properties: Option, +} + +impl DisconnectProperties { + pub fn new() -> Self { + Self { + session_expiry_interval: None, + reason_string: None, + user_properties: Vec::new(), + server_reference: None, + } + } + + fn len(&self) -> usize { + let mut length = 0; + + if self.session_expiry_interval.is_some() { + length += 1 + 4; + } + + if let Some(reason) = &self.reason_string { + length += 1 + 2 + reason.len(); + } + + for (key, value) in self.user_properties.iter() { + length += 1 + 2 + key.len() + 2 + value.len(); + } + + if let Some(server_reference) = &self.server_reference { + length += 1 + 2 + server_reference.len(); + } + + length + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let (properties_len_len, properties_len) = length(bytes.iter())?; + + bytes.advance(properties_len_len); + + if properties_len == 0 { + return Ok(None); + } + + let mut session_expiry_interval = None; + let mut reason_string = None; + let mut user_properties = Vec::new(); + let mut server_reference = None; + + let mut cursor = 0; + + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::SessionExpiryInterval => { + session_expiry_interval = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::ReasonString => { + let reason = read_mqtt_string(&mut bytes)?; + cursor += 2 + reason.len(); + reason_string = Some(reason); + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + PropertyType::ServerReference => { + let reference = read_mqtt_string(&mut bytes)?; + cursor += 2 + reference.len(); + server_reference = Some(reference); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + let properties = Self { + session_expiry_interval, + reason_string, + user_properties, + server_reference, + }; + + Ok(Some(properties)) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let length = self.len(); + write_remaining_length(buffer, length)?; + + if let Some(session_expiry_interval) = self.session_expiry_interval { + buffer.put_u8(PropertyType::SessionExpiryInterval as u8); + buffer.put_u32(session_expiry_interval); + } + + if let Some(reason) = &self.reason_string { + buffer.put_u8(PropertyType::ReasonString as u8); + write_mqtt_string(buffer, reason); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + if let Some(reference) = &self.server_reference { + buffer.put_u8(PropertyType::ServerReference as u8); + write_mqtt_string(buffer, reference); + } + + Ok(()) + } +} + +impl Disconnect { + pub fn new() -> Self { + Self { + reason_code: DisconnectReasonCode::NormalDisconnection, + properties: None, + } + } + + fn len(&self) -> usize { + if self.reason_code == DisconnectReasonCode::NormalDisconnection + && self.properties.is_none() + { + return 2; // Packet type + 0x00 + } + + let mut length = 0; + + match &self.properties { + Some(properties) => { + length += 1; // Disconnect Reason Code + + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + length += properties_len_len + properties_len; + } + None if self.reason_code == DisconnectReasonCode::NormalDisconnection => {} + None => { + length += 1; // Disconnect Reason Code + } + }; + + length + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let packet_type = fixed_header.byte1 >> 4; + let flags = fixed_header.byte1 & 0b0000_1111; + + bytes.advance(fixed_header.fixed_header_len); + + if packet_type != PacketType::Disconnect as u8 { + return Err(Error::InvalidPacketType(packet_type)); + }; + + if flags != 0x00 { + return Err(Error::MalformedPacket); + }; + + if fixed_header.remaining_len == 0 { + return Ok(Self::new()); + } + + let reason_code = read_u8(&mut bytes)?; + + let disconnect = Self { + reason_code: reason_code.try_into()?, + properties: DisconnectProperties::extract(&mut bytes)?, + }; + + Ok(disconnect) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + buffer.put_u8(0xE0); + + let length = self.len(); + + if length == 2 { + buffer.put_u8(0x00); + + return Ok(length); + } + + let len_len = write_remaining_length(buffer, length)?; + + buffer.put_u8(self.reason_code as u8); + + if let Some(properties) = &self.properties { + properties.write(buffer)?; + } + + Ok(1 + len_len + length) + } +} + +#[cfg(test)] +mod test { + use bytes::BytesMut; + + use super::parse_fixed_header; + + use super::{Disconnect, DisconnectProperties, DisconnectReasonCode}; + + #[test] + fn disconnect1_parsing_works() { + let mut buffer = bytes::BytesMut::new(); + let packet_bytes = [ + 0xE0, // Packet type + 0x00, // Remaining length + ]; + let expected = Disconnect::new(); + + buffer.extend_from_slice(&packet_bytes[..]); + + let fixed_header = parse_fixed_header(buffer.iter()).unwrap(); + let disconnect_bytes = buffer.split_to(fixed_header.frame_length()).freeze(); + let disconnect = Disconnect::read(fixed_header, disconnect_bytes).unwrap(); + + assert_eq!(disconnect, expected); + } + + #[test] + fn disconnect1_encoding_works() { + let mut buffer = BytesMut::new(); + let disconnect = Disconnect::new(); + let expected = [ + 0xE0, // Packet type + 0x00, // Remaining length + ]; + + disconnect.write(&mut buffer).unwrap(); + + assert_eq!(&buffer[..], &expected); + } + + fn sample2() -> Disconnect { + let properties = DisconnectProperties { + // TODO: change to 2137 xD + session_expiry_interval: Some(1234), + reason_string: Some("test".to_owned()), + user_properties: vec![("test".to_owned(), "test".to_owned())], + server_reference: Some("test".to_owned()), + }; + + Disconnect { + reason_code: DisconnectReasonCode::UnspecifiedError, + properties: Some(properties), + } + } + + fn sample_bytes2() -> Vec { + vec![ + 0xE0, // Packet type + 0x22, // Remaining length + 0x80, // Disconnect Reason Code + 0x20, // Properties length + 0x11, 0x00, 0x00, 0x04, 0xd2, // Session expiry interval + 0x1F, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // Reason string + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // User properties + 0x1C, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // server reference + ] + } + + #[test] + fn disconnect2_parsing_works() { + let mut buffer = bytes::BytesMut::new(); + let packet_bytes = sample_bytes2(); + let expected = sample2(); + + buffer.extend_from_slice(&packet_bytes[..]); + + let fixed_header = parse_fixed_header(buffer.iter()).unwrap(); + let disconnect_bytes = buffer.split_to(fixed_header.frame_length()).freeze(); + let disconnect = Disconnect::read(fixed_header, disconnect_bytes).unwrap(); + + assert_eq!(disconnect, expected); + } + + #[test] + fn disconnect2_encoding_works() { + let mut buffer = BytesMut::new(); + + let disconnect = sample2(); + let expected = sample_bytes2(); + + disconnect.write(&mut buffer).unwrap(); + + assert_eq!(&buffer[..], &expected); + } +} diff --git a/rumqttc/src/v5/packet/mod.rs b/rumqttc/src/v5/packet/mod.rs new file mode 100644 index 000000000..8f954c1cf --- /dev/null +++ b/rumqttc/src/v5/packet/mod.rs @@ -0,0 +1,489 @@ +use std::{ + fmt::{self, Display, Formatter}, + slice::Iter, +}; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +mod connack; +mod connect; +mod disconnect; +mod ping; +mod puback; +mod pubcomp; +mod publish; +mod pubrec; +mod pubrel; +mod suback; +mod subscribe; +mod unsuback; +mod unsubscribe; + +pub use connack::*; +pub use connect::*; +pub use disconnect::*; +pub use ping::*; +pub use puback::*; +pub use pubcomp::*; +pub use publish::*; +pub use pubrec::*; +pub use pubrel::*; +pub use suback::*; +pub use subscribe::*; +pub use unsuback::*; +pub use unsubscribe::*; + +/// Encapsulates all MQTT packet types +#[derive(Debug, Clone, PartialEq)] +pub enum Packet { + Connect(Connect), + ConnAck(ConnAck), + Publish(Publish), + PubAck(PubAck), + PubRec(PubRec), + PubRel(PubRel), + PubComp(PubComp), + Subscribe(Subscribe), + SubAck(SubAck), + Unsubscribe(Unsubscribe), + UnsubAck(UnsubAck), + PingReq, + PingResp, + Disconnect(Disconnect), +} + +/// MQTT packet type +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PacketType { + Connect = 1, + ConnAck, + Publish, + PubAck, + PubRec, + PubRel, + PubComp, + Subscribe, + SubAck, + Unsubscribe, + UnsubAck, + PingReq, + PingResp, + Disconnect, +} + +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum PropertyType { + PayloadFormatIndicator = 1, + MessageExpiryInterval = 2, + ContentType = 3, + ResponseTopic = 8, + CorrelationData = 9, + SubscriptionIdentifier = 11, + SessionExpiryInterval = 17, + AssignedClientIdentifier = 18, + ServerKeepAlive = 19, + AuthenticationMethod = 21, + AuthenticationData = 22, + RequestProblemInformation = 23, + WillDelayInterval = 24, + RequestResponseInformation = 25, + ResponseInformation = 26, + ServerReference = 28, + ReasonString = 31, + ReceiveMaximum = 33, + TopicAliasMaximum = 34, + TopicAlias = 35, + MaximumQos = 36, + RetainAvailable = 37, + UserProperty = 38, + MaximumPacketSize = 39, + WildcardSubscriptionAvailable = 40, + SubscriptionIdentifierAvailable = 41, + SharedSubscriptionAvailable = 42, +} + +/// Error during serialization and deserialization +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Error { + NotConnect(PacketType), + UnexpectedConnect, + InvalidConnectReturnCode(u8), + InvalidReason(u8), + InvalidProtocol, + InvalidProtocolLevel(u8), + IncorrectPacketFormat, + InvalidPacketType(u8), + InvalidPropertyType(u8), + InvalidRetainForwardRule(u8), + InvalidQoS(u8), + InvalidSubscribeReasonCode(u8), + PacketIdZero, + SubscriptionIdZero, + PayloadSizeIncorrect, + PayloadTooLong, + PayloadSizeLimitExceeded(usize), + PayloadRequired, + TopicNotUtf8, + BoundaryCrossed(usize), + MalformedPacket, + MalformedRemainingLength, + /// More bytes required to frame packet. Argument + /// implies minimum additional bytes required to + /// proceed further + InsufficientBytes(usize), +} + +/// Protocol type +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Protocol { + V4, + V5, +} + +/// Quality of service +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)] +pub enum QoS { + AtMostOnce = 0, + AtLeastOnce = 1, + ExactlyOnce = 2, +} + +/// Packet type from a byte +/// +/// ```ignore +/// 7 3 0 +/// +--------------------------+--------------------------+ +/// byte 1 | MQTT Control Packet Type | Flags for each type | +/// +--------------------------+--------------------------+ +/// | Remaining Bytes Len (1/2/3/4 bytes) | +/// +-----------------------------------------------------+ +/// +/// http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Figure_2.2_- +/// ``` +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)] +pub struct FixedHeader { + /// First byte of the stream. Used to identify packet types and + /// several flags + byte1: u8, + /// Length of fixed header. Byte 1 + (1..4) bytes. So fixed header + /// len can vary from 2 bytes to 5 bytes + /// 1..4 bytes are variable length encoded to represent remaining length + fixed_header_len: usize, + /// Remaining length of the packet. Doesn't include fixed header bytes + /// Represents variable header + payload size + remaining_len: usize, +} + +impl FixedHeader { + pub fn new(byte1: u8, remaining_len_len: usize, remaining_len: usize) -> FixedHeader { + FixedHeader { + byte1, + fixed_header_len: remaining_len_len + 1, + remaining_len, + } + } + + pub fn packet_type(&self) -> Result { + let num = self.byte1 >> 4; + match num { + 1 => Ok(PacketType::Connect), + 2 => Ok(PacketType::ConnAck), + 3 => Ok(PacketType::Publish), + 4 => Ok(PacketType::PubAck), + 5 => Ok(PacketType::PubRec), + 6 => Ok(PacketType::PubRel), + 7 => Ok(PacketType::PubComp), + 8 => Ok(PacketType::Subscribe), + 9 => Ok(PacketType::SubAck), + 10 => Ok(PacketType::Unsubscribe), + 11 => Ok(PacketType::UnsubAck), + 12 => Ok(PacketType::PingReq), + 13 => Ok(PacketType::PingResp), + 14 => Ok(PacketType::Disconnect), + _ => Err(Error::InvalidPacketType(num)), + } + } + + /// Returns the size of full packet (fixed header + variable header + payload) + /// Fixed header is enough to get the size of a frame in the stream + pub fn frame_length(&self) -> usize { + self.fixed_header_len + self.remaining_len + } +} + +fn property(num: u8) -> Result { + let property = match num { + 1 => PropertyType::PayloadFormatIndicator, + 2 => PropertyType::MessageExpiryInterval, + 3 => PropertyType::ContentType, + 8 => PropertyType::ResponseTopic, + 9 => PropertyType::CorrelationData, + 11 => PropertyType::SubscriptionIdentifier, + 17 => PropertyType::SessionExpiryInterval, + 18 => PropertyType::AssignedClientIdentifier, + 19 => PropertyType::ServerKeepAlive, + 21 => PropertyType::AuthenticationMethod, + 22 => PropertyType::AuthenticationData, + 23 => PropertyType::RequestProblemInformation, + 24 => PropertyType::WillDelayInterval, + 25 => PropertyType::RequestResponseInformation, + 26 => PropertyType::ResponseInformation, + 28 => PropertyType::ServerReference, + 31 => PropertyType::ReasonString, + 33 => PropertyType::ReceiveMaximum, + 34 => PropertyType::TopicAliasMaximum, + 35 => PropertyType::TopicAlias, + 36 => PropertyType::MaximumQos, + 37 => PropertyType::RetainAvailable, + 38 => PropertyType::UserProperty, + 39 => PropertyType::MaximumPacketSize, + 40 => PropertyType::WildcardSubscriptionAvailable, + 41 => PropertyType::SubscriptionIdentifierAvailable, + 42 => PropertyType::SharedSubscriptionAvailable, + num => return Err(Error::InvalidPropertyType(num)), + }; + + Ok(property) +} + +/// Checks if the stream has enough bytes to frame a packet and returns fixed header +/// only if a packet can be framed with existing bytes in the `stream`. +/// The passed stream doesn't modify parent stream's cursor. If this function +/// returned an error, next `check` on the same parent stream is forced start +/// with cursor at 0 again (Iter is owned. Only Iter's cursor is changed internally) +pub fn check(stream: Iter, max_packet_size: usize) -> Result { + // Create fixed header if there are enough bytes in the stream + // to frame full packet + let stream_len = stream.len(); + let fixed_header = parse_fixed_header(stream)?; + + // Don't let rogue connections attack with huge payloads. + // Disconnect them before reading all that data + if fixed_header.remaining_len > max_packet_size { + return Err(Error::PayloadSizeLimitExceeded(fixed_header.remaining_len)); + } + + // If the current call fails due to insufficient bytes in the stream, + // after calculating remaining length, we extend the stream + let frame_length = fixed_header.frame_length(); + if stream_len < frame_length { + return Err(Error::InsufficientBytes(frame_length - stream_len)); + } + + Ok(fixed_header) +} + +/// Parses fixed header +fn parse_fixed_header(mut stream: Iter) -> Result { + // At least 2 bytes are necessary to frame a packet + let stream_len = stream.len(); + if stream_len < 2 { + return Err(Error::InsufficientBytes(2 - stream_len)); + } + + let byte1 = stream.next().unwrap(); + let (len_len, len) = length(stream)?; + + Ok(FixedHeader::new(*byte1, len_len, len)) +} + +/// Parses variable byte integer in the stream and returns the length +/// and number of bytes that make it. Used for remaining length calculation +/// as well as for calculating property lengths +fn length(stream: Iter) -> Result<(usize, usize), Error> { + let mut len: usize = 0; + let mut len_len = 0; + let mut done = false; + let mut shift = 0; + + // Use continuation bit at position 7 to continue reading next + // byte to frame 'length'. + // Stream 0b1xxx_xxxx 0b1yyy_yyyy 0b1zzz_zzzz 0b0www_wwww will + // be framed as number 0bwww_wwww_zzz_zzzz_yyy_yyyy_xxx_xxxx + for byte in stream { + len_len += 1; + let byte = *byte as usize; + len += (byte & 0x7F) << shift; + + // stop when continue bit is 0 + done = (byte & 0x80) == 0; + if done { + break; + } + + shift += 7; + + // Only a max of 4 bytes allowed for remaining length + // more than 4 shifts (0, 7, 14, 21) implies bad length + if shift > 21 { + return Err(Error::MalformedRemainingLength); + } + } + + // Not enough bytes to frame remaining length. wait for + // one more byte + if !done { + return Err(Error::InsufficientBytes(1)); + } + + Ok((len_len, len)) +} + +/// Reads a stream of bytes and extracts next MQTT packet out of it +pub fn read(stream: &mut BytesMut, max_size: usize) -> Result { + let fixed_header = check(stream.iter(), max_size)?; + + // Test with a stream with exactly the size to check border panics + let packet = stream.split_to(fixed_header.frame_length()); + let packet_type = fixed_header.packet_type()?; + + if fixed_header.remaining_len == 0 { + // no payload packets + return match packet_type { + PacketType::PingReq => Ok(Packet::PingReq), + PacketType::PingResp => Ok(Packet::PingResp), + _ => Err(Error::PayloadRequired), + }; + } + + let packet = packet.freeze(); + let packet = match packet_type { + PacketType::Connect => Packet::Connect(Connect::read(fixed_header, packet)?), + PacketType::ConnAck => Packet::ConnAck(ConnAck::read(fixed_header, packet)?), + PacketType::Publish => Packet::Publish(Publish::read(fixed_header, packet)?), + PacketType::PubAck => Packet::PubAck(PubAck::read(fixed_header, packet)?), + PacketType::PubRec => Packet::PubRec(PubRec::read(fixed_header, packet)?), + PacketType::PubRel => Packet::PubRel(PubRel::read(fixed_header, packet)?), + PacketType::PubComp => Packet::PubComp(PubComp::read(fixed_header, packet)?), + PacketType::Subscribe => Packet::Subscribe(Subscribe::read(fixed_header, packet)?), + PacketType::SubAck => Packet::SubAck(SubAck::read(fixed_header, packet)?), + PacketType::Unsubscribe => Packet::Unsubscribe(Unsubscribe::read(fixed_header, packet)?), + PacketType::UnsubAck => Packet::UnsubAck(UnsubAck::read(fixed_header, packet)?), + PacketType::PingReq => Packet::PingReq, + PacketType::PingResp => Packet::PingResp, + PacketType::Disconnect => Packet::Disconnect(Disconnect::read(fixed_header, packet)?), + }; + + Ok(packet) +} + +/// Reads a series of bytes with a length from a byte stream +fn read_mqtt_bytes(stream: &mut Bytes) -> Result { + let len = read_u16(stream)? as usize; + + // Prevent attacks with wrong remaining length. This method is used in + // `packet.assembly()` with (enough) bytes to frame packet. Ensures that + // reading variable len string or bytes doesn't cross promised boundary + // with `read_fixed_header()` + if len > stream.len() { + return Err(Error::BoundaryCrossed(len)); + } + + Ok(stream.split_to(len)) +} + +/// Reads a string from bytes stream +fn read_mqtt_string(stream: &mut Bytes) -> Result { + let s = read_mqtt_bytes(stream)?; + match String::from_utf8(s.to_vec()) { + Ok(v) => Ok(v), + Err(_e) => Err(Error::TopicNotUtf8), + } +} + +/// Serializes bytes to stream (including length) +fn write_mqtt_bytes(stream: &mut BytesMut, bytes: &[u8]) { + stream.put_u16(bytes.len() as u16); + stream.extend_from_slice(bytes); +} + +/// Serializes a string to stream +fn write_mqtt_string(stream: &mut BytesMut, string: &str) { + write_mqtt_bytes(stream, string.as_bytes()); +} + +/// Writes remaining length to stream and returns number of bytes for remaining length +fn write_remaining_length(stream: &mut BytesMut, len: usize) -> Result { + if len > 268_435_455 { + return Err(Error::PayloadTooLong); + } + + let mut done = false; + let mut x = len; + let mut count = 0; + + while !done { + let mut byte = (x % 128) as u8; + x /= 128; + if x > 0 { + byte |= 128; + } + + stream.put_u8(byte); + count += 1; + done = x == 0; + } + + Ok(count) +} + +/// Return number of remaining length bytes required for encoding length +fn len_len(len: usize) -> usize { + if len >= 2_097_152 { + 4 + } else if len >= 16_384 { + 3 + } else if len >= 128 { + 2 + } else { + 1 + } +} + +/// Maps a number to QoS +pub fn qos(num: u8) -> Result { + match num { + 0 => Ok(QoS::AtMostOnce), + 1 => Ok(QoS::AtLeastOnce), + 2 => Ok(QoS::ExactlyOnce), + qos => Err(Error::InvalidQoS(qos)), + } +} + +/// After collecting enough bytes to frame a packet (packet's frame()) +/// , It's possible that content itself in the stream is wrong. Like expected +/// packet id or qos not being present. In cases where `read_mqtt_string` or +/// `read_mqtt_bytes` exhausted remaining length but packet framing expects to +/// parse qos next, these pre checks will prevent `bytes` crashes +fn read_u16(stream: &mut Bytes) -> Result { + if stream.len() < 2 { + return Err(Error::MalformedPacket); + } + + Ok(stream.get_u16()) +} + +fn read_u8(stream: &mut Bytes) -> Result { + if stream.is_empty() { + return Err(Error::MalformedPacket); + } + + Ok(stream.get_u8()) +} + +fn read_u32(stream: &mut Bytes) -> Result { + if stream.len() < 4 { + return Err(Error::MalformedPacket); + } + + Ok(stream.get_u32()) +} + +impl Display for Error { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "Error = {:?}", self) + } +} diff --git a/rumqttc/src/v5/packet/ping.rs b/rumqttc/src/v5/packet/ping.rs new file mode 100644 index 000000000..1072029bd --- /dev/null +++ b/rumqttc/src/v5/packet/ping.rs @@ -0,0 +1,20 @@ +use super::*; +use bytes::{BufMut, BytesMut}; + +pub struct PingReq; + +impl PingReq { + pub fn write(&self, payload: &mut BytesMut) -> Result { + payload.put_slice(&[0xC0, 0x00]); + Ok(2) + } +} + +pub struct PingResp; + +impl PingResp { + pub fn write(&self, payload: &mut BytesMut) -> Result { + payload.put_slice(&[0xD0, 0x00]); + Ok(2) + } +} diff --git a/rumqttc/src/v5/packet/puback.rs b/rumqttc/src/v5/packet/puback.rs new file mode 100644 index 000000000..51131949e --- /dev/null +++ b/rumqttc/src/v5/packet/puback.rs @@ -0,0 +1,324 @@ +use super::*; +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +/// Return code in connack +#[derive(Debug, Clone, Copy, PartialEq)] +#[repr(u8)] +pub enum PubAckReason { + Success = 0, + NoMatchingSubscribers = 16, + UnspecifiedError = 128, + ImplementationSpecificError = 131, + NotAuthorized = 135, + TopicNameInvalid = 144, + PacketIdentifierInUse = 145, + QuotaExceeded = 151, + PayloadFormatInvalid = 153, +} + +/// Acknowledgement to QoS1 publish +#[derive(Debug, Clone, PartialEq)] +pub struct PubAck { + pub pkid: u16, + pub reason: PubAckReason, + pub properties: Option, +} + +impl PubAck { + pub fn new(pkid: u16) -> PubAck { + PubAck { + pkid, + reason: PubAckReason::Success, + properties: None, + } + } + + fn len(&self) -> usize { + let mut len = 2 + 1; // pkid + reason + + // If there are no properties, sending reason code is optional + if self.reason == PubAckReason::Success && self.properties.is_none() { + return 2; + } + + if let Some(properties) = &self.properties { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + + // Unlike other packets, property length can be ignored if there are + // no properties in acks + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + let pkid = read_u16(&mut bytes)?; + + // No reason code or properties if remaining length == 2 + if fixed_header.remaining_len == 2 { + return Ok(PubAck { + pkid, + reason: PubAckReason::Success, + properties: None, + }); + } + + // No properties len or properties if remaining len > 2 but < 4 + let ack_reason = read_u8(&mut bytes)?; + if fixed_header.remaining_len < 4 { + return Ok(PubAck { + pkid, + reason: reason(ack_reason)?, + properties: None, + }); + } + + let puback = PubAck { + pkid, + reason: reason(ack_reason)?, + properties: PubAckProperties::extract(&mut bytes)?, + }; + + Ok(puback) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + let len = self.len(); + buffer.put_u8(0x40); + + let count = write_remaining_length(buffer, len)?; + buffer.put_u16(self.pkid); + + // Reason code is optional with success if there are no properties + if self.reason == PubAckReason::Success && self.properties.is_none() { + return Ok(4); + } + + buffer.put_u8(self.reason as u8); + if let Some(properties) = &self.properties { + properties.write(buffer)?; + } + + Ok(1 + count + len) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct PubAckProperties { + pub reason_string: Option, + pub user_properties: Vec<(String, String)>, +} + +impl PubAckProperties { + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(reason) = &self.reason_string { + len += 1 + 2 + reason.len(); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut reason_string = None; + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::ReasonString => { + let reason = read_mqtt_string(&mut bytes)?; + cursor += 2 + reason.len(); + reason_string = Some(reason); + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(PubAckProperties { + reason_string, + user_properties, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(reason) = &self.reason_string { + buffer.put_u8(PropertyType::ReasonString as u8); + write_mqtt_string(buffer, reason); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } +} +/// Connection return code type +fn reason(num: u8) -> Result { + let code = match num { + 0 => PubAckReason::Success, + 16 => PubAckReason::NoMatchingSubscribers, + 128 => PubAckReason::UnspecifiedError, + 131 => PubAckReason::ImplementationSpecificError, + 135 => PubAckReason::NotAuthorized, + 144 => PubAckReason::TopicNameInvalid, + 145 => PubAckReason::PacketIdentifierInUse, + 151 => PubAckReason::QuotaExceeded, + 153 => PubAckReason::PayloadFormatInvalid, + num => return Err(Error::InvalidConnectReturnCode(num)), + }; + + Ok(code) +} + +#[cfg(v5)] +#[cfg(test)] +mod test { + use super::*; + use alloc::vec; + use bytes::BytesMut; + use pretty_assertions::assert_eq; + + fn sample() -> PubAck { + let properties = PubAckProperties { + reason_string: Some("test".to_owned()), + user_properties: vec![("test".to_owned(), "test".to_owned())], + }; + + PubAck { + pkid: 42, + reason: PubAckReason::NoMatchingSubscribers, + properties: Some(properties), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0x40, // payload type + 0x18, // remaining length + 0x00, 0x2a, // packet id + 0x10, // reason + 0x14, // properties len + 0x1f, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // reason_string + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // user properties + ] + } + + #[test] + fn puback_parsing_works() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let puback_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let puback = PubAck::read(fixed_header, puback_bytes).unwrap(); + assert_eq!(puback, sample()); + } + + #[test] + fn puback_encoding_works() { + let puback = sample(); + let mut buf = BytesMut::new(); + puback.write(&mut buf).unwrap(); + assert_eq!(&buf[..], sample_bytes()); + } + + fn sample2() -> PubAck { + PubAck { + pkid: 42, + reason: PubAckReason::NoMatchingSubscribers, + properties: None, + } + } + + fn sample2_bytes() -> Vec { + vec![0x40, 0x03, 0x00, 0x2a, 0x10] + } + + #[test] + fn puback2_parsing_works() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample2_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let puback_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let puback = PubAck::read(fixed_header, puback_bytes).unwrap(); + assert_eq!(puback, sample2()); + } + + #[test] + fn puback2_encoding_works() { + let puback = sample2(); + let mut buf = BytesMut::new(); + + puback.write(&mut buf).unwrap(); + assert_eq!(&buf[..], sample2_bytes()); + } + + fn sample3() -> PubAck { + PubAck { + pkid: 42, + reason: PubAckReason::Success, + properties: None, + } + } + + fn sample3_bytes() -> Vec { + vec![0x40, 0x02, 0x00, 0x2a] + } + + #[test] + fn puback3_parsing_works() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample3_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let puback_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let puback = PubAck::read(fixed_header, puback_bytes).unwrap(); + assert_eq!(puback, sample3()); + } + + #[test] + fn puback3_encoding_works() { + let puback = sample3(); + let mut buf = BytesMut::new(); + + puback.write(&mut buf).unwrap(); + assert_eq!(&buf[..], sample3_bytes()); + } +} diff --git a/rumqttc/src/v5/packet/pubcomp.rs b/rumqttc/src/v5/packet/pubcomp.rs new file mode 100644 index 000000000..badb97867 --- /dev/null +++ b/rumqttc/src/v5/packet/pubcomp.rs @@ -0,0 +1,237 @@ +use super::*; +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +/// Return code in connack +#[derive(Debug, Clone, Copy, PartialEq)] +#[repr(u8)] +pub enum PubCompReason { + Success = 0, + PacketIdentifierNotFound = 146, +} + +/// Acknowledgement to QoS1 publish +#[derive(Debug, Clone, PartialEq)] +pub struct PubComp { + pub pkid: u16, + pub reason: PubCompReason, + pub properties: Option, +} + +impl PubComp { + pub fn new(pkid: u16) -> PubComp { + PubComp { + pkid, + reason: PubCompReason::Success, + properties: None, + } + } + + fn len(&self) -> usize { + let mut len = 2 + 1; // pkid + reason + + // If there are no properties during success, sending reason code is optional + if self.reason == PubCompReason::Success && self.properties.is_none() { + return 2; + } + + if let Some(properties) = &self.properties { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + let pkid = read_u16(&mut bytes)?; + + if fixed_header.remaining_len == 2 { + return Ok(PubComp { + pkid, + reason: PubCompReason::Success, + properties: None, + }); + } + + let ack_reason = read_u8(&mut bytes)?; + if fixed_header.remaining_len < 4 { + return Ok(PubComp { + pkid, + reason: reason(ack_reason)?, + properties: None, + }); + } + + let puback = PubComp { + pkid, + reason: reason(ack_reason)?, + properties: PubCompProperties::extract(&mut bytes)?, + }; + + Ok(puback) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + let len = self.len(); + buffer.put_u8(0x70); + let count = write_remaining_length(buffer, len)?; + buffer.put_u16(self.pkid); + + // If there are no properties during success, sending reason code is optional + if self.reason == PubCompReason::Success && self.properties.is_none() { + return Ok(4); + } + + buffer.put_u8(self.reason as u8); + + if let Some(properties) = &self.properties { + properties.write(buffer)?; + } + + Ok(1 + count + len) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct PubCompProperties { + pub reason_string: Option, + pub user_properties: Vec<(String, String)>, +} + +impl PubCompProperties { + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(reason) = &self.reason_string { + len += 1 + 2 + reason.len(); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut reason_string = None; + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::ReasonString => { + let reason = read_mqtt_string(&mut bytes)?; + cursor += 2 + reason.len(); + reason_string = Some(reason); + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(PubCompProperties { + reason_string, + user_properties, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(reason) = &self.reason_string { + buffer.put_u8(PropertyType::ReasonString as u8); + write_mqtt_string(buffer, reason); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } +} +/// Connection return code type +fn reason(num: u8) -> Result { + let code = match num { + 0 => PubCompReason::Success, + 146 => PubCompReason::PacketIdentifierNotFound, + num => return Err(Error::InvalidConnectReturnCode(num)), + }; + + Ok(code) +} + +#[cfg(test)] +mod test { + use super::*; + use bytes::BytesMut; + use pretty_assertions::assert_eq; + + fn sample() -> PubComp { + let properties = PubCompProperties { + reason_string: Some("test".to_owned()), + user_properties: vec![("test".to_owned(), "test".to_owned())], + }; + + PubComp { + pkid: 42, + reason: PubCompReason::PacketIdentifierNotFound, + properties: Some(properties), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0x70, // payload type + 0x18, // remaining length + 0x00, 0x2a, // packet id + 0x92, // reason + 0x14, // properties len + 0x1f, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // reason_string + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // user properties + ] + } + + #[test] + fn pubcomp_parsing_works_correctly() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let pubcomp_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let pubcomp = PubComp::read(fixed_header, pubcomp_bytes).unwrap(); + assert_eq!(pubcomp, sample()); + } + + #[test] + fn pubcomp_encoding_works_correctly() { + let pubcomp = sample(); + let mut buf = BytesMut::new(); + pubcomp.write(&mut buf).unwrap(); + assert_eq!(&buf[..], sample_bytes()); + } +} diff --git a/rumqttc/src/v5/packet/publish.rs b/rumqttc/src/v5/packet/publish.rs new file mode 100644 index 000000000..9f0e228ea --- /dev/null +++ b/rumqttc/src/v5/packet/publish.rs @@ -0,0 +1,394 @@ +use super::*; +use bytes::{Buf, Bytes}; +use core::fmt; + +/// Publish packet +#[derive(Clone, PartialEq)] +pub struct Publish { + pub dup: bool, + pub qos: QoS, + pub retain: bool, + pub topic: String, + pub pkid: u16, + pub properties: Option, + pub payload: Bytes, +} + +impl Publish { + pub fn new, P: Into>>(topic: S, qos: QoS, payload: P) -> Publish { + Publish { + dup: false, + qos, + retain: false, + pkid: 0, + topic: topic.into(), + properties: None, + payload: Bytes::from(payload.into()), + } + } + + pub fn from_bytes>(topic: S, qos: QoS, payload: Bytes) -> Publish { + Publish { + dup: false, + qos, + retain: false, + pkid: 0, + topic: topic.into(), + properties: None, + payload, + } + } + + pub fn len(&self) -> usize { + let mut len = 2 + self.topic.len(); + if self.qos != QoS::AtMostOnce && self.pkid != 0 { + len += 2; + } + + match &self.properties { + Some(properties) => { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + None => { + // just 1 byte representing 0 len + len += 1; + } + } + + len += self.payload.len(); + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let qos = qos((fixed_header.byte1 & 0b0110) >> 1)?; + let dup = (fixed_header.byte1 & 0b1000) != 0; + let retain = (fixed_header.byte1 & 0b0001) != 0; + + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + let topic = read_mqtt_string(&mut bytes)?; + + // Packet identifier exists where QoS > 0 + let pkid = match qos { + QoS::AtMostOnce => 0, + QoS::AtLeastOnce | QoS::ExactlyOnce => read_u16(&mut bytes)?, + }; + + if qos != QoS::AtMostOnce && pkid == 0 { + return Err(Error::PacketIdZero); + } + + let publish = Publish { + dup, + retain, + qos, + pkid, + topic, + properties: PublishProperties::extract(&mut bytes)?, + payload: bytes, + }; + + Ok(publish) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + let len = self.len(); + + let dup = self.dup as u8; + let qos = self.qos as u8; + let retain = self.retain as u8; + buffer.put_u8(0b0011_0000 | retain | qos << 1 | dup << 3); + + let count = write_remaining_length(buffer, len)?; + write_mqtt_string(buffer, self.topic.as_str()); + + if self.qos != QoS::AtMostOnce { + let pkid = self.pkid; + if pkid == 0 { + return Err(Error::PacketIdZero); + } + + buffer.put_u16(pkid); + } + + match &self.properties { + Some(properties) => properties.write(buffer)?, + None => { + write_remaining_length(buffer, 0)?; + } + }; + + buffer.extend_from_slice(&self.payload); + + // TODO: Returned length is wrong in other packets. Fix it + Ok(1 + count + len) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct PublishProperties { + pub payload_format_indicator: Option, + pub message_expiry_interval: Option, + pub topic_alias: Option, + pub response_topic: Option, + pub correlation_data: Option, + pub user_properties: Vec<(String, String)>, + pub subscription_identifiers: Vec, + pub content_type: Option, +} + +impl PublishProperties { + fn len(&self) -> usize { + let mut len = 0; + + if self.payload_format_indicator.is_some() { + len += 1 + 1; + } + + if self.message_expiry_interval.is_some() { + len += 1 + 4; + } + + if self.topic_alias.is_some() { + len += 1 + 2; + } + + if let Some(topic) = &self.response_topic { + len += 1 + 2 + topic.len() + } + + if let Some(data) = &self.correlation_data { + len += 1 + 2 + data.len() + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + for id in self.subscription_identifiers.iter() { + len += 1 + len_len(*id); + } + + if let Some(typ) = &self.content_type { + len += 1 + 2 + typ.len() + } + + len + } + + fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut payload_format_indicator = None; + let mut message_expiry_interval = None; + let mut topic_alias = None; + let mut response_topic = None; + let mut correlation_data = None; + let mut user_properties = Vec::new(); + let mut subscription_identifiers = Vec::new(); + let mut content_type = None; + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::PayloadFormatIndicator => { + payload_format_indicator = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::MessageExpiryInterval => { + message_expiry_interval = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::TopicAlias => { + topic_alias = Some(read_u16(&mut bytes)?); + cursor += 2; + } + PropertyType::ResponseTopic => { + let topic = read_mqtt_string(&mut bytes)?; + cursor += 2 + topic.len(); + response_topic = Some(topic); + } + PropertyType::CorrelationData => { + let data = read_mqtt_bytes(&mut bytes)?; + cursor += 2 + data.len(); + correlation_data = Some(data); + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + PropertyType::SubscriptionIdentifier => { + let (id_len, id) = length(bytes.iter())?; + cursor += 1 + id_len; + bytes.advance(id_len); + subscription_identifiers.push(id); + } + PropertyType::ContentType => { + let typ = read_mqtt_string(&mut bytes)?; + cursor += 2 + typ.len(); + content_type = Some(typ); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(PublishProperties { + payload_format_indicator, + message_expiry_interval, + topic_alias, + response_topic, + correlation_data, + user_properties, + subscription_identifiers, + content_type, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(payload_format_indicator) = self.payload_format_indicator { + buffer.put_u8(PropertyType::PayloadFormatIndicator as u8); + buffer.put_u8(payload_format_indicator); + } + + if let Some(message_expiry_interval) = self.message_expiry_interval { + buffer.put_u8(PropertyType::MessageExpiryInterval as u8); + buffer.put_u32(message_expiry_interval); + } + + if let Some(topic_alias) = self.topic_alias { + buffer.put_u8(PropertyType::TopicAlias as u8); + buffer.put_u16(topic_alias); + } + + if let Some(topic) = &self.response_topic { + buffer.put_u8(PropertyType::ResponseTopic as u8); + write_mqtt_string(buffer, topic); + } + + if let Some(data) = &self.correlation_data { + buffer.put_u8(PropertyType::CorrelationData as u8); + write_mqtt_bytes(buffer, data); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + for id in self.subscription_identifiers.iter() { + buffer.put_u8(PropertyType::SubscriptionIdentifier as u8); + write_remaining_length(buffer, *id)?; + } + + if let Some(typ) = &self.content_type { + buffer.put_u8(PropertyType::ContentType as u8); + write_mqtt_string(buffer, typ); + } + + Ok(()) + } +} + +impl fmt::Debug for Publish { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "Topic = {}, Qos = {:?}, Retain = {}, Pkid = {:?}, Payload Size = {}", + self.topic, + self.qos, + self.retain, + self.pkid, + self.payload.len() + ) + } +} + +#[cfg(test)] +mod test { + use super::*; + use bytes::{Bytes, BytesMut}; + use pretty_assertions::assert_eq; + + fn sample_v5() -> Publish { + let publish_properties = PublishProperties { + payload_format_indicator: Some(1), + message_expiry_interval: Some(4321), + topic_alias: Some(100), + response_topic: Some("topic".to_owned()), + correlation_data: Some(Bytes::from(vec![1, 2, 3, 4])), + user_properties: vec![("test".to_owned(), "test".to_owned())], + subscription_identifiers: vec![120, 121], + content_type: Some("test".to_owned()), + }; + + Publish { + dup: false, + qos: QoS::ExactlyOnce, + retain: false, + topic: "test".to_string(), + pkid: 42, + properties: Some(publish_properties), + payload: Bytes::from(vec![b't', b'e', b's', b't']), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0x34, // payload type + 0x3e, // remaining len + 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // topic name + 0x00, 0x2a, // pkid + 0x31, // properties len + 0x01, 0x01, // payload format indicator + 0x02, 0x00, 0x00, 0x10, 0xe1, // message_expiry_interval + 0x23, 0x00, 0x64, // topic alias + 0x08, 0x00, 0x05, 0x74, 0x6f, 0x70, 0x69, 0x63, // response topic + 0x09, 0x00, 0x04, 0x01, 0x02, 0x03, 0x04, // correlation_data + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // user properties + 0x0b, 0x78, // subscription_identifier + 0x0b, 0x79, // subscription_identifier + 0x03, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // content_type + 0x74, 0x65, 0x73, 0x74, // payload + ] + } + + #[test] + fn publish_parsing_works() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let publish_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let publish = Publish::read(fixed_header, publish_bytes).unwrap(); + assert_eq!(publish, sample_v5()); + } + + #[test] + fn publish_encoding_works() { + let publish = sample_v5(); + let mut buf = BytesMut::new(); + publish.write(&mut buf).unwrap(); + assert_eq!(&buf[..], sample_bytes()); + } + + #[test] + fn missing_properties_are_encoded() {} +} diff --git a/rumqttc/src/v5/packet/pubrec.rs b/rumqttc/src/v5/packet/pubrec.rs new file mode 100644 index 000000000..5e8de572e --- /dev/null +++ b/rumqttc/src/v5/packet/pubrec.rs @@ -0,0 +1,252 @@ +use super::*; +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +/// Return code in connack +#[derive(Debug, Clone, Copy, PartialEq)] +#[repr(u8)] +pub enum PubRecReason { + Success = 0, + NoMatchingSubscribers = 16, + UnspecifiedError = 128, + ImplementationSpecificError = 131, + NotAuthorized = 135, + TopicNameInvalid = 144, + PacketIdentifierInUse = 145, + QuotaExceeded = 151, + PayloadFormatInvalid = 153, +} + +/// Acknowledgement to QoS1 publish +#[derive(Debug, Clone, PartialEq)] +pub struct PubRec { + pub pkid: u16, + pub reason: PubRecReason, + pub properties: Option, +} + +impl PubRec { + pub fn new(pkid: u16) -> PubRec { + PubRec { + pkid, + reason: PubRecReason::Success, + properties: None, + } + } + + fn len(&self) -> usize { + let mut len = 2 + 1; // pkid + reason + + // If there are no properties during success, sending reason code is optional + if self.reason == PubRecReason::Success && self.properties.is_none() { + return 2; + } + + if let Some(properties) = &self.properties { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + + // Unlike other packets, property length can be ignored if there are + // no properties in acks + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + let pkid = read_u16(&mut bytes)?; + if fixed_header.remaining_len == 2 { + return Ok(PubRec { + pkid, + reason: PubRecReason::Success, + properties: None, + }); + } + + let ack_reason = read_u8(&mut bytes)?; + if fixed_header.remaining_len < 4 { + return Ok(PubRec { + pkid, + reason: reason(ack_reason)?, + properties: None, + }); + } + + let puback = PubRec { + pkid, + reason: reason(ack_reason)?, + properties: PubRecProperties::extract(&mut bytes)?, + }; + + Ok(puback) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + let len = self.len(); + buffer.put_u8(0x50); + let count = write_remaining_length(buffer, len)?; + buffer.put_u16(self.pkid); + + // If there are no properties during success, sending reason code is optional + if self.reason == PubRecReason::Success && self.properties.is_none() { + return Ok(4); + } + + buffer.put_u8(self.reason as u8); + + if let Some(properties) = &self.properties { + properties.write(buffer)?; + } + + Ok(1 + count + len) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct PubRecProperties { + pub reason_string: Option, + pub user_properties: Vec<(String, String)>, +} + +impl PubRecProperties { + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(reason) = &self.reason_string { + len += 1 + 2 + reason.len(); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut reason_string = None; + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::ReasonString => { + let reason = read_mqtt_string(&mut bytes)?; + cursor += 2 + reason.len(); + reason_string = Some(reason); + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(PubRecProperties { + reason_string, + user_properties, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(reason) = &self.reason_string { + buffer.put_u8(PropertyType::ReasonString as u8); + write_mqtt_string(buffer, reason); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } +} +/// Connection return code type +fn reason(num: u8) -> Result { + let code = match num { + 0 => PubRecReason::Success, + 16 => PubRecReason::NoMatchingSubscribers, + 128 => PubRecReason::UnspecifiedError, + 131 => PubRecReason::ImplementationSpecificError, + 135 => PubRecReason::NotAuthorized, + 144 => PubRecReason::TopicNameInvalid, + 145 => PubRecReason::PacketIdentifierInUse, + 151 => PubRecReason::QuotaExceeded, + 153 => PubRecReason::PayloadFormatInvalid, + num => return Err(Error::InvalidConnectReturnCode(num)), + }; + + Ok(code) +} + +#[cfg(test)] +mod test { + use super::*; + use bytes::BytesMut; + use pretty_assertions::assert_eq; + + fn sample() -> PubRec { + let properties = PubRecProperties { + reason_string: Some("test".to_owned()), + user_properties: vec![("test".to_owned(), "test".to_owned())], + }; + + PubRec { + pkid: 42, + reason: PubRecReason::NoMatchingSubscribers, + properties: Some(properties), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0x50, // payload type + 0x18, // remaining length + 0x00, 0x2a, // packet id + 0x10, // reason + 0x14, // properties len + 0x1f, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // reason_string + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // user properties + ] + } + + #[test] + fn pubrec_parsing_works() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let pubrec_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let pubrec = PubRec::read(fixed_header, pubrec_bytes).unwrap(); + assert_eq!(pubrec, sample()); + } + + #[test] + fn pubrec_encoding_works() { + let pubrec = sample(); + let mut buf = BytesMut::new(); + pubrec.write(&mut buf).unwrap(); + assert_eq!(&buf[..], sample_bytes()); + } +} diff --git a/rumqttc/src/v5/packet/pubrel.rs b/rumqttc/src/v5/packet/pubrel.rs new file mode 100644 index 000000000..1a1a62e4d --- /dev/null +++ b/rumqttc/src/v5/packet/pubrel.rs @@ -0,0 +1,236 @@ +use super::*; +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +/// Return code in connack +#[derive(Debug, Clone, Copy, PartialEq)] +#[repr(u8)] +pub enum PubRelReason { + Success = 0, + PacketIdentifierNotFound = 146, +} + +/// Acknowledgement to QoS1 publish +#[derive(Debug, Clone, PartialEq)] +pub struct PubRel { + pub pkid: u16, + pub reason: PubRelReason, + pub properties: Option, +} + +impl PubRel { + pub fn new(pkid: u16) -> PubRel { + PubRel { + pkid, + reason: PubRelReason::Success, + properties: None, + } + } + + fn len(&self) -> usize { + let mut len = 2 + 1; // pkid + reason + + // If there are no properties during success, sending reason code is optional + if self.reason == PubRelReason::Success && self.properties.is_none() { + return 2; + } + + if let Some(properties) = &self.properties { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + let pkid = read_u16(&mut bytes)?; + if fixed_header.remaining_len == 2 { + return Ok(PubRel { + pkid, + reason: PubRelReason::Success, + properties: None, + }); + } + + let ack_reason = read_u8(&mut bytes)?; + if fixed_header.remaining_len < 4 { + return Ok(PubRel { + pkid, + reason: reason(ack_reason)?, + properties: None, + }); + } + + let puback = PubRel { + pkid, + reason: reason(ack_reason)?, + properties: PubRelProperties::extract(&mut bytes)?, + }; + + Ok(puback) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + let len = self.len(); + buffer.put_u8(0x62); + let count = write_remaining_length(buffer, len)?; + buffer.put_u16(self.pkid); + + // If there are no properties during success, sending reason code is optional + if self.reason == PubRelReason::Success && self.properties.is_none() { + return Ok(4); + } + + buffer.put_u8(self.reason as u8); + + if let Some(properties) = &self.properties { + properties.write(buffer)?; + } + + Ok(1 + count + len) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct PubRelProperties { + pub reason_string: Option, + pub user_properties: Vec<(String, String)>, +} + +impl PubRelProperties { + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(reason) = &self.reason_string { + len += 1 + 2 + reason.len(); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut reason_string = None; + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::ReasonString => { + let reason = read_mqtt_string(&mut bytes)?; + cursor += 2 + reason.len(); + reason_string = Some(reason); + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(PubRelProperties { + reason_string, + user_properties, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(reason) = &self.reason_string { + buffer.put_u8(PropertyType::ReasonString as u8); + write_mqtt_string(buffer, reason); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } +} +/// Connection return code type +fn reason(num: u8) -> Result { + let code = match num { + 0 => PubRelReason::Success, + 146 => PubRelReason::PacketIdentifierNotFound, + num => return Err(Error::InvalidConnectReturnCode(num)), + }; + + Ok(code) +} + +#[cfg(test)] +mod test { + use super::*; + use bytes::BytesMut; + use pretty_assertions::assert_eq; + + fn sample() -> PubRel { + let properties = PubRelProperties { + reason_string: Some("test".to_owned()), + user_properties: vec![("test".to_owned(), "test".to_owned())], + }; + + PubRel { + pkid: 42, + reason: PubRelReason::PacketIdentifierNotFound, + properties: Some(properties), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0x62, // payload type + 0x18, // remaining length + 0x00, 0x2a, // packet id + 0x92, // reason + 0x14, // properties len + 0x1f, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // reason_string + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // user properties + ] + } + + #[test] + fn pubrel_parsing_works() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let pubrel_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let pubrel = PubRel::read(fixed_header, pubrel_bytes).unwrap(); + assert_eq!(pubrel, sample()); + } + + #[test] + fn pubrel_encoding_works() { + let pubrel = sample(); + let mut buf = BytesMut::new(); + pubrel.write(&mut buf).unwrap(); + assert_eq!(&buf[..], sample_bytes()); + } +} diff --git a/rumqttc/src/v5/packet/suback.rs b/rumqttc/src/v5/packet/suback.rs new file mode 100644 index 000000000..0ec0ead05 --- /dev/null +++ b/rumqttc/src/v5/packet/suback.rs @@ -0,0 +1,263 @@ +use super::*; +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use std::convert::{TryFrom, TryInto}; + +/// Acknowledgement to subscribe +#[derive(Debug, Clone, PartialEq)] +pub struct SubAck { + pub pkid: u16, + pub return_codes: Vec, + pub properties: Option, +} + +impl SubAck { + pub fn new(pkid: u16, return_codes: Vec) -> SubAck { + SubAck { + pkid, + return_codes, + properties: None, + } + } + + pub fn len(&self) -> usize { + let mut len = 2 + self.return_codes.len(); + + match &self.properties { + Some(properties) => { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + None => { + // just 1 byte representing 0 len + len += 1; + } + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + let pkid = read_u16(&mut bytes)?; + let properties = SubAckProperties::extract(&mut bytes)?; + + if !bytes.has_remaining() { + return Err(Error::MalformedPacket); + } + + let mut return_codes = Vec::new(); + while bytes.has_remaining() { + let return_code = read_u8(&mut bytes)?; + return_codes.push(return_code.try_into()?); + } + + let suback = SubAck { + pkid, + return_codes, + properties, + }; + + Ok(suback) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + buffer.put_u8(0x90); + let remaining_len = self.len(); + let remaining_len_bytes = write_remaining_length(buffer, remaining_len)?; + + buffer.put_u16(self.pkid); + + match &self.properties { + Some(properties) => properties.write(buffer)?, + None => { + write_remaining_length(buffer, 0)?; + } + }; + + let p: Vec = self.return_codes.iter().map(|code| *code as u8).collect(); + buffer.extend_from_slice(&p); + Ok(1 + remaining_len_bytes + remaining_len) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct SubAckProperties { + pub reason_string: Option, + pub user_properties: Vec<(String, String)>, +} + +impl SubAckProperties { + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(reason) = &self.reason_string { + len += 1 + 2 + reason.len(); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut reason_string = None; + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::ReasonString => { + let reason = read_mqtt_string(&mut bytes)?; + cursor += 2 + reason.len(); + reason_string = Some(reason); + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(SubAckProperties { + reason_string, + user_properties, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(reason) = &self.reason_string { + buffer.put_u8(PropertyType::ReasonString as u8); + write_mqtt_string(buffer, reason); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } +} + +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SubscribeReasonCode { + QoS0 = 0, + QoS1 = 1, + QoS2 = 2, + Unspecified = 128, + ImplementationSpecific = 131, + NotAuthorized = 135, + TopicFilterInvalid = 143, + PkidInUse = 145, + QuotaExceeded = 151, + SharedSubscriptionsNotSupported = 158, + SubscriptionIdNotSupported = 161, + WildcardSubscriptionsNotSupported = 162, +} + +impl TryFrom for SubscribeReasonCode { + type Error = super::Error; + + fn try_from(value: u8) -> Result { + let v = match value { + 0 => SubscribeReasonCode::QoS0, + 1 => SubscribeReasonCode::QoS1, + 2 => SubscribeReasonCode::QoS2, + 128 => SubscribeReasonCode::Unspecified, + 131 => SubscribeReasonCode::ImplementationSpecific, + 135 => SubscribeReasonCode::NotAuthorized, + 143 => SubscribeReasonCode::TopicFilterInvalid, + 145 => SubscribeReasonCode::PkidInUse, + 151 => SubscribeReasonCode::QuotaExceeded, + 158 => SubscribeReasonCode::SharedSubscriptionsNotSupported, + 161 => SubscribeReasonCode::SubscriptionIdNotSupported, + 162 => SubscribeReasonCode::WildcardSubscriptionsNotSupported, + v => return Err(super::Error::InvalidSubscribeReasonCode(v)), + }; + + Ok(v) + } +} + +#[cfg(test)] +mod test { + use super::*; + use bytes::BytesMut; + use pretty_assertions::assert_eq; + + fn sample() -> SubAck { + let properties = SubAckProperties { + reason_string: Some("test".to_owned()), + user_properties: vec![("test".to_owned(), "test".to_owned())], + }; + + SubAck { + pkid: 42, + return_codes: vec![ + SubscribeReasonCode::QoS0, + SubscribeReasonCode::QoS1, + SubscribeReasonCode::QoS2, + SubscribeReasonCode::Unspecified, + ], + properties: Some(properties), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0x90, // packet type + 0x1b, // remaining len + 0x00, 0x2a, // pkid + 0x14, // properties len + 0x1f, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, + 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // user properties + 0x00, 0x01, 0x02, 0x80, // return codes + ] + } + + #[test] + fn suback_parsing_works() { + let mut stream = BytesMut::new(); + let packetstream = &sample_bytes(); + + stream.extend_from_slice(&packetstream[..]); + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let suback_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let suback = SubAck::read(fixed_header, suback_bytes).unwrap(); + assert_eq!(suback, sample()); + } + + #[test] + fn suback_encoding_works() { + let publish = sample(); + let mut buf = BytesMut::new(); + publish.write(&mut buf).unwrap(); + + // println!("{:X?}", buf); + // println!("{:#04X?}", &buf[..]); + assert_eq!(&buf[..], sample_bytes()); + } +} diff --git a/rumqttc/src/v5/packet/subscribe.rs b/rumqttc/src/v5/packet/subscribe.rs new file mode 100644 index 000000000..9a5c43d1e --- /dev/null +++ b/rumqttc/src/v5/packet/subscribe.rs @@ -0,0 +1,425 @@ +use super::*; +use bytes::{Buf, Bytes}; +use core::fmt; + +/// Subscription packet +#[derive(Clone, PartialEq)] +pub struct Subscribe { + pub pkid: u16, + pub filters: Vec, + pub properties: Option, +} + +impl Subscribe { + pub fn new>(path: S, qos: QoS) -> Subscribe { + let filter = SubscribeFilter { + path: path.into(), + qos, + nolocal: false, + preserve_retain: false, + retain_forward_rule: RetainForwardRule::OnEverySubscribe, + }; + + Subscribe { + pkid: 0, + filters: vec![filter], + properties: None, + } + } + + pub fn new_many(topics: T) -> Subscribe + where + T: IntoIterator, + { + Subscribe { + pkid: 0, + filters: topics.into_iter().collect(), + properties: None, + } + } + + pub fn empty_subscribe() -> Subscribe { + Subscribe { + pkid: 0, + filters: Vec::new(), + properties: None, + } + } + + pub fn add(&mut self, path: String, qos: QoS) -> &mut Self { + let filter = SubscribeFilter { + path, + qos, + nolocal: false, + preserve_retain: false, + retain_forward_rule: RetainForwardRule::OnEverySubscribe, + }; + + self.filters.push(filter); + self + } + + pub fn len(&self) -> usize { + let mut len = 2 + self.filters.iter().fold(0, |s, t| s + t.len()); + + if let Some(properties) = &self.properties { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } else { + // just 1 byte representing 0 len + len += 1; + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + let pkid = read_u16(&mut bytes)?; + let properties = SubscribeProperties::extract(&mut bytes)?; + + // variable header size = 2 (packet identifier) + let mut filters = Vec::new(); + + while bytes.has_remaining() { + let path = read_mqtt_string(&mut bytes)?; + let options = read_u8(&mut bytes)?; + let requested_qos = options & 0b0000_0011; + + let nolocal = options >> 2 & 0b0000_0001; + let nolocal = nolocal != 0; + + let preserve_retain = options >> 3 & 0b0000_0001; + let preserve_retain = preserve_retain != 0; + + let retain_forward_rule = (options >> 4) & 0b0000_0011; + let retain_forward_rule = match retain_forward_rule { + 0 => RetainForwardRule::OnEverySubscribe, + 1 => RetainForwardRule::OnNewSubscribe, + 2 => RetainForwardRule::Never, + r => return Err(Error::InvalidRetainForwardRule(r)), + }; + + filters.push(SubscribeFilter { + path, + qos: qos(requested_qos)?, + nolocal, + preserve_retain, + retain_forward_rule, + }); + } + + let subscribe = Subscribe { + pkid, + filters, + properties, + }; + + Ok(subscribe) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + // write packet type + buffer.put_u8(0x82); + + // write remaining length + let remaining_len = self.len(); + let remaining_len_bytes = write_remaining_length(buffer, remaining_len)?; + + // write packet id + buffer.put_u16(self.pkid); + + match &self.properties { + Some(properties) => properties.write(buffer)?, + None => { + write_remaining_length(buffer, 0)?; + } + }; + + // write filters + for filter in self.filters.iter() { + filter.write(buffer); + } + + Ok(1 + remaining_len_bytes + remaining_len) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct SubscribeProperties { + pub id: Option, + pub user_properties: Vec<(String, String)>, +} + +impl SubscribeProperties { + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(id) = &self.id { + len += 1 + len_len(*id); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut id = None; + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::SubscriptionIdentifier => { + let (id_len, sub_id) = length(bytes.iter())?; + // TODO: Validate 1 +. Tests are working either way + cursor += 1 + id_len; + bytes.advance(id_len); + id = Some(sub_id) + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(SubscribeProperties { + id, + user_properties, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(id) = &self.id { + buffer.put_u8(PropertyType::SubscriptionIdentifier as u8); + write_remaining_length(buffer, *id)?; + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } +} + +/// Subscription filter +#[derive(Clone, PartialEq)] +pub struct SubscribeFilter { + pub path: String, + pub qos: QoS, + pub nolocal: bool, + pub preserve_retain: bool, + pub retain_forward_rule: RetainForwardRule, +} + +impl SubscribeFilter { + pub fn new(path: String, qos: QoS) -> SubscribeFilter { + SubscribeFilter { + path, + qos, + nolocal: false, + preserve_retain: false, + retain_forward_rule: RetainForwardRule::OnEverySubscribe, + } + } + + pub fn set_nolocal(&mut self, flag: bool) -> &mut Self { + self.nolocal = flag; + self + } + + pub fn set_preserve_retain(&mut self, flag: bool) -> &mut Self { + self.preserve_retain = flag; + self + } + + pub fn set_retain_forward_rule(&mut self, rule: RetainForwardRule) -> &mut Self { + self.retain_forward_rule = rule; + self + } + + pub fn len(&self) -> usize { + // filter len + filter + options + 2 + self.path.len() + 1 + } + + fn write(&self, buffer: &mut BytesMut) { + let mut options = 0; + options |= self.qos as u8; + + if self.nolocal { + options |= 1 << 2; + } + + if self.preserve_retain { + options |= 1 << 3; + } + + match self.retain_forward_rule { + RetainForwardRule::OnEverySubscribe => options |= 0 << 4, + RetainForwardRule::OnNewSubscribe => options |= 1 << 4, + RetainForwardRule::Never => options |= 2 << 4, + } + + write_mqtt_string(buffer, self.path.as_str()); + buffer.put_u8(options); + } +} + +#[derive(Debug, Clone, PartialEq)] +pub enum RetainForwardRule { + OnEverySubscribe, + OnNewSubscribe, + Never, +} + +impl fmt::Debug for Subscribe { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "Filters = {:?}, Packet id = {:?}", + self.filters, self.pkid + ) + } +} + +impl fmt::Debug for SubscribeFilter { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "Filter = {}, Qos = {:?}, Nolocal = {}, Preserve retain = {}, Forward rule = {:?}", + self.path, self.qos, self.nolocal, self.preserve_retain, self.retain_forward_rule + ) + } +} + +#[cfg(test)] +mod test { + use super::*; + use bytes::BytesMut; + use pretty_assertions::assert_eq; + + fn sample() -> Subscribe { + let subscribe_properties = SubscribeProperties { + id: Some(100), + user_properties: vec![("test".to_owned(), "test".to_owned())], + }; + + let mut filter = SubscribeFilter::new("hello".to_owned(), QoS::AtLeastOnce); + filter + .set_nolocal(true) + .set_preserve_retain(true) + .set_retain_forward_rule(RetainForwardRule::Never); + + Subscribe { + pkid: 42, + filters: vec![filter], + properties: Some(subscribe_properties), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0x82, // packet type + 0x1a, // remaining length + 0x00, 0x2a, // pkid + 0x0f, // properties len + 0x0b, 0x64, // subscription identifier + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // user properties + 0x00, 0x05, 0x68, 0x65, 0x6c, 0x6c, 0x6f, // filter + 0x2d, // options + ] + } + + #[test] + fn subscribe_parsing_works() { + let mut stream = BytesMut::new(); + let packetstream = &sample_bytes(); + + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let subscribe_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let subscribe = Subscribe::read(fixed_header, subscribe_bytes).unwrap(); + assert_eq!(subscribe, sample()); + } + + #[test] + fn subscribe_encoding_works() { + let publish = sample(); + let mut buf = BytesMut::new(); + publish.write(&mut buf).unwrap(); + + // println!("{:X?}", buf); + // println!("{:#04X?}", &buf[..]); + assert_eq!(&buf[..], sample_bytes()); + } + + fn sample2() -> Subscribe { + let filter = SubscribeFilter::new("hello/world".to_owned(), QoS::AtLeastOnce); + Subscribe { + pkid: 42, + filters: vec![filter], + properties: None, + } + } + + fn sample2_bytes() -> Vec { + vec![ + 0x82, 0x11, 0x00, 0x2a, 0x00, 0x00, 0x0b, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x2f, 0x77, + 0x6f, 0x72, 0x6c, 0x64, 0x01, + ] + } + + #[test] + fn subscribe2_parsing_works() { + let mut stream = BytesMut::new(); + let packetstream = &sample2_bytes(); + + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let subscribe_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let subscribe = Subscribe::read(fixed_header, subscribe_bytes).unwrap(); + assert_eq!(subscribe, sample2()); + } + + #[test] + fn subscribe2_encoding_works() { + let publish = sample2(); + let mut buf = BytesMut::new(); + publish.write(&mut buf).unwrap(); + + // println!("{:X?}", buf); + // println!("{:#04X?}", &buf[..]); + assert_eq!(&buf[..], sample2_bytes()); + } +} diff --git a/rumqttc/src/v5/packet/unsuback.rs b/rumqttc/src/v5/packet/unsuback.rs new file mode 100644 index 000000000..ce01bb4cd --- /dev/null +++ b/rumqttc/src/v5/packet/unsuback.rs @@ -0,0 +1,249 @@ +use super::*; +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +//// Return code in connack +#[derive(Debug, Clone, Copy, PartialEq)] +#[repr(u8)] +pub enum UnsubAckReason { + Success = 0x00, + NoSubscriptionExisted = 0x11, + UnspecifiedError = 0x80, + ImplementationSpecificError = 0x83, + NotAuthorized = 0x87, + TopicFilterInvalid = 0x8F, + PacketIdentifierInUse = 0x91, +} + +/// Acknowledgement to unsubscribe +#[derive(Debug, Clone, PartialEq)] +pub struct UnsubAck { + pub pkid: u16, + pub reasons: Vec, + pub properties: Option, +} + +impl UnsubAck { + pub fn new(pkid: u16) -> UnsubAck { + UnsubAck { + pkid, + reasons: Vec::new(), + properties: None, + } + } + + pub fn len(&self) -> usize { + let mut len = 2 + self.reasons.len(); + + match &self.properties { + Some(properties) => { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + None => { + // just 1 byte representing 0 len + len += 1; + } + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + let pkid = read_u16(&mut bytes)?; + let properties = UnsubAckProperties::extract(&mut bytes)?; + + if !bytes.has_remaining() { + return Err(Error::MalformedPacket); + } + + let mut reasons = Vec::new(); + while bytes.has_remaining() { + let r = read_u8(&mut bytes)?; + reasons.push(reason(r)?); + } + + let unsuback = UnsubAck { + pkid, + reasons, + properties, + }; + + Ok(unsuback) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + buffer.put_u8(0xB0); + let remaining_len = self.len(); + let remaining_len_bytes = write_remaining_length(buffer, remaining_len)?; + + buffer.put_u16(self.pkid); + + match &self.properties { + Some(properties) => properties.write(buffer)?, + None => { + write_remaining_length(buffer, 0)?; + } + }; + + let p: Vec = self.reasons.iter().map(|code| *code as u8).collect(); + buffer.extend_from_slice(&p); + Ok(1 + remaining_len_bytes + remaining_len) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct UnsubAckProperties { + pub reason_string: Option, + pub user_properties: Vec<(String, String)>, +} + +impl UnsubAckProperties { + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(reason) = &self.reason_string { + len += 1 + 2 + reason.len(); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut reason_string = None; + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::ReasonString => { + let reason = read_mqtt_string(&mut bytes)?; + cursor += 2 + reason.len(); + reason_string = Some(reason); + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(UnsubAckProperties { + reason_string, + user_properties, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(reason) = &self.reason_string { + buffer.put_u8(PropertyType::ReasonString as u8); + write_mqtt_string(buffer, reason); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } +} + +/// Connection return code type +fn reason(num: u8) -> Result { + let code = match num { + 0x00 => UnsubAckReason::Success, + 0x11 => UnsubAckReason::NoSubscriptionExisted, + 0x80 => UnsubAckReason::UnspecifiedError, + 0x83 => UnsubAckReason::ImplementationSpecificError, + 0x87 => UnsubAckReason::NotAuthorized, + 0x8F => UnsubAckReason::TopicFilterInvalid, + 0x91 => UnsubAckReason::PacketIdentifierInUse, + num => return Err(Error::InvalidSubscribeReasonCode(num)), + }; + + Ok(code) +} + +#[cfg(test)] +mod test { + use super::*; + use bytes::BytesMut; + use pretty_assertions::assert_eq; + + fn sample() -> UnsubAck { + let properties = UnsubAckProperties { + reason_string: Some("test".to_owned()), + user_properties: vec![("test".to_owned(), "test".to_owned())], + }; + + UnsubAck { + pkid: 10, + reasons: vec![ + UnsubAckReason::NotAuthorized, + UnsubAckReason::TopicFilterInvalid, + ], + properties: Some(properties), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0xb0, // packet type + 0x19, // remaining len + 0x00, 0x0a, // pkid + 0x14, // properties len + 0x1f, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // reason string + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // user properties + 0x87, 0x8f, // reasons + ] + } + + #[test] + fn unsuback_parsing_works() { + let mut stream = BytesMut::new(); + let packetstream = &sample_bytes(); + + stream.extend_from_slice(&packetstream[..]); + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let unsuback_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let unsuback = UnsubAck::read(fixed_header, unsuback_bytes).unwrap(); + assert_eq!(unsuback, sample()); + } + + #[test] + fn unsuback_encoding_works() { + let publish = sample(); + let mut buf = BytesMut::new(); + publish.write(&mut buf).unwrap(); + + // println!("{:X?}", buf); + // println!("{:#04X?}", &buf[..]); + assert_eq!(&buf[..], sample_bytes()); + } +} diff --git a/rumqttc/src/v5/packet/unsubscribe.rs b/rumqttc/src/v5/packet/unsubscribe.rs new file mode 100644 index 000000000..8700c5132 --- /dev/null +++ b/rumqttc/src/v5/packet/unsubscribe.rs @@ -0,0 +1,238 @@ +use super::*; +use bytes::{Buf, Bytes}; + +/// Unsubscribe packet +#[derive(Debug, Clone, PartialEq)] +pub struct Unsubscribe { + pub pkid: u16, + pub filters: Vec, + pub properties: Option, +} + +impl Unsubscribe { + pub fn new>(topic: S) -> Unsubscribe { + Unsubscribe { + pkid: 0, + filters: vec![topic.into()], + properties: None, + } + } + + pub fn len(&self) -> usize { + // Packet id + length of filters (unlike subscribe, this just a string. + // Hence 2 is prefixed for len per filter) + let mut len = 2 + self.filters.iter().fold(0, |s, t| 2 + s + t.len()); + + if let Some(properties) = &self.properties { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } else { + // just 1 byte representing 0 len + len += 1; + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + let pkid = read_u16(&mut bytes)?; + dbg!(pkid); + let properties = UnsubscribeProperties::extract(&mut bytes)?; + + let mut filters = Vec::with_capacity(1); + while bytes.has_remaining() { + let filter = read_mqtt_string(&mut bytes)?; + filters.push(filter); + } + + let unsubscribe = Unsubscribe { + pkid, + filters, + properties, + }; + Ok(unsubscribe) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + buffer.put_u8(0xA2); + + // write remaining length + let remaining_len = self.len(); + let remaining_len_bytes = write_remaining_length(buffer, remaining_len)?; + + // write packet id + buffer.put_u16(self.pkid); + + match &self.properties { + Some(properties) => properties.write(buffer)?, + None => { + write_remaining_length(buffer, 0)?; + } + }; + + // write filters + for filter in self.filters.iter() { + write_mqtt_string(buffer, filter); + } + + Ok(1 + remaining_len_bytes + remaining_len) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct UnsubscribeProperties { + pub user_properties: Vec<(String, String)>, +} + +impl UnsubscribeProperties { + fn len(&self) -> usize { + let mut len = 0; + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(UnsubscribeProperties { user_properties })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use super::*; + use bytes::BytesMut; + use pretty_assertions::assert_eq; + + fn sample() -> Unsubscribe { + let properties = UnsubscribeProperties { + user_properties: vec![("test".to_owned(), "test".to_owned())], + }; + + Unsubscribe { + pkid: 10, + filters: vec!["hello".to_owned(), "world".to_owned()], + properties: Some(properties), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0xa2, // packet type + 0x1e, // remaining len + 0x00, 0x0a, // pkid + 0x0d, // properties len + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // user properties + 0x00, 0x05, 0x68, 0x65, 0x6c, 0x6c, 0x6f, // filter 1 + 0x00, 0x05, 0x77, 0x6f, 0x72, 0x6c, 0x64, // filter 2 + ] + } + + #[test] + fn unsubscribe_parsing_works() { + let mut stream = BytesMut::new(); + let packetstream = &sample_bytes(); + + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let unsubscribe_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let unsubscribe = Unsubscribe::read(fixed_header, unsubscribe_bytes).unwrap(); + assert_eq!(unsubscribe, sample()); + } + + #[test] + fn subscribe_encoding_works() { + let publish = sample(); + let mut buf = BytesMut::new(); + publish.write(&mut buf).unwrap(); + + println!("{:X?}", buf); + println!("{:#04X?}", &buf[..]); + assert_eq!(&buf[..], sample_bytes()); + } + + fn sample2() -> Unsubscribe { + Unsubscribe { + pkid: 10, + filters: vec!["hello".to_owned()], + properties: None, + } + } + + fn sample2_bytes() -> Vec { + vec![ + 0xa2, 0x0a, 0x00, 0x0a, 0x00, 0x00, 0x05, 0x68, 0x65, 0x6c, 0x6c, 0x6f, + ] + } + + #[test] + fn subscribe2_parsing_works() { + let mut stream = BytesMut::new(); + let packetstream = &sample2_bytes(); + + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let subscribe_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let subscribe = Unsubscribe::read(fixed_header, subscribe_bytes).unwrap(); + assert_eq!(subscribe, sample2()); + } + + #[test] + fn subscribe2_encoding_works() { + let publish = sample2(); + let mut buf = BytesMut::new(); + publish.write(&mut buf).unwrap(); + + // println!("{:X?}", buf); + // println!("{:#04X?}", &buf[..]); + assert_eq!(&buf[..], sample2_bytes()); + } +} From bff5bbd24c800b067d885ae45b0243d736865fb0 Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Fri, 11 Feb 2022 18:26:02 +0530 Subject: [PATCH 04/41] rumqttc: v5: use mqttbytes/v5 Signed-off-by: Abhik Jain --- rumqttc/src/v5/client.rs | 4 ++-- rumqttc/src/v5/eventloop.rs | 8 ++++---- rumqttc/src/v5/framed.rs | 4 ++-- rumqttc/src/v5/mod.rs | 3 +-- rumqttc/src/v5/state.rs | 12 ++++++------ rumqttc/src/v5/tls.rs | 2 +- 6 files changed, 16 insertions(+), 17 deletions(-) diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index 8b95c1bc2..53c0eaad5 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -1,10 +1,10 @@ //! This module offers a high level synchronous and asynchronous abstraction to //! async eventloop. -use crate::v4::{ConnectionError, Event, EventLoop, MqttOptions, Request}; +use crate::v5::{ConnectionError, Event, EventLoop, MqttOptions, Request}; use async_channel::{SendError, Sender, TrySendError}; use bytes::Bytes; -use mqttbytes::v4::*; +use mqttbytes::v5::*; use mqttbytes::*; use std::mem; use tokio::runtime; diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index d4ce78c7a..8a07317a1 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -1,11 +1,11 @@ -use crate::v4::{framed::Network, Transport}; -use crate::v4::{tls, Incoming, MqttState, Packet, Request, StateError}; -use crate::v4::{MqttOptions, Outgoing}; +use crate::v5::{framed::Network, Transport}; +use crate::v5::{tls, Incoming, MqttState, Packet, Request, StateError}; +use crate::v5::{MqttOptions, Outgoing}; use async_channel::{bounded, Receiver, Sender}; #[cfg(feature = "websocket")] use async_tungstenite::tokio::{connect_async, connect_async_with_tls_connector}; -use mqttbytes::v4::*; +use mqttbytes::v5::*; use tokio::net::TcpStream; #[cfg(unix)] use tokio::net::UnixStream; diff --git a/rumqttc/src/v5/framed.rs b/rumqttc/src/v5/framed.rs index 96bda17da..03855b062 100644 --- a/rumqttc/src/v5/framed.rs +++ b/rumqttc/src/v5/framed.rs @@ -1,9 +1,9 @@ use bytes::BytesMut; -use mqttbytes::v4::*; +use mqttbytes::v5::*; use mqttbytes::*; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use crate::v4::{Incoming, MqttState, StateError}; +use crate::v5::{Incoming, MqttState, StateError}; use std::io; /// Network transforms packets <-> frames efficiently. It takes diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index 01800c685..012f7e0a6 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -12,8 +12,7 @@ mod tls; pub use async_channel::{SendError, Sender, TrySendError}; pub use client::{AsyncClient, Client, ClientError, Connection}; pub use eventloop::{ConnectionError, Event, EventLoop}; -pub use mqttbytes::v4::*; -pub use mqttbytes::*; +pub use packet::*; pub use state::{MqttState, StateError}; pub use tokio_rustls::rustls::ClientConfig; pub use tls::Error; diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index c51116ccf..e8f72592f 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -1,7 +1,7 @@ -use crate::v4::{Event, Incoming, Outgoing, Request}; +use super::{Event, Incoming, Outgoing, Request}; use bytes::BytesMut; -use mqttbytes::v4::*; +use mqttbytes::v5::*; use mqttbytes::*; use std::collections::VecDeque; use std::{io, mem, time::Instant}; @@ -422,7 +422,7 @@ impl MqttState { debug!( "Unsubscribe. Topics = {:?}, Pkid = {:?}", - unsub.topics, unsub.pkid + unsub.filters, unsub.pkid ); unsub.write(&mut self.write)?; @@ -434,7 +434,7 @@ impl MqttState { fn outgoing_disconnect(&mut self) -> Result<(), StateError> { debug!("Disconnect"); - Disconnect.write(&mut self.write)?; + Disconnect::new().write(&mut self.write)?; let event = Event::Outgoing(Outgoing::Disconnect); self.events.push_back(event); Ok(()) @@ -487,8 +487,8 @@ impl MqttState { #[cfg(test)] mod test { use super::{MqttState, StateError}; - use crate::v4::{Event, Incoming, MqttOptions, Outgoing, Request}; - use mqttbytes::v4::*; + use crate::v5::{Event, Incoming, MqttOptions, Outgoing, Request}; + use mqttbytes::v5::*; use mqttbytes::*; fn build_outgoing_publish(qos: QoS) -> Publish { diff --git a/rumqttc/src/v5/tls.rs b/rumqttc/src/v5/tls.rs index ea1809e23..3d07819b2 100644 --- a/rumqttc/src/v5/tls.rs +++ b/rumqttc/src/v5/tls.rs @@ -7,7 +7,7 @@ use tokio_rustls::rustls::{ use tokio_rustls::webpki; use tokio_rustls::{client::TlsStream, TlsConnector}; -use crate::v4::{Key, MqttOptions, TlsConfiguration}; +use crate::v5::{Key, MqttOptions, TlsConfiguration}; use std::convert::TryFrom; use std::io; From 093d140aa6995a71c7d7dca496fa4054e72c5b9e Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Fri, 11 Feb 2022 18:29:58 +0530 Subject: [PATCH 05/41] rumqttc: v5: use v5::packet instead of mqttbytes::v5 Signed-off-by: Abhik Jain --- rumqttc/src/v5/client.rs | 31 ++++++------------------------- rumqttc/src/v5/eventloop.rs | 8 ++++---- rumqttc/src/v5/framed.rs | 6 ++---- rumqttc/src/v5/state.rs | 10 ++++------ 4 files changed, 16 insertions(+), 39 deletions(-) diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index 53c0eaad5..846ae9901 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -1,11 +1,9 @@ //! This module offers a high level synchronous and asynchronous abstraction to //! async eventloop. -use crate::v5::{ConnectionError, Event, EventLoop, MqttOptions, Request}; +use crate::v5::{packet::*, ConnectionError, Event, EventLoop, MqttOptions, Request}; use async_channel::{SendError, Sender, TrySendError}; use bytes::Bytes; -use mqttbytes::v5::*; -use mqttbytes::*; use std::mem; use tokio::runtime; use tokio::runtime::Runtime; @@ -94,11 +92,7 @@ impl AsyncClient { } /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. - pub async fn ack( - &self, - publish: &Publish - ) -> Result<(), ClientError> - { + pub async fn ack(&self, publish: &Publish) -> Result<(), ClientError> { let ack = get_ack_req(publish); if let Some(ack) = ack { @@ -108,11 +102,7 @@ impl AsyncClient { } /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. - pub fn try_ack( - &self, - publish: &Publish - ) -> Result<(), ClientError> - { + pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { let ack = get_ack_req(publish); if let Some(ack) = ack { self.request_tx.try_send(ack)?; @@ -217,7 +207,7 @@ fn get_ack_req(publish: &Publish) -> Option { let ack = match publish.qos { QoS::AtMostOnce => return None, QoS::AtLeastOnce => Request::PubAck(PubAck::new(publish.pkid)), - QoS::ExactlyOnce => Request::PubRec(PubRec::new(publish.pkid)) + QoS::ExactlyOnce => Request::PubRec(PubRec::new(publish.pkid)), }; Some(ack) } @@ -277,26 +267,17 @@ impl Client { } /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. - pub fn ack( - &self, - publish: &Publish - ) -> Result<(), ClientError> - { + pub fn ack(&self, publish: &Publish) -> Result<(), ClientError> { pollster::block_on(self.client.ack(publish))?; Ok(()) } /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. - pub fn try_ack( - &self, - publish: &Publish - ) -> Result<(), ClientError> - { + pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { self.client.try_ack(publish)?; Ok(()) } - /// Sends a MQTT Subscribe to the eventloop pub fn subscribe>(&mut self, topic: S, qos: QoS) -> Result<(), ClientError> { pollster::block_on(self.client.subscribe(topic, qos))?; diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index 8a07317a1..fe8b1cb6b 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -1,11 +1,11 @@ -use crate::v5::{framed::Network, Transport}; -use crate::v5::{tls, Incoming, MqttState, Packet, Request, StateError}; -use crate::v5::{MqttOptions, Outgoing}; +use crate::v5::{ + framed::Network, packet::*, tls, Incoming, MqttOptions, MqttState, Outgoing, Packet, Request, + StateError, Transport, +}; use async_channel::{bounded, Receiver, Sender}; #[cfg(feature = "websocket")] use async_tungstenite::tokio::{connect_async, connect_async_with_tls_connector}; -use mqttbytes::v5::*; use tokio::net::TcpStream; #[cfg(unix)] use tokio::net::UnixStream; diff --git a/rumqttc/src/v5/framed.rs b/rumqttc/src/v5/framed.rs index 03855b062..684694d5a 100644 --- a/rumqttc/src/v5/framed.rs +++ b/rumqttc/src/v5/framed.rs @@ -1,9 +1,7 @@ use bytes::BytesMut; -use mqttbytes::v5::*; -use mqttbytes::*; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use crate::v5::{Incoming, MqttState, StateError}; +use crate::v5::{packet::*, Incoming, MqttState, StateError}; use std::io; /// Network transforms packets <-> frames efficiently. It takes @@ -61,7 +59,7 @@ impl Network { loop { let required = match read(&mut self.read, self.max_incoming_size) { Ok(packet) => return Ok(packet), - Err(mqttbytes::Error::InsufficientBytes(required)) => required, + Err(Error::InsufficientBytes(required)) => required, Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())), }; diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index e8f72592f..491eabc62 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -1,8 +1,6 @@ -use super::{Event, Incoming, Outgoing, Request}; +use super::{Event, Incoming, Outgoing, Request, packet::*}; use bytes::BytesMut; -use mqttbytes::v5::*; -use mqttbytes::*; use std::collections::VecDeque; use std::{io, mem, time::Instant}; @@ -30,11 +28,11 @@ pub enum StateError { #[error("Timeout while waiting to resolve collision")] CollisionTimeout, #[error("Mqtt serialization/deserialization error")] - Deserialization(mqttbytes::Error), + Deserialization(Error), } -impl From for StateError { - fn from(e: mqttbytes::Error) -> StateError { +impl From for StateError { + fn from(e: Error) -> StateError { StateError::Deserialization(e) } } From cbc25b211515bfb5dd8600c714fb2b5a42512313 Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Fri, 11 Feb 2022 18:44:09 +0530 Subject: [PATCH 06/41] rumqttc: fix examples, tests and docs Signed-off-by: Abhik Jain --- rumqttc/examples/async_manual_acks.rs | 1 - rumqttc/examples/syncpubsub.rs | 4 +++- rumqttc/examples/tls2.rs | 2 +- rumqttc/src/lib.rs | 4 ++-- rumqttc/src/v4/client.rs | 27 +++++---------------------- rumqttc/src/v4/mod.rs | 14 +++++++++----- rumqttc/src/v4/state.rs | 2 +- rumqttc/src/v5/mod.rs | 14 +++++++++----- rumqttc/src/v5/packet/disconnect.rs | 2 +- rumqttc/src/v5/state.rs | 8 +++----- 10 files changed, 34 insertions(+), 44 deletions(-) diff --git a/rumqttc/examples/async_manual_acks.rs b/rumqttc/examples/async_manual_acks.rs index 6790991e3..127e235db 100644 --- a/rumqttc/examples/async_manual_acks.rs +++ b/rumqttc/examples/async_manual_acks.rs @@ -14,7 +14,6 @@ fn create_conn() -> (AsyncClient, EventLoop) { AsyncClient::new(mqttoptions, 10) } - #[tokio::main(worker_threads = 1)] async fn main() -> Result<(), Box> { pretty_env_logger::init(); diff --git a/rumqttc/examples/syncpubsub.rs b/rumqttc/examples/syncpubsub.rs index b13abc65f..a14bcda4e 100644 --- a/rumqttc/examples/syncpubsub.rs +++ b/rumqttc/examples/syncpubsub.rs @@ -7,7 +7,9 @@ fn main() { let mut mqttoptions = MqttOptions::new("test-1", "localhost", 1883); let will = LastWill::new("hello/world", "good bye", QoS::AtMostOnce, false); - mqttoptions.set_keep_alive(Duration::from_secs(5)).set_last_will(will); + mqttoptions + .set_keep_alive(Duration::from_secs(5)) + .set_last_will(will); let (client, mut connection) = Client::new(mqttoptions, 10); thread::spawn(move || publish(client)); diff --git a/rumqttc/examples/tls2.rs b/rumqttc/examples/tls2.rs index 1bf5cf4fb..cd81364a4 100644 --- a/rumqttc/examples/tls2.rs +++ b/rumqttc/examples/tls2.rs @@ -14,7 +14,7 @@ async fn main() -> Result<(), Box> { // Dummies to prevent compilation error in CI let ca = vec![1, 2, 3]; let client_cert = vec![1, 2, 3]; - let client_key= vec![1, 2, 3]; + let client_key = vec![1, 2, 3]; // let ca = include_bytes!("/home/tekjar/tlsfiles/ca.cert.pem"); // let client_cert = include_bytes!("/home/tekjar/tlsfiles/device-1.cert.pem"); // let client_key = include_bytes!("/home/tekjar/tlsfiles/device-1.key.pem"); diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index 0a30cbefd..353397a1e 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -9,7 +9,7 @@ //! ---------------------------- //! //! ```no_run -//! use rumqttc::{MqttOptions, Client, QoS}; +//! use rumqttc::v4::{MqttOptions, Client, QoS}; //! use std::time::Duration; //! use std::thread; //! @@ -33,7 +33,7 @@ //! ------------------------------ //! //! ```no_run -//! use rumqttc::{MqttOptions, AsyncClient, QoS}; +//! use rumqttc::v4::{MqttOptions, AsyncClient, QoS}; //! use tokio::{task, time}; //! use std::time::Duration; //! use std::error::Error; diff --git a/rumqttc/src/v4/client.rs b/rumqttc/src/v4/client.rs index 8b95c1bc2..623281075 100644 --- a/rumqttc/src/v4/client.rs +++ b/rumqttc/src/v4/client.rs @@ -94,11 +94,7 @@ impl AsyncClient { } /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. - pub async fn ack( - &self, - publish: &Publish - ) -> Result<(), ClientError> - { + pub async fn ack(&self, publish: &Publish) -> Result<(), ClientError> { let ack = get_ack_req(publish); if let Some(ack) = ack { @@ -108,11 +104,7 @@ impl AsyncClient { } /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. - pub fn try_ack( - &self, - publish: &Publish - ) -> Result<(), ClientError> - { + pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { let ack = get_ack_req(publish); if let Some(ack) = ack { self.request_tx.try_send(ack)?; @@ -217,7 +209,7 @@ fn get_ack_req(publish: &Publish) -> Option { let ack = match publish.qos { QoS::AtMostOnce => return None, QoS::AtLeastOnce => Request::PubAck(PubAck::new(publish.pkid)), - QoS::ExactlyOnce => Request::PubRec(PubRec::new(publish.pkid)) + QoS::ExactlyOnce => Request::PubRec(PubRec::new(publish.pkid)), }; Some(ack) } @@ -277,26 +269,17 @@ impl Client { } /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. - pub fn ack( - &self, - publish: &Publish - ) -> Result<(), ClientError> - { + pub fn ack(&self, publish: &Publish) -> Result<(), ClientError> { pollster::block_on(self.client.ack(publish))?; Ok(()) } /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. - pub fn try_ack( - &self, - publish: &Publish - ) -> Result<(), ClientError> - { + pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { self.client.try_ack(publish)?; Ok(()) } - /// Sends a MQTT Subscribe to the eventloop pub fn subscribe>(&mut self, topic: S, qos: QoS) -> Result<(), ClientError> { pollster::block_on(self.client.subscribe(topic, qos))?; diff --git a/rumqttc/src/v4/mod.rs b/rumqttc/src/v4/mod.rs index 4797b0f2a..882dfdd06 100644 --- a/rumqttc/src/v4/mod.rs +++ b/rumqttc/src/v4/mod.rs @@ -14,8 +14,8 @@ pub use eventloop::{ConnectionError, Event, EventLoop}; pub use mqttbytes::v4::*; pub use mqttbytes::*; pub use state::{MqttState, StateError}; -pub use tokio_rustls::rustls::ClientConfig; pub use tls::Error; +pub use tokio_rustls::rustls::ClientConfig; pub type Incoming = Packet; @@ -336,7 +336,11 @@ impl MqttOptions { } /// Username and password - pub fn set_credentials, P: Into>(&mut self, username: U, password: P) -> &mut Self { + pub fn set_credentials, P: Into>( + &mut self, + username: U, + password: P, + ) -> &mut Self { self.credentials = Some((username.into(), password.into())); self } @@ -570,7 +574,7 @@ impl std::convert::TryFrom for MqttOptions { inflight, last_will: None, conn_timeout, - manual_acks: false + manual_acks: false, }) } } @@ -613,9 +617,9 @@ mod test { fn no_scheme() { let mut _mqtt_opts = MqttOptions::new("client_a", "a3f8czas.iot.eu-west-1.amazonaws.com/mqtt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=MyCreds%2F20201001%2Feu-west-1%2Fiotdevicegateway%2Faws4_request&X-Amz-Date=20201001T130812Z&X-Amz-Expires=7200&X-Amz-Signature=9ae09b49896f44270f2707551581953e6cac71a4ccf34c7c3415555be751b2d1&X-Amz-SignedHeaders=host", 443); - _mqtt_opts.set_transport(crate::Transport::wss(Vec::from("Test CA"), None, None)); + _mqtt_opts.set_transport(crate::v4::Transport::wss(Vec::from("Test CA"), None, None)); - if let crate::Transport::Wss(TlsConfiguration::Simple { + if let crate::v4::Transport::Wss(TlsConfiguration::Simple { ca, client_auth, alpn, diff --git a/rumqttc/src/v4/state.rs b/rumqttc/src/v4/state.rs index c51116ccf..0653d4e98 100644 --- a/rumqttc/src/v4/state.rs +++ b/rumqttc/src/v4/state.rs @@ -100,7 +100,7 @@ impl MqttState { // TODO: Optimize these sizes later events: VecDeque::with_capacity(100), write: BytesMut::with_capacity(10 * 1024), - manual_acks + manual_acks, } } diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index 012f7e0a6..d4b276151 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -14,8 +14,8 @@ pub use client::{AsyncClient, Client, ClientError, Connection}; pub use eventloop::{ConnectionError, Event, EventLoop}; pub use packet::*; pub use state::{MqttState, StateError}; -pub use tokio_rustls::rustls::ClientConfig; pub use tls::Error; +pub use tokio_rustls::rustls::ClientConfig; pub type Incoming = Packet; @@ -336,7 +336,11 @@ impl MqttOptions { } /// Username and password - pub fn set_credentials, P: Into>(&mut self, username: U, password: P) -> &mut Self { + pub fn set_credentials, P: Into>( + &mut self, + username: U, + password: P, + ) -> &mut Self { self.credentials = Some((username.into(), password.into())); self } @@ -570,7 +574,7 @@ impl std::convert::TryFrom for MqttOptions { inflight, last_will: None, conn_timeout, - manual_acks: false + manual_acks: false, }) } } @@ -613,9 +617,9 @@ mod test { fn no_scheme() { let mut _mqtt_opts = MqttOptions::new("client_a", "a3f8czas.iot.eu-west-1.amazonaws.com/mqtt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=MyCreds%2F20201001%2Feu-west-1%2Fiotdevicegateway%2Faws4_request&X-Amz-Date=20201001T130812Z&X-Amz-Expires=7200&X-Amz-Signature=9ae09b49896f44270f2707551581953e6cac71a4ccf34c7c3415555be751b2d1&X-Amz-SignedHeaders=host", 443); - _mqtt_opts.set_transport(crate::Transport::wss(Vec::from("Test CA"), None, None)); + _mqtt_opts.set_transport(crate::v5::Transport::wss(Vec::from("Test CA"), None, None)); - if let crate::Transport::Wss(TlsConfiguration::Simple { + if let crate::v5::Transport::Wss(TlsConfiguration::Simple { ca, client_auth, alpn, diff --git a/rumqttc/src/v5/packet/disconnect.rs b/rumqttc/src/v5/packet/disconnect.rs index 508f3c427..3662087f1 100644 --- a/rumqttc/src/v5/packet/disconnect.rs +++ b/rumqttc/src/v5/packet/disconnect.rs @@ -1,6 +1,6 @@ use std::convert::{TryFrom, TryInto}; -use bytes::{BufMut, BytesMut, Bytes}; +use bytes::{BufMut, Bytes, BytesMut}; use super::*; diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 491eabc62..2bc2b470a 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -1,4 +1,4 @@ -use super::{Event, Incoming, Outgoing, Request, packet::*}; +use super::{packet::*, Event, Incoming, Outgoing, Request}; use bytes::BytesMut; use std::collections::VecDeque; @@ -98,7 +98,7 @@ impl MqttState { // TODO: Optimize these sizes later events: VecDeque::with_capacity(100), write: BytesMut::with_capacity(10 * 1024), - manual_acks + manual_acks, } } @@ -485,9 +485,7 @@ impl MqttState { #[cfg(test)] mod test { use super::{MqttState, StateError}; - use crate::v5::{Event, Incoming, MqttOptions, Outgoing, Request}; - use mqttbytes::v5::*; - use mqttbytes::*; + use crate::v5::{packet::*, Event, Incoming, MqttOptions, Outgoing, Request}; fn build_outgoing_publish(qos: QoS) -> Publish { let topic = "hello/world".to_owned(); From ec46060b4d3efb153d640486c7317d59c5a0167b Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Sat, 12 Feb 2022 14:58:44 +0530 Subject: [PATCH 07/41] rumqttc: use request_buf in eventloop and client Signed-off-by: Abhik Jain --- rumqttc/src/v5/client.rs | 128 +++++++++++++++++------------ rumqttc/src/v5/eventloop.rs | 159 +++++++++++++++++++++--------------- 2 files changed, 169 insertions(+), 118 deletions(-) diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index 846ae9901..20ad59f95 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -4,19 +4,22 @@ use crate::v5::{packet::*, ConnectionError, Event, EventLoop, MqttOptions, Reque use async_channel::{SendError, Sender, TrySendError}; use bytes::Bytes; -use std::mem; -use tokio::runtime; -use tokio::runtime::Runtime; +use std::{ + collections::VecDeque, + mem, + sync::{Arc, Mutex}, +}; +use tokio::runtime::{self, Runtime}; /// Client Error #[derive(Debug, thiserror::Error)] pub enum ClientError { #[error("Failed to send cancel request to eventloop")] - Cancel(#[from] SendError<()>), + Cancel(SendError<()>), #[error("Failed to send mqtt requests to eventloop")] - Request(#[from] SendError), + Request(#[from] SendError<()>), #[error("Failed to send mqtt requests to eventloop")] - TryRequest(#[from] TrySendError), + TryRequest(#[from] TrySendError<()>), #[error("Serialization error")] Mqtt4(mqttbytes::Error), } @@ -25,7 +28,8 @@ pub enum ClientError { /// This is cloneable and can be used to asynchronously Publish, Subscribe. #[derive(Clone, Debug)] pub struct AsyncClient { - request_tx: Sender, + request_buf: Arc>>, + request_tx: Sender<()>, cancel_tx: Sender<()>, } @@ -33,10 +37,12 @@ impl AsyncClient { /// Create a new `AsyncClient` pub fn new(options: MqttOptions, cap: usize) -> (AsyncClient, EventLoop) { let mut eventloop = EventLoop::new(options, cap); + let request_buf = eventloop.buf().clone(); let request_tx = eventloop.handle(); let cancel_tx = eventloop.cancel_handle(); let client = AsyncClient { + request_buf, request_tx, cancel_tx, }; @@ -46,8 +52,13 @@ impl AsyncClient { /// Create a new `AsyncClient` from a pair of async channel `Sender`s. This is mostly useful for /// creating a test instance. - pub fn from_senders(request_tx: Sender, cancel_tx: Sender<()>) -> AsyncClient { + pub fn from_senders( + request_buf: Arc>>, + request_tx: Sender<()>, + cancel_tx: Sender<()>, + ) -> AsyncClient { AsyncClient { + request_buf, request_tx, cancel_tx, } @@ -60,16 +71,18 @@ impl AsyncClient { qos: QoS, retain: bool, payload: V, - ) -> Result<(), ClientError> + ) -> Result where S: Into, V: Into>, { let mut publish = Publish::new(topic, qos, payload); publish.retain = retain; + let pkid = publish.pkid; let publish = Request::Publish(publish); - self.request_tx.send(publish).await?; - Ok(()) + self.request_buf.lock().unwrap().push_back(publish); + self.request_tx.send(()).await?; + Ok(pkid) } /// Sends a MQTT Publish to the eventloop @@ -79,16 +92,18 @@ impl AsyncClient { qos: QoS, retain: bool, payload: V, - ) -> Result<(), ClientError> + ) -> Result where S: Into, V: Into>, { let mut publish = Publish::new(topic, qos, payload); publish.retain = retain; + let pkid = publish.pkid; let publish = Request::Publish(publish); - self.request_tx.try_send(publish)?; - Ok(()) + self.request_buf.lock().unwrap().push_back(publish); + self.request_tx.try_send(())?; + Ok(pkid) } /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. @@ -96,7 +111,8 @@ impl AsyncClient { let ack = get_ack_req(publish); if let Some(ack) = ack { - self.request_tx.send(ack).await?; + self.request_buf.lock().unwrap().push_back(ack); + self.request_tx.send(()).await?; } Ok(()) } @@ -105,7 +121,8 @@ impl AsyncClient { pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { let ack = get_ack_req(publish); if let Some(ack) = ack { - self.request_tx.try_send(ack)?; + self.request_buf.lock().unwrap().push_back(ack); + self.request_tx.try_send(())?; } Ok(()) } @@ -124,7 +141,8 @@ impl AsyncClient { let mut publish = Publish::from_bytes(topic, qos, payload); publish.retain = retain; let publish = Request::Publish(publish); - self.request_tx.send(publish).await?; + self.request_buf.lock().unwrap().push_back(publish); + self.request_tx.send(()).await?; Ok(()) } @@ -132,7 +150,8 @@ impl AsyncClient { pub async fn subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { let subscribe = Subscribe::new(topic.into(), qos); let request = Request::Subscribe(subscribe); - self.request_tx.send(request).await?; + self.request_buf.lock().unwrap().push_back(request); + self.request_tx.send(()).await?; Ok(()) } @@ -140,7 +159,8 @@ impl AsyncClient { pub fn try_subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { let subscribe = Subscribe::new(topic.into(), qos); let request = Request::Subscribe(subscribe); - self.request_tx.try_send(request)?; + self.request_buf.lock().unwrap().push_back(request); + self.request_tx.try_send(())?; Ok(()) } @@ -151,7 +171,8 @@ impl AsyncClient { { let subscribe = Subscribe::new_many(topics); let request = Request::Subscribe(subscribe); - self.request_tx.send(request).await?; + self.request_buf.lock().unwrap().push_back(request); + self.request_tx.send(()).await?; Ok(()) } @@ -162,7 +183,8 @@ impl AsyncClient { { let subscribe = Subscribe::new_many(topics); let request = Request::Subscribe(subscribe); - self.request_tx.try_send(request)?; + self.request_buf.lock().unwrap().push_back(request); + self.request_tx.try_send(())?; Ok(()) } @@ -170,7 +192,8 @@ impl AsyncClient { pub async fn unsubscribe>(&self, topic: S) -> Result<(), ClientError> { let unsubscribe = Unsubscribe::new(topic.into()); let request = Request::Unsubscribe(unsubscribe); - self.request_tx.send(request).await?; + self.request_buf.lock().unwrap().push_back(request); + self.request_tx.send(()).await?; Ok(()) } @@ -178,28 +201,42 @@ impl AsyncClient { pub fn try_unsubscribe>(&self, topic: S) -> Result<(), ClientError> { let unsubscribe = Unsubscribe::new(topic.into()); let request = Request::Unsubscribe(unsubscribe); - self.request_tx.try_send(request)?; + self.request_buf.lock().unwrap().push_back(request); + self.request_tx.try_send(())?; Ok(()) } /// Sends a MQTT disconnect to the eventloop pub async fn disconnect(&self) -> Result<(), ClientError> { let request = Request::Disconnect; - self.request_tx.send(request).await?; + self.request_buf.lock().unwrap().push_back(request); + self.request_tx.send(()).await?; Ok(()) } /// Sends a MQTT disconnect to the eventloop pub fn try_disconnect(&self) -> Result<(), ClientError> { let request = Request::Disconnect; - self.request_tx.try_send(request)?; + self.request_buf.lock().unwrap().push_back(request); + self.request_tx.try_send(())?; Ok(()) } /// Stops the eventloop right away pub async fn cancel(&self) -> Result<(), ClientError> { - self.cancel_tx.send(()).await?; - Ok(()) + self.cancel_tx.send(()).await.map_err(ClientError::Cancel) + } + + #[inline] + async fn send_and_notify(&self, request: Request) -> Result<(), async_channel::SendError<()>> { + self.request_buf.lock().unwrap().push_back(request); + self.request_tx.send(()).await + } + + #[inline] + fn try_send_and_notify(&self, request: Request) -> Result<(), async_channel::TrySendError<()>> { + self.request_buf.lock().unwrap().push_back(request); + self.request_tx.try_send(()) } } @@ -242,13 +279,12 @@ impl Client { qos: QoS, retain: bool, payload: V, - ) -> Result<(), ClientError> + ) -> Result where S: Into, V: Into>, { - pollster::block_on(self.client.publish(topic, qos, retain, payload))?; - Ok(()) + pollster::block_on(self.client.publish(topic, qos, retain, payload)) } pub fn try_publish( @@ -257,31 +293,27 @@ impl Client { qos: QoS, retain: bool, payload: V, - ) -> Result<(), ClientError> + ) -> Result where S: Into, V: Into>, { - self.client.try_publish(topic, qos, retain, payload)?; - Ok(()) + self.client.try_publish(topic, qos, retain, payload) } /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. pub fn ack(&self, publish: &Publish) -> Result<(), ClientError> { - pollster::block_on(self.client.ack(publish))?; - Ok(()) + pollster::block_on(self.client.ack(publish)) } /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { - self.client.try_ack(publish)?; - Ok(()) + self.client.try_ack(publish) } /// Sends a MQTT Subscribe to the eventloop pub fn subscribe>(&mut self, topic: S, qos: QoS) -> Result<(), ClientError> { - pollster::block_on(self.client.subscribe(topic, qos))?; - Ok(()) + pollster::block_on(self.client.subscribe(topic, qos)) } /// Sends a MQTT Subscribe to the eventloop @@ -290,8 +322,7 @@ impl Client { topic: S, qos: QoS, ) -> Result<(), ClientError> { - self.client.try_subscribe(topic, qos)?; - Ok(()) + self.client.try_subscribe(topic, qos) } /// Sends a MQTT Subscribe for multiple topics to the eventloop @@ -311,32 +342,27 @@ impl Client { /// Sends a MQTT Unsubscribe to the eventloop pub fn unsubscribe>(&mut self, topic: S) -> Result<(), ClientError> { - pollster::block_on(self.client.unsubscribe(topic))?; - Ok(()) + pollster::block_on(self.client.unsubscribe(topic)) } /// Sends a MQTT Unsubscribe to the eventloop pub fn try_unsubscribe>(&mut self, topic: S) -> Result<(), ClientError> { - self.client.try_unsubscribe(topic)?; - Ok(()) + self.client.try_unsubscribe(topic) } /// Sends a MQTT disconnect to the eventloop pub fn disconnect(&mut self) -> Result<(), ClientError> { - pollster::block_on(self.client.disconnect())?; - Ok(()) + pollster::block_on(self.client.disconnect()) } /// Sends a MQTT disconnect to the eventloop pub fn try_disconnect(&mut self) -> Result<(), ClientError> { - self.client.try_disconnect()?; - Ok(()) + self.client.try_disconnect() } /// Stops the eventloop right away pub fn cancel(&mut self) -> Result<(), ClientError> { - pollster::block_on(self.client.cancel())?; - Ok(()) + pollster::block_on(self.client.cancel()) } } diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index fe8b1cb6b..1c536db26 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -14,12 +14,16 @@ use tokio::time::{self, error::Elapsed, Instant, Sleep}; #[cfg(feature = "websocket")] use ws_stream_tungstenite::WsStream; -use std::io; #[cfg(unix)] use std::path::Path; -use std::pin::Pin; -use std::time::Duration; -use std::vec::IntoIter; +use std::{ + collections::VecDeque, + io, + pin::Pin, + sync::{Arc, Mutex}, + time::Duration, + vec::IntoIter, +}; /// Critical errors during eventloop polling #[derive(Debug, thiserror::Error)] @@ -48,10 +52,12 @@ pub struct EventLoop { pub options: MqttOptions, /// Current state of the connection pub state: MqttState, + request_buf: Arc>>, + request_buf_cache: VecDeque, /// Request stream - pub requests_rx: Receiver, + pub requests_rx: Receiver<()>, /// Requests handle to send requests - pub requests_tx: Sender, + pub requests_tx: Sender<()>, /// Pending packets from last session pub pending: IntoIter, /// Network connection to the broker @@ -78,7 +84,8 @@ impl EventLoop { /// access and update `options`, `state` and `requests`. pub fn new(options: MqttOptions, cap: usize) -> EventLoop { let (cancel_tx, cancel_rx) = bounded(5); - let (requests_tx, requests_rx) = bounded(cap); + let (requests_tx, requests_rx) = bounded(1); + let request_buf = Arc::new(Mutex::new(VecDeque::with_capacity(cap))); let pending = Vec::new(); let pending = pending.into_iter(); let max_inflight = options.inflight; @@ -87,6 +94,8 @@ impl EventLoop { EventLoop { options, state: MqttState::new(max_inflight, manual_acks), + request_buf, + request_buf_cache: VecDeque::with_capacity(cap), requests_tx, requests_rx, pending, @@ -98,10 +107,14 @@ impl EventLoop { } /// Returns a handle to communicate with this eventloop - pub fn handle(&self) -> Sender { + pub fn handle(&self) -> Sender<()> { self.requests_tx.clone() } + pub fn buf(&self) -> &Arc>> { + &self.request_buf + } + /// Handle for cancelling the eventloop. /// /// Can be useful in cases when connection should be halted immediately @@ -156,68 +169,80 @@ impl EventLoop { return Ok(event); } - // this loop is necessary since self.incoming.pop_front() might return None. In that case, - // instead of returning a None event, we try again. - select! { - // Pull a bunch of packets from network, reply in bunch and yield the first item - o = network.readb(&mut self.state) => { - o?; - // flush all the acks and return first incoming packet - network.flush(&mut self.state.write).await?; - Ok(self.state.events.pop_front().unwrap()) - }, - // Pull next request from user requests channel. - // If conditions in the below branch are for flow control. We read next user - // user request only when inflight messages are < configured inflight and there - // are no collisions while handling previous outgoing requests. - // - // Flow control is based on ack count. If inflight packet count in the buffer is - // less than max_inflight setting, next outgoing request will progress. For this - // to work correctly, broker should ack in sequence (a lot of brokers won't) - // - // E.g If max inflight = 5, user requests will be blocked when inflight queue - // looks like this -> [1, 2, 3, 4, 5]. - // If broker acking 2 instead of 1 -> [1, x, 3, 4, 5]. - // This pulls next user request. But because max packet id = max_inflight, next - // user request's packet id will roll to 1. This replaces existing packet id 1. - // Resulting in a collision - // - // Eventloop can stop receiving outgoing user requests when previous outgoing - // request collided. I.e collision state. Collision state will be cleared only - // when correct ack is received - // Full inflight queue will look like -> [1a, 2, 3, 4, 5]. - // If 3 is acked instead of 1 first -> [1a, 2, x, 4, 5]. - // After collision with pkid 1 -> [1b ,2, x, 4, 5]. - // 1a is saved to state and event loop is set to collision mode stopping new - // outgoing requests (along with 1b). - o = self.requests_rx.recv(), if !inflight_full && !pending && !collision => match o { - Ok(request) => { + // this loop is necessary as self.request_buf might be empty, in which case it is possible + // for self.state.events to be empty, and so popping off from it might return None. If None + // is returned, we select again. + loop { + select! { + // Pull a bunch of packets from network, reply in bunch and yield the first item + o = network.readb(&mut self.state) => { + o?; + // flush all the acks and return first incoming packet + network.flush(&mut self.state.write).await?; + return Ok(self.state.events.pop_front().unwrap()); + }, + // Pull next request from user requests channel. + // If conditions in the below branch are for flow control. We read next user + // request only when inflight messages are < configured inflight and there are no + // collisions while handling previous outgoing requests. + // + // Flow control is based on ack count. If inflight packet count in the buffer is + // less than max_inflight setting, next outgoing request will progress. For this to + // work correctly, broker should ack in sequence (a lot of brokers won't) + // + // E.g If max inflight = 5, user requests will be blocked when inflight queue looks + // like this -> [1, 2, 3, 4, 5]. + // If broker acking 2 instead of 1 -> [1, x, 3, 4, 5]. + // This pulls next user request. But because max packet id = max_inflight, next + // user request's packet id will roll to 1. This replaces existing packet id 1. + // Resulting in a collision + // + // Eventloop can stop receiving outgoing user requests when previous outgoing + // request collided. I.e collision state. Collision state will be cleared only + // when correct ack is received + // Full inflight queue will look like -> [1a, 2, 3, 4, 5]. + // If 3 is acked instead of 1 first -> [1a, 2, x, 4, 5]. + // After collision with pkid 1 -> [1b ,2, x, 4, 5]. + // 1a is saved to state and event loop is set to collision mode stopping new + // outgoing requests (along with 1b). + o = self.requests_rx.recv(), if !inflight_full && !pending && !collision => match o { + Ok(_request_notif) => { + // swapping to avoid blocking the mutex + std::mem::swap(&mut self.request_buf_cache,&mut *self.request_buf.lock().unwrap()); + if self.request_buf_cache.is_empty() { + continue; + } + for request in self.request_buf_cache.drain(..) { + self.state.handle_outgoing_packet(request)?; + } + network.flush(&mut self.state.write).await?; + // remaining events in the self.state.events will be taken out in next call + // to poll() even before the select! is used. + return Ok(self.state.events.pop_front().unwrap()) + } + Err(_) => return Err(ConnectionError::RequestsDone), + }, + // Handle the next pending packet from previous session. Disable + // this branch when done with all the pending packets + Some(request) = next_pending(throttle, &mut self.pending), if pending => { self.state.handle_outgoing_packet(request)?; network.flush(&mut self.state.write).await?; - Ok(self.state.events.pop_front().unwrap()) + return Ok(self.state.events.pop_front().unwrap()) + }, + // We generate pings irrespective of network activity. This keeps the ping logic + // simple. We can change this behavior in future if necessary (to prevent extra pings) + _ = self.keepalive_timeout.as_mut().unwrap() => { + let timeout = self.keepalive_timeout.as_mut().unwrap(); + timeout.as_mut().reset(Instant::now() + self.options.keep_alive); + + self.state.handle_outgoing_packet(Request::PingReq)?; + network.flush(&mut self.state.write).await?; + return Ok(self.state.events.pop_front().unwrap()) + } + // cancellation requests to stop the polling + _ = self.cancel_rx.recv() => { + return Err(ConnectionError::Cancel) } - Err(_) => Err(ConnectionError::RequestsDone), - }, - // Handle the next pending packet from previous session. Disable - // this branch when done with all the pending packets - Some(request) = next_pending(throttle, &mut self.pending), if pending => { - self.state.handle_outgoing_packet(request)?; - network.flush(&mut self.state.write).await?; - Ok(self.state.events.pop_front().unwrap()) - }, - // We generate pings irrespective of network activity. This keeps the ping logic - // simple. We can change this behavior in future if necessary (to prevent extra pings) - _ = self.keepalive_timeout.as_mut().unwrap() => { - let timeout = self.keepalive_timeout.as_mut().unwrap(); - timeout.as_mut().reset(Instant::now() + self.options.keep_alive); - - self.state.handle_outgoing_packet(Request::PingReq)?; - network.flush(&mut self.state.write).await?; - Ok(self.state.events.pop_front().unwrap()) - } - // cancellation requests to stop the polling - _ = self.cancel_rx.recv() => { - Err(ConnectionError::Cancel) } } } From 7313ea7f9db43f8346a6215b09c9334a83f7e910 Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Sat, 12 Feb 2022 15:25:19 +0530 Subject: [PATCH 08/41] rumqttc: client: more intuitive API + code reduction Signed-off-by: Abhik Jain --- rumqttc/src/v5/client.rs | 100 ++++++++++++++++----------------------- 1 file changed, 42 insertions(+), 58 deletions(-) diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index 20ad59f95..0d5acc461 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -16,10 +16,10 @@ use tokio::runtime::{self, Runtime}; pub enum ClientError { #[error("Failed to send cancel request to eventloop")] Cancel(SendError<()>), - #[error("Failed to send mqtt requests to eventloop")] - Request(#[from] SendError<()>), - #[error("Failed to send mqtt requests to eventloop")] - TryRequest(#[from] TrySendError<()>), + #[error("Failed to send mqtt request to eventloop, the evenloop has been closed")] + EventloopClosed, + #[error("Failed to send mqtt request to evenloop, to requests buffer is full right now")] + RequestsFull, #[error("Serialization error")] Mqtt4(mqttbytes::Error), } @@ -29,6 +29,7 @@ pub enum ClientError { #[derive(Clone, Debug)] pub struct AsyncClient { request_buf: Arc>>, + request_buf_capacity: usize, request_tx: Sender<()>, cancel_tx: Sender<()>, } @@ -43,6 +44,7 @@ impl AsyncClient { let client = AsyncClient { request_buf, + request_buf_capacity: cap, request_tx, cancel_tx, }; @@ -56,9 +58,11 @@ impl AsyncClient { request_buf: Arc>>, request_tx: Sender<()>, cancel_tx: Sender<()>, + cap: usize, ) -> AsyncClient { AsyncClient { request_buf, + request_buf_capacity: cap, request_tx, cancel_tx, } @@ -79,9 +83,7 @@ impl AsyncClient { let mut publish = Publish::new(topic, qos, payload); publish.retain = retain; let pkid = publish.pkid; - let publish = Request::Publish(publish); - self.request_buf.lock().unwrap().push_back(publish); - self.request_tx.send(()).await?; + self.send_and_notify(Request::Publish(publish)).await?; Ok(pkid) } @@ -100,9 +102,7 @@ impl AsyncClient { let mut publish = Publish::new(topic, qos, payload); publish.retain = retain; let pkid = publish.pkid; - let publish = Request::Publish(publish); - self.request_buf.lock().unwrap().push_back(publish); - self.request_tx.try_send(())?; + self.try_send_and_notify(Request::Publish(publish))?; Ok(pkid) } @@ -111,8 +111,7 @@ impl AsyncClient { let ack = get_ack_req(publish); if let Some(ack) = ack { - self.request_buf.lock().unwrap().push_back(ack); - self.request_tx.send(()).await?; + self.send_and_notify(ack).await?; } Ok(()) } @@ -121,8 +120,7 @@ impl AsyncClient { pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { let ack = get_ack_req(publish); if let Some(ack) = ack { - self.request_buf.lock().unwrap().push_back(ack); - self.request_tx.try_send(())?; + self.try_send_and_notify(ack)?; } Ok(()) } @@ -140,28 +138,19 @@ impl AsyncClient { { let mut publish = Publish::from_bytes(topic, qos, payload); publish.retain = retain; - let publish = Request::Publish(publish); - self.request_buf.lock().unwrap().push_back(publish); - self.request_tx.send(()).await?; - Ok(()) + self.send_and_notify(Request::Publish(publish)).await } /// Sends a MQTT Subscribe to the eventloop pub async fn subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { let subscribe = Subscribe::new(topic.into(), qos); - let request = Request::Subscribe(subscribe); - self.request_buf.lock().unwrap().push_back(request); - self.request_tx.send(()).await?; - Ok(()) + self.send_and_notify(Request::Subscribe(subscribe)).await } /// Sends a MQTT Subscribe to the eventloop pub fn try_subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { let subscribe = Subscribe::new(topic.into(), qos); - let request = Request::Subscribe(subscribe); - self.request_buf.lock().unwrap().push_back(request); - self.request_tx.try_send(())?; - Ok(()) + self.try_send_and_notify(Request::Subscribe(subscribe)) } /// Sends a MQTT Subscribe for multiple topics to the eventloop @@ -170,10 +159,7 @@ impl AsyncClient { T: IntoIterator, { let subscribe = Subscribe::new_many(topics); - let request = Request::Subscribe(subscribe); - self.request_buf.lock().unwrap().push_back(request); - self.request_tx.send(()).await?; - Ok(()) + self.send_and_notify(Request::Subscribe(subscribe)).await } /// Sends a MQTT Subscribe for multiple topics to the eventloop @@ -182,44 +168,30 @@ impl AsyncClient { T: IntoIterator, { let subscribe = Subscribe::new_many(topics); - let request = Request::Subscribe(subscribe); - self.request_buf.lock().unwrap().push_back(request); - self.request_tx.try_send(())?; - Ok(()) + self.try_send_and_notify(Request::Subscribe(subscribe)) } /// Sends a MQTT Unsubscribe to the eventloop pub async fn unsubscribe>(&self, topic: S) -> Result<(), ClientError> { let unsubscribe = Unsubscribe::new(topic.into()); - let request = Request::Unsubscribe(unsubscribe); - self.request_buf.lock().unwrap().push_back(request); - self.request_tx.send(()).await?; - Ok(()) + self.send_and_notify(Request::Unsubscribe(unsubscribe)) + .await } /// Sends a MQTT Unsubscribe to the eventloop pub fn try_unsubscribe>(&self, topic: S) -> Result<(), ClientError> { let unsubscribe = Unsubscribe::new(topic.into()); - let request = Request::Unsubscribe(unsubscribe); - self.request_buf.lock().unwrap().push_back(request); - self.request_tx.try_send(())?; - Ok(()) + self.try_send_and_notify(Request::Unsubscribe(unsubscribe)) } /// Sends a MQTT disconnect to the eventloop pub async fn disconnect(&self) -> Result<(), ClientError> { - let request = Request::Disconnect; - self.request_buf.lock().unwrap().push_back(request); - self.request_tx.send(()).await?; - Ok(()) + self.send_and_notify(Request::Disconnect).await } /// Sends a MQTT disconnect to the eventloop pub fn try_disconnect(&self) -> Result<(), ClientError> { - let request = Request::Disconnect; - self.request_buf.lock().unwrap().push_back(request); - self.request_tx.try_send(())?; - Ok(()) + self.try_send_and_notify(Request::Disconnect) } /// Stops the eventloop right away @@ -227,16 +199,28 @@ impl AsyncClient { self.cancel_tx.send(()).await.map_err(ClientError::Cancel) } - #[inline] - async fn send_and_notify(&self, request: Request) -> Result<(), async_channel::SendError<()>> { - self.request_buf.lock().unwrap().push_back(request); - self.request_tx.send(()).await + async fn send_and_notify(&self, request: Request) -> Result<(), ClientError> { + let mut request_buf = self.request_buf.lock().unwrap(); + if request_buf.len() == self.request_buf_capacity { + return Err(ClientError::RequestsFull); + } + request_buf.push_back(request); + if let Err(SendError(_)) = self.request_tx.send(()).await { + return Err(ClientError::EventloopClosed); + }; + Ok(()) } - #[inline] - fn try_send_and_notify(&self, request: Request) -> Result<(), async_channel::TrySendError<()>> { - self.request_buf.lock().unwrap().push_back(request); - self.request_tx.try_send(()) + fn try_send_and_notify(&self, request: Request) -> Result<(), ClientError> { + let mut request_buf = self.request_buf.lock().unwrap(); + if request_buf.len() == self.request_buf_capacity { + return Err(ClientError::RequestsFull); + } + request_buf.push_back(request); + if let Err(TrySendError::Closed(_)) = self.request_tx.try_send(()) { + return Err(ClientError::EventloopClosed); + } + Ok(()) } } From fd166d84ed48bfba339e97ff4aee5677bcb5a354 Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Sat, 12 Feb 2022 15:25:45 +0530 Subject: [PATCH 09/41] rumqttc: examples: remove needless import Signed-off-by: Abhik Jain --- rumqttc/examples/async_manual_acks.rs | 2 +- rumqttc/examples/asyncpubsub.rs | 2 +- rumqttc/examples/syncpubsub.rs | 2 +- rumqttc/examples/tls.rs | 2 +- rumqttc/examples/tls2.rs | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/rumqttc/examples/async_manual_acks.rs b/rumqttc/examples/async_manual_acks.rs index 127e235db..9f2678b51 100644 --- a/rumqttc/examples/async_manual_acks.rs +++ b/rumqttc/examples/async_manual_acks.rs @@ -1,6 +1,6 @@ use tokio::{task, time}; -use rumqttc::v4::{self, AsyncClient, Event, EventLoop, Incoming, MqttOptions, QoS}; +use rumqttc::v4::{AsyncClient, Event, EventLoop, Incoming, MqttOptions, QoS}; use std::error::Error; use std::time::Duration; diff --git a/rumqttc/examples/asyncpubsub.rs b/rumqttc/examples/asyncpubsub.rs index fef15451d..b4de624e7 100644 --- a/rumqttc/examples/asyncpubsub.rs +++ b/rumqttc/examples/asyncpubsub.rs @@ -1,6 +1,6 @@ use tokio::{task, time}; -use rumqttc::v4::{self, AsyncClient, MqttOptions, QoS}; +use rumqttc::v4::{AsyncClient, MqttOptions, QoS}; use std::error::Error; use std::time::Duration; diff --git a/rumqttc/examples/syncpubsub.rs b/rumqttc/examples/syncpubsub.rs index a14bcda4e..2001b4d29 100644 --- a/rumqttc/examples/syncpubsub.rs +++ b/rumqttc/examples/syncpubsub.rs @@ -1,4 +1,4 @@ -use rumqttc::v4::{self, Client, LastWill, MqttOptions, QoS}; +use rumqttc::v4::{Client, LastWill, MqttOptions, QoS}; use std::thread; use std::time::Duration; diff --git a/rumqttc/examples/tls.rs b/rumqttc/examples/tls.rs index 90bb4180f..dedf975db 100644 --- a/rumqttc/examples/tls.rs +++ b/rumqttc/examples/tls.rs @@ -1,6 +1,6 @@ //! Example of how to configure rumqttd to connect to a server using TLS and authentication. -use rumqttc::v4::{self, AsyncClient, Event, Incoming, MqttOptions, Transport}; +use rumqttc::v4::{AsyncClient, Event, Incoming, MqttOptions, Transport}; use rustls::ClientConfig; use std::error::Error; diff --git a/rumqttc/examples/tls2.rs b/rumqttc/examples/tls2.rs index cd81364a4..f6392b99d 100644 --- a/rumqttc/examples/tls2.rs +++ b/rumqttc/examples/tls2.rs @@ -1,6 +1,6 @@ //! Example of how to configure rumqttd to connect to a server using TLS and authentication. -use rumqttc::v4::{self, AsyncClient, Key, MqttOptions, TlsConfiguration, Transport}; +use rumqttc::v4::{AsyncClient, Key, MqttOptions, TlsConfiguration, Transport}; use std::error::Error; #[tokio::main] From 24728ef4ed34acb58e15f5e9d4c5159fefdf7cc9 Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Sun, 13 Feb 2022 18:03:07 +0530 Subject: [PATCH 10/41] rumqttc: v5: replace async_channel + pollseter -> flume Signed-off-by: Abhik Jain --- rumqttc/Cargo.toml | 1 + rumqttc/src/v5/client.rs | 74 +++++++++++++++++++++++++------------ rumqttc/src/v5/eventloop.rs | 8 ++-- rumqttc/src/v5/mod.rs | 2 +- 4 files changed, 56 insertions(+), 29 deletions(-) diff --git a/rumqttc/Cargo.toml b/rumqttc/Cargo.toml index 81a7b05b7..bbf7dbd7b 100644 --- a/rumqttc/Cargo.toml +++ b/rumqttc/Cargo.toml @@ -31,6 +31,7 @@ log = "0.4" thiserror = "1.0.21" http = "^0.2" url = { version = "2.2", default-features = false, optional = true } +flume = "0.10.10" [dev-dependencies] pretty_env_logger = "0.4" diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index 0d5acc461..6f2930bd2 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -2,8 +2,8 @@ //! async eventloop. use crate::v5::{packet::*, ConnectionError, Event, EventLoop, MqttOptions, Request}; -use async_channel::{SendError, Sender, TrySendError}; use bytes::Bytes; +use flume::{SendError, Sender, TrySendError}; use std::{ collections::VecDeque, mem, @@ -83,7 +83,8 @@ impl AsyncClient { let mut publish = Publish::new(topic, qos, payload); publish.retain = retain; let pkid = publish.pkid; - self.send_and_notify(Request::Publish(publish)).await?; + self.send_async_and_notify(Request::Publish(publish)) + .await?; Ok(pkid) } @@ -108,18 +109,15 @@ impl AsyncClient { /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. pub async fn ack(&self, publish: &Publish) -> Result<(), ClientError> { - let ack = get_ack_req(publish); - - if let Some(ack) = ack { - self.send_and_notify(ack).await?; + if let Some(ack) = get_ack_req(publish) { + self.send_async_and_notify(ack).await?; } Ok(()) } /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { - let ack = get_ack_req(publish); - if let Some(ack) = ack { + if let Some(ack) = get_ack_req(publish) { self.try_send_and_notify(ack)?; } Ok(()) @@ -138,13 +136,14 @@ impl AsyncClient { { let mut publish = Publish::from_bytes(topic, qos, payload); publish.retain = retain; - self.send_and_notify(Request::Publish(publish)).await + self.send_async_and_notify(Request::Publish(publish)).await } /// Sends a MQTT Subscribe to the eventloop pub async fn subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { let subscribe = Subscribe::new(topic.into(), qos); - self.send_and_notify(Request::Subscribe(subscribe)).await + self.send_async_and_notify(Request::Subscribe(subscribe)) + .await } /// Sends a MQTT Subscribe to the eventloop @@ -159,7 +158,8 @@ impl AsyncClient { T: IntoIterator, { let subscribe = Subscribe::new_many(topics); - self.send_and_notify(Request::Subscribe(subscribe)).await + self.send_async_and_notify(Request::Subscribe(subscribe)) + .await } /// Sends a MQTT Subscribe for multiple topics to the eventloop @@ -174,7 +174,7 @@ impl AsyncClient { /// Sends a MQTT Unsubscribe to the eventloop pub async fn unsubscribe>(&self, topic: S) -> Result<(), ClientError> { let unsubscribe = Unsubscribe::new(topic.into()); - self.send_and_notify(Request::Unsubscribe(unsubscribe)) + self.send_async_and_notify(Request::Unsubscribe(unsubscribe)) .await } @@ -186,7 +186,7 @@ impl AsyncClient { /// Sends a MQTT disconnect to the eventloop pub async fn disconnect(&self) -> Result<(), ClientError> { - self.send_and_notify(Request::Disconnect).await + self.send_async_and_notify(Request::Disconnect).await } /// Sends a MQTT disconnect to the eventloop @@ -196,16 +196,31 @@ impl AsyncClient { /// Stops the eventloop right away pub async fn cancel(&self) -> Result<(), ClientError> { - self.cancel_tx.send(()).await.map_err(ClientError::Cancel) + self.cancel_tx + .send_async(()) + .await + .map_err(ClientError::Cancel) } - async fn send_and_notify(&self, request: Request) -> Result<(), ClientError> { + async fn send_async_and_notify(&self, request: Request) -> Result<(), ClientError> { let mut request_buf = self.request_buf.lock().unwrap(); if request_buf.len() == self.request_buf_capacity { return Err(ClientError::RequestsFull); } request_buf.push_back(request); - if let Err(SendError(_)) = self.request_tx.send(()).await { + if let Err(SendError(_)) = self.request_tx.send_async(()).await { + return Err(ClientError::EventloopClosed); + }; + Ok(()) + } + + pub(crate) fn send_and_notify(&self, request: Request) -> Result<(), ClientError> { + let mut request_buf = self.request_buf.lock().unwrap(); + if request_buf.len() == self.request_buf_capacity { + return Err(ClientError::RequestsFull); + } + request_buf.push_back(request); + if let Err(SendError(_)) = self.request_tx.send(()) { return Err(ClientError::EventloopClosed); }; Ok(()) @@ -217,7 +232,7 @@ impl AsyncClient { return Err(ClientError::RequestsFull); } request_buf.push_back(request); - if let Err(TrySendError::Closed(_)) = self.request_tx.try_send(()) { + if let Err(TrySendError::Disconnected(_)) = self.request_tx.try_send(()) { return Err(ClientError::EventloopClosed); } Ok(()) @@ -268,7 +283,11 @@ impl Client { S: Into, V: Into>, { - pollster::block_on(self.client.publish(topic, qos, retain, payload)) + let mut publish = Publish::new(topic, qos, payload); + publish.retain = retain; + let pkid = publish.pkid; + self.client.send_and_notify(Request::Publish(publish))?; + Ok(pkid) } pub fn try_publish( @@ -287,7 +306,10 @@ impl Client { /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. pub fn ack(&self, publish: &Publish) -> Result<(), ClientError> { - pollster::block_on(self.client.ack(publish)) + if let Some(ack) = get_ack_req(publish) { + self.client.send_and_notify(ack)?; + } + Ok(()) } /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. @@ -297,7 +319,8 @@ impl Client { /// Sends a MQTT Subscribe to the eventloop pub fn subscribe>(&mut self, topic: S, qos: QoS) -> Result<(), ClientError> { - pollster::block_on(self.client.subscribe(topic, qos)) + let subscribe = Subscribe::new(topic.into(), qos); + self.client.send_and_notify(Request::Subscribe(subscribe)) } /// Sends a MQTT Subscribe to the eventloop @@ -314,7 +337,8 @@ impl Client { where T: IntoIterator, { - pollster::block_on(self.client.subscribe_many(topics)) + let subscribe = Subscribe::new_many(topics); + self.client.send_and_notify(Request::Subscribe(subscribe)) } pub fn try_subscribe_many(&mut self, topics: T) -> Result<(), ClientError> @@ -326,7 +350,9 @@ impl Client { /// Sends a MQTT Unsubscribe to the eventloop pub fn unsubscribe>(&mut self, topic: S) -> Result<(), ClientError> { - pollster::block_on(self.client.unsubscribe(topic)) + let unsubscribe = Unsubscribe::new(topic.into()); + self.client + .send_and_notify(Request::Unsubscribe(unsubscribe)) } /// Sends a MQTT Unsubscribe to the eventloop @@ -336,7 +362,7 @@ impl Client { /// Sends a MQTT disconnect to the eventloop pub fn disconnect(&mut self) -> Result<(), ClientError> { - pollster::block_on(self.client.disconnect()) + self.client.send_and_notify(Request::Disconnect) } /// Sends a MQTT disconnect to the eventloop @@ -346,7 +372,7 @@ impl Client { /// Stops the eventloop right away pub fn cancel(&mut self) -> Result<(), ClientError> { - pollster::block_on(self.client.cancel()) + self.client.cancel_tx.send(()).map_err(ClientError::Cancel) } } diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index 1c536db26..3eb1dd6c6 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -3,9 +3,9 @@ use crate::v5::{ StateError, Transport, }; -use async_channel::{bounded, Receiver, Sender}; #[cfg(feature = "websocket")] use async_tungstenite::tokio::{connect_async, connect_async_with_tls_connector}; +use flume::{bounded, Receiver, Sender}; use tokio::net::TcpStream; #[cfg(unix)] use tokio::net::UnixStream; @@ -205,7 +205,7 @@ impl EventLoop { // After collision with pkid 1 -> [1b ,2, x, 4, 5]. // 1a is saved to state and event loop is set to collision mode stopping new // outgoing requests (along with 1b). - o = self.requests_rx.recv(), if !inflight_full && !pending && !collision => match o { + o = self.requests_rx.recv_async(), if !inflight_full && !pending && !collision => match o { Ok(_request_notif) => { // swapping to avoid blocking the mutex std::mem::swap(&mut self.request_buf_cache,&mut *self.request_buf.lock().unwrap()); @@ -240,7 +240,7 @@ impl EventLoop { return Ok(self.state.events.pop_front().unwrap()) } // cancellation requests to stop the polling - _ = self.cancel_rx.recv() => { + _ = self.cancel_rx.recv_async() => { return Err(ConnectionError::Cancel) } } @@ -256,7 +256,7 @@ async fn connect_or_cancel( // resolved. Returns with an error if connections fail continuously select! { o = connect(options) => o, - _ = cancel_rx.recv() => { + _ = cancel_rx.recv_async() => { Err(ConnectionError::Cancel) } } diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index d4b276151..83e21ef01 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -9,9 +9,9 @@ mod packet; mod state; mod tls; -pub use async_channel::{SendError, Sender, TrySendError}; pub use client::{AsyncClient, Client, ClientError, Connection}; pub use eventloop::{ConnectionError, Event, EventLoop}; +pub use flume::{SendError, Sender, TrySendError}; pub use packet::*; pub use state::{MqttState, StateError}; pub use tls::Error; From 27dd057c162418050205c6e8429cfaa8719b9a58 Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Tue, 15 Feb 2022 11:40:58 +0530 Subject: [PATCH 11/41] rumqttc: v5: AsyncClient: limit scope of MutecGuard for async runtime Signed-off-by: Abhik Jain --- rumqttc/src/v5/client.rs | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index 6f2930bd2..06eec457e 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -203,11 +203,13 @@ impl AsyncClient { } async fn send_async_and_notify(&self, request: Request) -> Result<(), ClientError> { - let mut request_buf = self.request_buf.lock().unwrap(); - if request_buf.len() == self.request_buf_capacity { - return Err(ClientError::RequestsFull); + { + let mut request_buf = self.request_buf.lock().unwrap(); + if request_buf.len() == self.request_buf_capacity { + return Err(ClientError::RequestsFull); + } + request_buf.push_back(request); } - request_buf.push_back(request); if let Err(SendError(_)) = self.request_tx.send_async(()).await { return Err(ClientError::EventloopClosed); }; From 5428339a0fbff830aa854ac4bb051007ba4b57e1 Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Tue, 15 Feb 2022 11:43:54 +0530 Subject: [PATCH 12/41] rumqttc: v5: add examples Signed-off-by: Abhik Jain --- rumqttc/examples/async_manual_acks_v5.rs | 76 ++++++++++++++++++++++++ rumqttc/examples/asyncpubsub_v5.rs | 43 ++++++++++++++ rumqttc/examples/syncpubsub_v5.rs | 35 +++++++++++ 3 files changed, 154 insertions(+) create mode 100644 rumqttc/examples/async_manual_acks_v5.rs create mode 100644 rumqttc/examples/asyncpubsub_v5.rs create mode 100644 rumqttc/examples/syncpubsub_v5.rs diff --git a/rumqttc/examples/async_manual_acks_v5.rs b/rumqttc/examples/async_manual_acks_v5.rs new file mode 100644 index 000000000..eb01d2e3d --- /dev/null +++ b/rumqttc/examples/async_manual_acks_v5.rs @@ -0,0 +1,76 @@ +use tokio::{task, time}; + +use rumqttc::v5::{AsyncClient, Event, EventLoop, Incoming, MqttOptions, QoS}; +use std::error::Error; +use std::time::Duration; + +fn create_conn() -> (AsyncClient, EventLoop) { + let mut mqttoptions = MqttOptions::new("test-1", "localhost", 1883); + mqttoptions + .set_keep_alive(Duration::from_secs(5)) + .set_manual_acks(true) + .set_clean_session(false); + + AsyncClient::new(mqttoptions, 10) +} + +#[tokio::main(worker_threads = 1)] +async fn main() -> Result<(), Box> { + pretty_env_logger::init(); + + // create mqtt connection with clean_session = false and manual_acks = true + let (client, mut eventloop) = create_conn(); + + // subscribe example topic + client + .subscribe("hello/world", QoS::AtLeastOnce) + .await + .unwrap(); + + task::spawn(async move { + // send some messages to example topic and disconnect + requests(client.clone()).await; + client.disconnect().await.unwrap() + }); + + loop { + // get subscribed messages without acking + let event = eventloop.poll().await; + println!("{:?}", event); + if let Err(_err) = event { + // break loop on disconnection + break; + } + } + + // create new broker connection + let (client, mut eventloop) = create_conn(); + + loop { + // previously published messages should be republished after reconnection. + let event = eventloop.poll().await; + println!("{:?}", event); + match event { + Ok(Event::Incoming(Incoming::Publish(publish))) => { + // this time we will ack incoming publishes. + // Its important not to block eventloop as this can cause deadlock. + let c = client.clone(); + tokio::spawn(async move { + c.ack(&publish).await.unwrap(); + }); + } + _ => {} + } + } +} + +async fn requests(client: AsyncClient) { + for i in 1..=10 { + client + .publish("hello/world", QoS::AtLeastOnce, false, vec![1; i]) + .await + .unwrap(); + + time::sleep(Duration::from_secs(1)).await; + } +} diff --git a/rumqttc/examples/asyncpubsub_v5.rs b/rumqttc/examples/asyncpubsub_v5.rs new file mode 100644 index 000000000..b0b7a6698 --- /dev/null +++ b/rumqttc/examples/asyncpubsub_v5.rs @@ -0,0 +1,43 @@ +use tokio::{task, time}; + +use rumqttc::v5::{AsyncClient, MqttOptions, QoS}; +use std::error::Error; +use std::time::Duration; + +#[tokio::main(worker_threads = 1)] +async fn main() -> Result<(), Box> { + pretty_env_logger::init(); + // color_backtrace::install(); + + let mut mqttoptions = MqttOptions::new("test-1", "localhost", 1883); + mqttoptions.set_keep_alive(Duration::from_secs(5)); + + let (client, mut eventloop) = AsyncClient::new(mqttoptions, 10); + task::spawn(async move { + requests(client).await; + time::sleep(Duration::from_secs(3)).await; + }); + + loop { + let event = eventloop.poll().await; + println!("{:?}", event.unwrap()); + } +} + +async fn requests(client: AsyncClient) { + client + .subscribe("hello/world", QoS::AtMostOnce) + .await + .unwrap(); + + for i in 1..=10 { + client + .publish("hello/world", QoS::ExactlyOnce, false, vec![1; i]) + .await + .unwrap(); + + time::sleep(Duration::from_secs(1)).await; + } + + time::sleep(Duration::from_secs(120)).await; +} diff --git a/rumqttc/examples/syncpubsub_v5.rs b/rumqttc/examples/syncpubsub_v5.rs new file mode 100644 index 000000000..857ab4789 --- /dev/null +++ b/rumqttc/examples/syncpubsub_v5.rs @@ -0,0 +1,35 @@ +use rumqttc::v5::{Client, LastWill, MqttOptions, QoS}; +use std::thread; +use std::time::Duration; + +fn main() { + pretty_env_logger::init(); + + let mut mqttoptions = MqttOptions::new("test-1", "localhost", 1883); + let will = LastWill::new("hello/world", "good bye", QoS::AtMostOnce, false); + mqttoptions + .set_keep_alive(Duration::from_secs(5)) + .set_last_will(will); + + let (client, mut connection) = Client::new(mqttoptions, 10); + thread::spawn(move || publish(client)); + + for (i, notification) in connection.iter().enumerate() { + println!("{}. Notification = {:?}", i, notification); + } + + println!("Done with the stream!!"); +} + +fn publish(mut client: Client) { + client.subscribe("hello/+/world", QoS::AtMostOnce).unwrap(); + for i in 0..10 { + let payload = vec![1; i as usize]; + let topic = format!("hello/{}/world", i); + let qos = QoS::AtLeastOnce; + + client.publish(topic, qos, true, payload).unwrap(); + } + + thread::sleep(Duration::from_secs(1)); +} From bae184a5769697ebc3c44ed1ddf59b768c17ad39 Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Tue, 15 Feb 2022 18:34:24 +0530 Subject: [PATCH 13/41] rumqttc: v4: examples: fix Signed-off-by: Abhik Jain --- rumqttc/examples/websocket.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/rumqttc/examples/websocket.rs b/rumqttc/examples/websocket.rs index 4336d2f7a..1467723c9 100644 --- a/rumqttc/examples/websocket.rs +++ b/rumqttc/examples/websocket.rs @@ -1,6 +1,8 @@ use tokio::{task, time}; -use rumqttc::{self, AsyncClient, MqttOptions, QoS, Transport}; +use rumqttc::v4::{AsyncClient, MqttOptions, QoS}; +#[cfg(feature = "websocket")] +use rumqttc::v4::Transport; use std::error::Error; use std::time::Duration; From b5518b5620a97dcf984f67b63f532ddbec2fc87d Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Mon, 7 Mar 2022 12:27:16 +0530 Subject: [PATCH 14/41] cmacos test Signed-off-by: Abhik Jain --- rumqttc/tests/reliability.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rumqttc/tests/reliability.rs b/rumqttc/tests/reliability.rs index 3a129b6b8..db41af29d 100644 --- a/rumqttc/tests/reliability.rs +++ b/rumqttc/tests/reliability.rs @@ -360,7 +360,7 @@ async fn packet_id_collisions_are_detected_and_flow_control_is_applied() { if ack == 1 { let elapsed = start.elapsed().as_millis() as i64; let deviation_millis: i64 = (5000 - elapsed).abs(); - assert!(deviation_millis < 10); + assert!(deviation_millis < 100); break; } } From 8045e84ff6a3892738f7acacab97c535f79dbb94 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Wed, 16 Mar 2022 11:27:39 +0530 Subject: [PATCH 15/41] Limit GitHub actions to running only on PRs affecting code/dependencies (#375) * Limit to running CI in case there are changes to `*.rs` or `Cargo.*` files --- .github/workflows/features.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/features.yml b/.github/workflows/features.yml index bb03036f8..8d2dafc06 100644 --- a/.github/workflows/features.yml +++ b/.github/workflows/features.yml @@ -2,6 +2,10 @@ on: pull_request: branches: - master + paths: + - '**.rs' + - 'Cargo.*' + - '*/Cargo.*' name: features From 70244d9efc9990f652bca9870cd73738016313b9 Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Fri, 18 Mar 2022 00:54:09 +0530 Subject: [PATCH 16/41] rumqttc: v5: separate pub/sub client option Signed-off-by: Abhik Jain --- rumqttc/src/v5/client/asyncclient.rs | 274 +++++++++++++++++++++++++++ rumqttc/src/v5/client/mod.rs | 102 ++++++++++ rumqttc/src/v5/client/publisher.rs | 98 ++++++++++ rumqttc/src/v5/client/subscriber.rs | 121 ++++++++++++ rumqttc/src/v5/client/syncclient.rs | 135 +++++++++++++ rumqttc/src/v5/eventloop.rs | 8 +- rumqttc/src/v5/mod.rs | 10 + rumqttc/src/v5/state.rs | 21 +- 8 files changed, 761 insertions(+), 8 deletions(-) create mode 100644 rumqttc/src/v5/client/asyncclient.rs create mode 100644 rumqttc/src/v5/client/mod.rs create mode 100644 rumqttc/src/v5/client/publisher.rs create mode 100644 rumqttc/src/v5/client/subscriber.rs create mode 100644 rumqttc/src/v5/client/syncclient.rs diff --git a/rumqttc/src/v5/client/asyncclient.rs b/rumqttc/src/v5/client/asyncclient.rs new file mode 100644 index 000000000..a00b295ef --- /dev/null +++ b/rumqttc/src/v5/client/asyncclient.rs @@ -0,0 +1,274 @@ +use std::{ + collections::VecDeque, + sync::{Arc, Mutex}, +}; + +use bytes::Bytes; +use flume::{SendError, Sender, TrySendError}; + +use crate::v5::{ + client::{get_ack_req, Publisher, Subscriber}, + packet::{Publish, Subscribe, SubscribeFilter, Unsubscribe}, + ClientError, EventLoop, MqttOptions, QoS, Request, +}; + +/// `AsyncClient` to communicate with MQTT `Eventloop` +/// This is cloneable and can be used to asynchronously Publish, Subscribe. +#[derive(Clone, Debug)] +pub struct AsyncClient { + request_buf: Arc>>, + sub_events_buf: Arc>>, + sub_events_buf_cache: VecDeque, + request_buf_capacity: usize, + request_tx: Sender<()>, + pub(crate) cancel_tx: Sender<()>, +} + +impl AsyncClient { + /// Create a new `AsyncClient` + pub fn new(options: MqttOptions, cap: usize) -> (AsyncClient, EventLoop) { + let mut eventloop = EventLoop::new(options, cap); + let request_buf = eventloop.request_buf().clone(); + let sub_events_buf = eventloop.sub_events_buf().clone(); + let sub_events_buf_cache = VecDeque::with_capacity(cap); + let request_tx = eventloop.handle(); + let cancel_tx = eventloop.cancel_handle(); + + let client = AsyncClient { + request_buf, + request_buf_capacity: cap, + sub_events_buf, + sub_events_buf_cache, + request_tx, + cancel_tx, + }; + + (client, eventloop) + } + + /// Create a new `AsyncClient` from a pair of async channel `Sender`s. This is mostly useful for + /// creating a test instance. + pub fn from_senders( + request_buf: Arc>>, + sub_events_buf: Arc>>, + request_tx: Sender<()>, + cancel_tx: Sender<()>, + cap: usize, + ) -> AsyncClient { + AsyncClient { + request_buf, + request_buf_capacity: cap, + sub_events_buf, + sub_events_buf_cache: VecDeque::with_capacity(cap), + request_tx, + cancel_tx, + } + } + + /// Sends a MQTT Publish to the eventloop + pub async fn publish( + &self, + topic: S, + qos: QoS, + retain: bool, + payload: V, + ) -> Result + where + S: Into, + V: Into>, + { + let mut publish = Publish::new(topic, qos, payload); + publish.retain = retain; + let pkid = publish.pkid; + self.send_async_and_notify(Request::Publish(publish)) + .await?; + Ok(pkid) + } + + /// Sends a MQTT Publish to the eventloop + pub fn try_publish( + &self, + topic: S, + qos: QoS, + retain: bool, + payload: V, + ) -> Result + where + S: Into, + V: Into>, + { + let mut publish = Publish::new(topic, qos, payload); + publish.retain = retain; + let pkid = publish.pkid; + self.try_send_and_notify(Request::Publish(publish))?; + Ok(pkid) + } + + /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. + pub async fn ack(&self, publish: &Publish) -> Result<(), ClientError> { + if let Some(ack) = get_ack_req(publish.qos, publish.pkid) { + self.send_async_and_notify(ack).await?; + } + Ok(()) + } + + /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. + pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { + if let Some(ack) = get_ack_req(publish.qos, publish.pkid) { + self.try_send_and_notify(ack)?; + } + Ok(()) + } + + /// Sends a MQTT Publish to the eventloop + pub async fn publish_bytes( + &self, + topic: S, + qos: QoS, + retain: bool, + payload: Bytes, + ) -> Result<(), ClientError> + where + S: Into, + { + let mut publish = Publish::from_bytes(topic, qos, payload); + publish.retain = retain; + self.send_async_and_notify(Request::Publish(publish)).await + } + + /// Sends a MQTT Subscribe to the eventloop + pub async fn subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { + let subscribe = Subscribe::new(topic.into(), qos); + self.send_async_and_notify(Request::Subscribe(subscribe)) + .await + } + + /// Sends a MQTT Subscribe to the eventloop + pub fn try_subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { + let subscribe = Subscribe::new(topic.into(), qos); + self.try_send_and_notify(Request::Subscribe(subscribe)) + } + + /// Sends a MQTT Subscribe for multiple topics to the eventloop + pub async fn subscribe_many(&self, topics: T) -> Result<(), ClientError> + where + T: IntoIterator, + { + let subscribe = Subscribe::new_many(topics); + self.send_async_and_notify(Request::Subscribe(subscribe)) + .await + } + + /// Sends a MQTT Subscribe for multiple topics to the eventloop + pub fn try_subscribe_many(&self, topics: T) -> Result<(), ClientError> + where + T: IntoIterator, + { + let subscribe = Subscribe::new_many(topics); + self.try_send_and_notify(Request::Subscribe(subscribe)) + } + + /// Sends a MQTT Unsubscribe to the eventloop + pub async fn unsubscribe>(&self, topic: S) -> Result<(), ClientError> { + let unsubscribe = Unsubscribe::new(topic.into()); + self.send_async_and_notify(Request::Unsubscribe(unsubscribe)) + .await + } + + /// Sends a MQTT Unsubscribe to the eventloop + pub fn try_unsubscribe>(&self, topic: S) -> Result<(), ClientError> { + let unsubscribe = Unsubscribe::new(topic.into()); + self.try_send_and_notify(Request::Unsubscribe(unsubscribe)) + } + + /// Sends a MQTT disconnect to the eventloop + pub async fn disconnect(&self) -> Result<(), ClientError> { + self.send_async_and_notify(Request::Disconnect).await + } + + /// Sends a MQTT disconnect to the eventloop + pub fn try_disconnect(&self) -> Result<(), ClientError> { + self.try_send_and_notify(Request::Disconnect) + } + + /// Stops the eventloop right away + pub async fn cancel(&self) -> Result<(), ClientError> { + self.cancel_tx + .send_async(()) + .await + .map_err(ClientError::Cancel) + } + + async fn send_async_and_notify(&self, request: Request) -> Result<(), ClientError> { + { + let mut request_buf = self.request_buf.lock().unwrap(); + if request_buf.len() == self.request_buf_capacity { + return Err(ClientError::RequestsFull); + } + request_buf.push_back(request); + } + if let Err(SendError(_)) = self.request_tx.send_async(()).await { + return Err(ClientError::EventloopClosed); + }; + Ok(()) + } + + pub(crate) fn send_and_notify(&self, request: Request) -> Result<(), ClientError> { + let mut request_buf = self.request_buf.lock().unwrap(); + if request_buf.len() == self.request_buf_capacity { + return Err(ClientError::RequestsFull); + } + request_buf.push_back(request); + if let Err(SendError(_)) = self.request_tx.send(()) { + return Err(ClientError::EventloopClosed); + }; + Ok(()) + } + + fn try_send_and_notify(&self, request: Request) -> Result<(), ClientError> { + let mut request_buf = self.request_buf.lock().unwrap(); + if request_buf.len() == self.request_buf_capacity { + return Err(ClientError::RequestsFull); + } + request_buf.push_back(request); + if let Err(TrySendError::Disconnected(_)) = self.request_tx.try_send(()) { + return Err(ClientError::EventloopClosed); + } + Ok(()) + } + + pub fn next_publish(&mut self) -> Option { + if let Some(publish) = self.sub_events_buf_cache.pop_front() { + return Some(publish); + } + + std::mem::swap( + &mut self.sub_events_buf_cache, + &mut *self.sub_events_buf.lock().unwrap(), + ); + self.sub_events_buf_cache.pop_front() + } + + pub async fn split( + self, + publish_topic: impl Into, + publish_qos: QoS, + ) -> Result<(Publisher, Subscriber), ClientError> { + let publisher = Publisher { + request_buf: self.request_buf.clone(), + request_buf_capacity: self.request_buf_capacity, + request_tx: self.request_tx.clone(), + cancel_tx: self.cancel_tx.clone(), + publish_topic: publish_topic.into(), + publish_qos, + }; + let subscriber = Subscriber { + request_buf: self.request_buf, + sub_events_buf: self.sub_events_buf, + sub_events_buf_cache: self.sub_events_buf_cache, + request_buf_capacity: self.request_buf_capacity, + request_tx: self.request_tx, + }; + Ok((publisher, subscriber)) + } +} diff --git a/rumqttc/src/v5/client/mod.rs b/rumqttc/src/v5/client/mod.rs new file mode 100644 index 000000000..3156126df --- /dev/null +++ b/rumqttc/src/v5/client/mod.rs @@ -0,0 +1,102 @@ +//! This module offers a high level synchronous and asynchronous abstraction to +//! async eventloop. +use crate::v5::{packet::*, ConnectionError, Event, EventLoop, Request}; + +use flume::SendError; +use std::mem; +use tokio::runtime::{self, Runtime}; + +mod asyncclient; +pub use asyncclient::AsyncClient; +mod publisher; +pub use publisher::Publisher; +mod subscriber; +pub use subscriber::Subscriber; +mod syncclient; +pub use syncclient::Client; + +/// Client Error +#[derive(Debug, thiserror::Error)] +pub enum ClientError { + #[error("Failed to send cancel request to eventloop")] + Cancel(SendError<()>), + #[error("Failed to send mqtt request to eventloop, the evenloop has been closed")] + EventloopClosed, + #[error("Failed to send mqtt request to evenloop, to requests buffer is full right now")] + RequestsFull, + #[error("Serialization error")] + Mqtt5(Error), +} + +fn get_ack_req(qos: QoS, pkid: u16) -> Option { + let ack = match qos { + QoS::AtMostOnce => return None, + QoS::AtLeastOnce => Request::PubAck(PubAck::new(pkid)), + QoS::ExactlyOnce => Request::PubRec(PubRec::new(pkid)), + }; + Some(ack) +} + + +/// MQTT connection. Maintains all the necessary state +pub struct Connection { + pub eventloop: EventLoop, + runtime: Option, +} + +impl Connection { + fn new(eventloop: EventLoop, runtime: Runtime) -> Connection { + Connection { + eventloop, + runtime: Some(runtime), + } + } + + /// Returns an iterator over this connection. Iterating over this is all that's + /// necessary to make connection progress and maintain a robust connection. + /// Just continuing to loop will reconnect + /// **NOTE** Don't block this while iterating + #[must_use = "Connection should be iterated over a loop to make progress"] + pub fn iter(&mut self) -> Iter { + let runtime = self.runtime.take().unwrap(); + Iter { + connection: self, + runtime, + } + } +} + +/// Iterator which polls the eventloop for connection progress +pub struct Iter<'a> { + connection: &'a mut Connection, + runtime: runtime::Runtime, +} + +impl<'a> Iterator for Iter<'a> { + type Item = Result; + + fn next(&mut self) -> Option { + let f = self.connection.eventloop.poll(); + match self.runtime.block_on(f) { + Ok(v) => Some(Ok(v)), + // closing of request channel should stop the iterator + Err(ConnectionError::RequestsDone) => { + trace!("Done with requests"); + None + } + Err(ConnectionError::Cancel) => { + trace!("Cancellation request received"); + None + } + Err(e) => Some(Err(e)), + } + } +} + +impl<'a> Drop for Iter<'a> { + fn drop(&mut self) { + // TODO: Don't create new runtime in drop + let runtime = runtime::Builder::new_current_thread().build().unwrap(); + self.connection.runtime = Some(mem::replace(&mut self.runtime, runtime)); + } +} diff --git a/rumqttc/src/v5/client/publisher.rs b/rumqttc/src/v5/client/publisher.rs new file mode 100644 index 000000000..ae47769f6 --- /dev/null +++ b/rumqttc/src/v5/client/publisher.rs @@ -0,0 +1,98 @@ +use std::{ + collections::VecDeque, + sync::{Arc, Mutex}, +}; + +use bytes::Bytes; +use flume::{SendError, Sender, TrySendError}; + +use crate::v5::{packet::Publish, ClientError, QoS, Request}; + +pub struct Publisher { + pub(crate) request_buf: Arc>>, + pub(crate) request_buf_capacity: usize, + pub(crate) request_tx: Sender<()>, + pub(crate) cancel_tx: Sender<()>, + pub(crate) publish_topic: String, + pub(crate) publish_qos: QoS, +} + +impl Publisher { + /// Sends a MQTT Publish to the eventloop + pub async fn publish( + &self, + retain: bool, + payload: impl Into>, + ) -> Result { + let mut publish = Publish::new(&self.publish_topic, self.publish_qos, payload); + publish.retain = retain; + let pkid = publish.pkid; + self.send_async_and_notify(Request::Publish(publish)) + .await?; + Ok(pkid) + } + + /// Sends a MQTT Publish to the eventloop + pub fn try_publish( + &self, + retain: bool, + payload: impl Into>, + ) -> Result { + let mut publish = Publish::new(&self.publish_topic, self.publish_qos, payload); + publish.retain = retain; + let pkid = publish.pkid; + self.try_send_and_notify(Request::Publish(publish))?; + Ok(pkid) + } + + /// Sends a MQTT Publish to the eventloop + pub async fn publish_bytes(&self, retain: bool, payload: Bytes) -> Result<(), ClientError> { + let mut publish = Publish::from_bytes(&self.publish_topic, self.publish_qos, payload); + publish.retain = retain; + self.send_async_and_notify(Request::Publish(publish)).await + } + + async fn send_async_and_notify(&self, request: Request) -> Result<(), ClientError> { + { + let mut request_buf = self.request_buf.lock().unwrap(); + if request_buf.len() == self.request_buf_capacity { + return Err(ClientError::RequestsFull); + } + request_buf.push_back(request); + } + if let Err(SendError(_)) = self.request_tx.send_async(()).await { + return Err(ClientError::EventloopClosed); + }; + Ok(()) + } + + /// Sends a MQTT disconnect to the eventloop + pub async fn disconnect(&self) -> Result<(), ClientError> { + self.send_async_and_notify(Request::Disconnect).await + } + + /// Sends a MQTT disconnect to the eventloop + pub fn try_disconnect(&self) -> Result<(), ClientError> { + self.try_send_and_notify(Request::Disconnect) + } + + /// Stops the eventloop right away + pub async fn cancel(&self) -> Result<(), ClientError> { + self.cancel_tx + .send_async(()) + .await + .map_err(ClientError::Cancel) + } + + fn try_send_and_notify(&self, request: Request) -> Result<(), ClientError> { + let mut request_buf = self.request_buf.lock().unwrap(); + if request_buf.len() == self.request_buf_capacity { + return Err(ClientError::RequestsFull); + } + request_buf.push_back(request); + if let Err(TrySendError::Disconnected(_)) = self.request_tx.try_send(()) { + return Err(ClientError::EventloopClosed); + } + Ok(()) + } +} diff --git a/rumqttc/src/v5/client/subscriber.rs b/rumqttc/src/v5/client/subscriber.rs new file mode 100644 index 000000000..6a7889f18 --- /dev/null +++ b/rumqttc/src/v5/client/subscriber.rs @@ -0,0 +1,121 @@ +use std::{ + collections::VecDeque, + sync::{Arc, Mutex}, +}; + +use flume::{SendError, Sender, TrySendError}; + +use crate::v5::{ + client::get_ack_req, ClientError, Publish, QoS, Request, Subscribe, SubscribeFilter, + Unsubscribe, +}; + +#[derive(Debug, Clone)] +pub struct Subscriber { + pub(crate) request_buf: Arc>>, + pub(crate) sub_events_buf: Arc>>, + pub(crate) sub_events_buf_cache: VecDeque, + pub(crate) request_buf_capacity: usize, + pub(crate) request_tx: Sender<()>, +} + +impl Subscriber { + pub fn next_publish(&mut self) -> Option { + if let Some(publish) = self.sub_events_buf_cache.pop_front() { + return Some(publish); + } + + std::mem::swap( + &mut self.sub_events_buf_cache, + &mut *self.sub_events_buf.lock().unwrap(), + ); + self.sub_events_buf_cache.pop_front() + } + + /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. + pub async fn ack(&self, qos: QoS, pkid: u16) -> Result<(), ClientError> { + if let Some(ack) = get_ack_req(qos, pkid) { + self.send_async_and_notify(ack).await?; + } + Ok(()) + } + + /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. + pub fn try_ack(&self, qos: QoS, pkid: u16) -> Result<(), ClientError> { + if let Some(ack) = get_ack_req(qos, pkid) { + self.try_send_and_notify(ack)?; + } + Ok(()) + } + + /// Sends a MQTT Subscribe to the eventloop + pub async fn subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { + let subscribe = Subscribe::new(topic.into(), qos); + self.send_async_and_notify(Request::Subscribe(subscribe)) + .await + } + + /// Sends a MQTT Subscribe to the eventloop + pub fn try_subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { + let subscribe = Subscribe::new(topic.into(), qos); + self.try_send_and_notify(Request::Subscribe(subscribe)) + } + + /// Sends a MQTT Subscribe for multiple topics to the eventloop + pub async fn subscribe_many(&self, topics: T) -> Result<(), ClientError> + where + T: IntoIterator, + { + let subscribe = Subscribe::new_many(topics); + self.send_async_and_notify(Request::Subscribe(subscribe)) + .await + } + + /// Sends a MQTT Subscribe for multiple topics to the eventloop + pub fn try_subscribe_many(&self, topics: T) -> Result<(), ClientError> + where + T: IntoIterator, + { + let subscribe = Subscribe::new_many(topics); + self.try_send_and_notify(Request::Subscribe(subscribe)) + } + + /// Sends a MQTT Unsubscribe to the eventloop + pub async fn unsubscribe>(&self, topic: S) -> Result<(), ClientError> { + let unsubscribe = Unsubscribe::new(topic.into()); + self.send_async_and_notify(Request::Unsubscribe(unsubscribe)) + .await + } + + /// Sends a MQTT Unsubscribe to the eventloop + pub fn try_unsubscribe>(&self, topic: S) -> Result<(), ClientError> { + let unsubscribe = Unsubscribe::new(topic.into()); + self.try_send_and_notify(Request::Unsubscribe(unsubscribe)) + } + + async fn send_async_and_notify(&self, request: Request) -> Result<(), ClientError> { + { + let mut request_buf = self.request_buf.lock().unwrap(); + if request_buf.len() == self.request_buf_capacity { + return Err(ClientError::RequestsFull); + } + request_buf.push_back(request); + } + if let Err(SendError(_)) = self.request_tx.send_async(()).await { + return Err(ClientError::EventloopClosed); + }; + Ok(()) + } + + fn try_send_and_notify(&self, request: Request) -> Result<(), ClientError> { + let mut request_buf = self.request_buf.lock().unwrap(); + if request_buf.len() == self.request_buf_capacity { + return Err(ClientError::RequestsFull); + } + request_buf.push_back(request); + if let Err(TrySendError::Disconnected(_)) = self.request_tx.try_send(()) { + return Err(ClientError::EventloopClosed); + } + Ok(()) + } +} diff --git a/rumqttc/src/v5/client/syncclient.rs b/rumqttc/src/v5/client/syncclient.rs new file mode 100644 index 000000000..bed105012 --- /dev/null +++ b/rumqttc/src/v5/client/syncclient.rs @@ -0,0 +1,135 @@ +use tokio::runtime; + +use crate::v5::{ + client::get_ack_req, + packet::{Publish, Subscribe, SubscribeFilter, Unsubscribe}, + AsyncClient, ClientError, Connection, MqttOptions, QoS, Request, +}; + +/// `Client` to communicate with MQTT eventloop `Connection`. +/// +/// Client is cloneable and can be used to synchronously Publish, Subscribe. +/// Asynchronous channel handle can also be extracted if necessary +#[derive(Clone)] +pub struct Client { + client: AsyncClient, +} + +impl Client { + /// Create a new `Client` + pub fn new(options: MqttOptions, cap: usize) -> (Client, Connection) { + let (client, eventloop) = AsyncClient::new(options, cap); + let client = Client { client }; + let runtime = runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + let connection = Connection::new(eventloop, runtime); + (client, connection) + } + + /// Sends a MQTT Publish to the eventloop + pub fn publish( + &mut self, + topic: S, + qos: QoS, + retain: bool, + payload: V, + ) -> Result + where + S: Into, + V: Into>, + { + let mut publish = Publish::new(topic, qos, payload); + publish.retain = retain; + let pkid = publish.pkid; + self.client.send_and_notify(Request::Publish(publish))?; + Ok(pkid) + } + + pub fn try_publish( + &mut self, + topic: S, + qos: QoS, + retain: bool, + payload: V, + ) -> Result + where + S: Into, + V: Into>, + { + self.client.try_publish(topic, qos, retain, payload) + } + + /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. + pub fn ack(&self, publish: &Publish) -> Result<(), ClientError> { + if let Some(ack) = get_ack_req(publish.qos, publish.pkid) { + self.client.send_and_notify(ack)?; + } + Ok(()) + } + + /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. + pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { + self.client.try_ack(publish) + } + + /// Sends a MQTT Subscribe to the eventloop + pub fn subscribe>(&mut self, topic: S, qos: QoS) -> Result<(), ClientError> { + let subscribe = Subscribe::new(topic.into(), qos); + self.client.send_and_notify(Request::Subscribe(subscribe)) + } + + /// Sends a MQTT Subscribe to the eventloop + pub fn try_subscribe>( + &mut self, + topic: S, + qos: QoS, + ) -> Result<(), ClientError> { + self.client.try_subscribe(topic, qos) + } + + /// Sends a MQTT Subscribe for multiple topics to the eventloop + pub fn subscribe_many(&mut self, topics: T) -> Result<(), ClientError> + where + T: IntoIterator, + { + let subscribe = Subscribe::new_many(topics); + self.client.send_and_notify(Request::Subscribe(subscribe)) + } + + pub fn try_subscribe_many(&mut self, topics: T) -> Result<(), ClientError> + where + T: IntoIterator, + { + self.client.try_subscribe_many(topics) + } + + /// Sends a MQTT Unsubscribe to the eventloop + pub fn unsubscribe>(&mut self, topic: S) -> Result<(), ClientError> { + let unsubscribe = Unsubscribe::new(topic.into()); + self.client + .send_and_notify(Request::Unsubscribe(unsubscribe)) + } + + /// Sends a MQTT Unsubscribe to the eventloop + pub fn try_unsubscribe>(&mut self, topic: S) -> Result<(), ClientError> { + self.client.try_unsubscribe(topic) + } + + /// Sends a MQTT disconnect to the eventloop + pub fn disconnect(&mut self) -> Result<(), ClientError> { + self.client.send_and_notify(Request::Disconnect) + } + + /// Sends a MQTT disconnect to the eventloop + pub fn try_disconnect(&mut self) -> Result<(), ClientError> { + self.client.try_disconnect() + } + + /// Stops the eventloop right away + pub fn cancel(&mut self) -> Result<(), ClientError> { + self.client.cancel_tx.send(()).map_err(ClientError::Cancel) + } +} diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index 659a713e7..7c1843100 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -93,7 +93,7 @@ impl EventLoop { EventLoop { options, - state: MqttState::new(max_inflight, manual_acks), + state: MqttState::new(max_inflight, manual_acks, cap), request_buf, request_buf_cache: VecDeque::with_capacity(cap), requests_tx, @@ -111,10 +111,14 @@ impl EventLoop { self.requests_tx.clone() } - pub fn buf(&self) -> &Arc>> { + pub fn request_buf(&self) -> &Arc>> { &self.request_buf } + pub fn sub_events_buf(&self) -> &Arc>> { + &self.state.sub_events_buf + } + /// Handle for cancelling the eventloop. /// /// Can be useful in cases when connection should be halted immediately diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index 83e21ef01..d8d2db4c1 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -233,6 +233,7 @@ pub struct MqttOptions { /// If set to `true` MQTT acknowledgements are not sent automatically. /// Every incoming publish packet must be manually acknowledged with `client.ack(...)` method. manual_acks: bool, + override_slow_subs: bool, } impl MqttOptions { @@ -260,6 +261,7 @@ impl MqttOptions { last_will: None, conn_timeout: 5, manual_acks: false, + override_slow_subs: false, } } @@ -408,6 +410,14 @@ impl MqttOptions { pub fn manual_acks(&self) -> bool { self.manual_acks } + + pub fn set_override_slow_subs(&mut self, val: bool) { + self.override_slow_subs = val; + } + + pub fn override_slow_subs(&self) -> bool { + self.override_slow_subs + } } #[cfg(feature = "url")] diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 2bc2b470a..0cec123b3 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -2,7 +2,11 @@ use super::{packet::*, Event, Incoming, Outgoing, Request}; use bytes::BytesMut; use std::collections::VecDeque; -use std::{io, mem, time::Instant}; +use std::{ + io, mem, + sync::{Arc, Mutex}, + time::Instant, +}; /// Errors during state handling #[derive(Debug, thiserror::Error)] @@ -75,13 +79,14 @@ pub struct MqttState { pub write: BytesMut, /// Indicates if acknowledgements should be send immediately pub manual_acks: bool, + pub(crate) sub_events_buf: Arc>>, } impl MqttState { /// Creates new mqtt state. Same state should be used during a /// connection for persistent sessions while new state should /// instantiated for clean sessions - pub fn new(max_inflight: u16, manual_acks: bool) -> Self { + pub fn new(max_inflight: u16, manual_acks: bool, cap: usize) -> Self { MqttState { await_pingresp: false, collision_ping_count: 0, @@ -99,6 +104,7 @@ impl MqttState { events: VecDeque::with_capacity(100), write: BytesMut::with_capacity(10 * 1024), manual_acks, + sub_events_buf: Arc::new(Mutex::new(VecDeque::with_capacity(cap))), } } @@ -195,13 +201,12 @@ impl MqttState { let qos = publish.qos; match qos { - QoS::AtMostOnce => Ok(()), + QoS::AtMostOnce => {}, QoS::AtLeastOnce => { if !self.manual_acks { let puback = PubAck::new(publish.pkid); self.outgoing_puback(puback)? } - Ok(()) } QoS::ExactlyOnce => { let pkid = publish.pkid; @@ -210,9 +215,13 @@ impl MqttState { let pubrec = PubRec::new(pkid); self.outgoing_pubrec(pubrec)?; } - Ok(()) } } + + // TODO: maybe limit the capacity of `self.sub_events_buf` + self.sub_events_buf.lock().unwrap().push_back(publish.clone()); + + Ok(()) } fn handle_incoming_puback(&mut self, puback: &PubAck) -> Result<(), StateError> { @@ -507,7 +516,7 @@ mod test { } fn build_mqttstate() -> MqttState { - MqttState::new(100, false) + MqttState::new(100, false, 100) } #[test] From c9c5b72ad083c636d0a1d5ffd840618c273d1fdf Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Tue, 22 Mar 2022 19:23:25 +0530 Subject: [PATCH 17/41] rumqttc: v5: fix pkid Signed-off-by: Abhik Jain --- rumqttc/src/v5/client/asyncclient.rs | 50 +++++++++++-- rumqttc/src/v5/client/publisher.rs | 37 ++++++++-- rumqttc/src/v5/state.rs | 102 ++++++++++++++++++--------- 3 files changed, 145 insertions(+), 44 deletions(-) diff --git a/rumqttc/src/v5/client/asyncclient.rs b/rumqttc/src/v5/client/asyncclient.rs index a00b295ef..d811cb2fc 100644 --- a/rumqttc/src/v5/client/asyncclient.rs +++ b/rumqttc/src/v5/client/asyncclient.rs @@ -1,6 +1,9 @@ use std::{ collections::VecDeque, - sync::{Arc, Mutex}, + sync::{ + atomic::{AtomicU16, Ordering}, + Arc, Mutex, + }, }; use bytes::Bytes; @@ -18,6 +21,8 @@ use crate::v5::{ pub struct AsyncClient { request_buf: Arc>>, sub_events_buf: Arc>>, + pkid_counter: Arc, + max_inflight: u16, sub_events_buf_cache: VecDeque, request_buf_capacity: usize, request_tx: Sender<()>, @@ -33,11 +38,15 @@ impl AsyncClient { let sub_events_buf_cache = VecDeque::with_capacity(cap); let request_tx = eventloop.handle(); let cancel_tx = eventloop.cancel_handle(); + let max_inflight = eventloop.state.max_inflight; + let pkid_counter = eventloop.state.pkid_counter().clone(); let client = AsyncClient { request_buf, request_buf_capacity: cap, sub_events_buf, + pkid_counter, + max_inflight, sub_events_buf_cache, request_tx, cancel_tx, @@ -51,6 +60,8 @@ impl AsyncClient { pub fn from_senders( request_buf: Arc>>, sub_events_buf: Arc>>, + pkid_counter: Arc, + max_inflight: u16, request_tx: Sender<()>, cancel_tx: Sender<()>, cap: usize, @@ -58,6 +69,8 @@ impl AsyncClient { AsyncClient { request_buf, request_buf_capacity: cap, + pkid_counter, + max_inflight, sub_events_buf, sub_events_buf_cache: VecDeque::with_capacity(cap), request_tx, @@ -79,7 +92,8 @@ impl AsyncClient { { let mut publish = Publish::new(topic, qos, payload); publish.retain = retain; - let pkid = publish.pkid; + let pkid = self.increment_pkid(); + publish.pkid = pkid; self.send_async_and_notify(Request::Publish(publish)) .await?; Ok(pkid) @@ -99,7 +113,8 @@ impl AsyncClient { { let mut publish = Publish::new(topic, qos, payload); publish.retain = retain; - let pkid = publish.pkid; + let pkid = self.increment_pkid(); + publish.pkid = pkid; self.try_send_and_notify(Request::Publish(publish))?; Ok(pkid) } @@ -127,13 +142,16 @@ impl AsyncClient { qos: QoS, retain: bool, payload: Bytes, - ) -> Result<(), ClientError> + ) -> Result where S: Into, { let mut publish = Publish::from_bytes(topic, qos, payload); publish.retain = retain; - self.send_async_and_notify(Request::Publish(publish)).await + let pkid = self.increment_pkid(); + publish.pkid = pkid; + self.send_async_and_notify(Request::Publish(publish)).await?; + Ok(pkid) } /// Sends a MQTT Subscribe to the eventloop @@ -257,6 +275,8 @@ impl AsyncClient { let publisher = Publisher { request_buf: self.request_buf.clone(), request_buf_capacity: self.request_buf_capacity, + pkid_counter: self.pkid_counter, + max_inflight: self.max_inflight, request_tx: self.request_tx.clone(), cancel_tx: self.cancel_tx.clone(), publish_topic: publish_topic.into(), @@ -271,4 +291,24 @@ impl AsyncClient { }; Ok((publisher, subscriber)) } + + fn increment_pkid(&self) -> u16 { + let mut cur_pkid = self.pkid_counter.load(Ordering::SeqCst); + loop { + let new_pkid = if cur_pkid > self.max_inflight { + 1 + } else { + cur_pkid + 1 + }; + match self.pkid_counter.compare_exchange( + cur_pkid, + new_pkid, + Ordering::SeqCst, + Ordering::Relaxed, + ) { + Ok(_prev_pkid) => break new_pkid, + Err(actual_pkid) => cur_pkid = actual_pkid, + } + } + } } diff --git a/rumqttc/src/v5/client/publisher.rs b/rumqttc/src/v5/client/publisher.rs index ae47769f6..e4de4af8b 100644 --- a/rumqttc/src/v5/client/publisher.rs +++ b/rumqttc/src/v5/client/publisher.rs @@ -1,6 +1,6 @@ use std::{ collections::VecDeque, - sync::{Arc, Mutex}, + sync::{Arc, atomic::{AtomicU16, Ordering}, Mutex}, }; use bytes::Bytes; @@ -11,6 +11,8 @@ use crate::v5::{packet::Publish, ClientError, QoS, Request}; pub struct Publisher { pub(crate) request_buf: Arc>>, pub(crate) request_buf_capacity: usize, + pub(crate) pkid_counter: Arc, + pub(crate) max_inflight: u16, pub(crate) request_tx: Sender<()>, pub(crate) cancel_tx: Sender<()>, pub(crate) publish_topic: String, @@ -26,7 +28,8 @@ impl Publisher { ) -> Result { let mut publish = Publish::new(&self.publish_topic, self.publish_qos, payload); publish.retain = retain; - let pkid = publish.pkid; + let pkid = self.increment_pkid(); + publish.pkid = pkid; self.send_async_and_notify(Request::Publish(publish)) .await?; Ok(pkid) @@ -40,16 +43,20 @@ impl Publisher { ) -> Result { let mut publish = Publish::new(&self.publish_topic, self.publish_qos, payload); publish.retain = retain; - let pkid = publish.pkid; + let pkid = self.increment_pkid(); + publish.pkid = pkid; self.try_send_and_notify(Request::Publish(publish))?; Ok(pkid) } /// Sends a MQTT Publish to the eventloop - pub async fn publish_bytes(&self, retain: bool, payload: Bytes) -> Result<(), ClientError> { + pub async fn publish_bytes(&self, retain: bool, payload: Bytes) -> Result { let mut publish = Publish::from_bytes(&self.publish_topic, self.publish_qos, payload); + let pkid = self.increment_pkid(); + publish.pkid = pkid; publish.retain = retain; - self.send_async_and_notify(Request::Publish(publish)).await + self.send_async_and_notify(Request::Publish(publish)).await?; + Ok(pkid) } async fn send_async_and_notify(&self, request: Request) -> Result<(), ClientError> { @@ -95,4 +102,24 @@ impl Publisher { } Ok(()) } + + fn increment_pkid(&self) -> u16 { + let mut cur_pkid = self.pkid_counter.load(Ordering::SeqCst); + loop { + let new_pkid = if cur_pkid > self.max_inflight { + 1 + } else { + cur_pkid + 1 + }; + match self.pkid_counter.compare_exchange( + cur_pkid, + new_pkid, + Ordering::SeqCst, + Ordering::Relaxed, + ) { + Ok(_prev_pkid) => break new_pkid, + Err(actual_pkid) => cur_pkid = actual_pkid, + } + } + } } diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 0cec123b3..c990f92d0 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -2,6 +2,7 @@ use super::{packet::*, Event, Incoming, Outgoing, Request}; use bytes::BytesMut; use std::collections::VecDeque; +use std::sync::atomic::{AtomicU16, Ordering}; use std::{ io, mem, sync::{Arc, Mutex}, @@ -59,8 +60,6 @@ pub struct MqttState { last_incoming: Instant, /// Last outgoing packet time last_outgoing: Instant, - /// Packet id of the last outgoing packet - pub(crate) last_pkid: u16, /// Number of outgoing inflight publishes pub(crate) inflight: u16, /// Maximum number of allowed inflight @@ -80,6 +79,7 @@ pub struct MqttState { /// Indicates if acknowledgements should be send immediately pub manual_acks: bool, pub(crate) sub_events_buf: Arc>>, + pkid_counter: Arc, } impl MqttState { @@ -92,7 +92,6 @@ impl MqttState { collision_ping_count: 0, last_incoming: Instant::now(), last_outgoing: Instant::now(), - last_pkid: 0, inflight: 0, max_inflight, // index 0 is wasted as 0 is not a valid packet id @@ -105,6 +104,38 @@ impl MqttState { write: BytesMut::with_capacity(10 * 1024), manual_acks, sub_events_buf: Arc::new(Mutex::new(VecDeque::with_capacity(cap))), + pkid_counter: Arc::new(AtomicU16::new(0)), + } + } + + #[inline] + pub(crate) fn pkid_counter(&self) -> &Arc { + &self.pkid_counter + } + + #[cfg(test)] + #[inline] + fn cur_pkid(&self) -> u16 { + self.pkid_counter.load(Ordering::SeqCst) + } + + pub(crate) fn increment_pkid(&self) -> u16 { + let mut cur_pkid = self.pkid_counter.load(Ordering::SeqCst); + loop { + let new_pkid = if cur_pkid > self.max_inflight { + 1 + } else { + cur_pkid + 1 + }; + match self.pkid_counter.compare_exchange( + cur_pkid, + new_pkid, + Ordering::SeqCst, + Ordering::Relaxed, + ) { + Ok(_prev_pkid) => break new_pkid, + Err(actual_pkid) => cur_pkid = actual_pkid, + } } } @@ -201,7 +232,7 @@ impl MqttState { let qos = publish.qos; match qos { - QoS::AtMostOnce => {}, + QoS::AtMostOnce => {} QoS::AtLeastOnce => { if !self.manual_acks { let puback = PubAck::new(publish.pkid); @@ -219,7 +250,10 @@ impl MqttState { } // TODO: maybe limit the capacity of `self.sub_events_buf` - self.sub_events_buf.lock().unwrap().push_back(publish.clone()); + self.sub_events_buf + .lock() + .unwrap() + .push_back(publish.clone()); Ok(()) } @@ -312,7 +346,7 @@ impl MqttState { fn outgoing_publish(&mut self, mut publish: Publish) -> Result<(), StateError> { if publish.qos != QoS::AtMostOnce { if publish.pkid == 0 { - publish.pkid = self.next_pkid(); + publish.pkid = self.increment_pkid(); } let pkid = publish.pkid; @@ -409,7 +443,7 @@ impl MqttState { } fn outgoing_subscribe(&mut self, mut subscription: Subscribe) -> Result<(), StateError> { - let pkid = self.next_pkid(); + let pkid = self.increment_pkid(); subscription.pkid = pkid; debug!( @@ -424,7 +458,7 @@ impl MqttState { } fn outgoing_unsubscribe(&mut self, mut unsub: Unsubscribe) -> Result<(), StateError> { - let pkid = self.next_pkid(); + let pkid = self.increment_pkid(); unsub.pkid = pkid; debug!( @@ -461,7 +495,7 @@ impl MqttState { let pubrel = match pubrel.pkid { // consider PacketIdentifier(0) as uninitialized packets 0 => { - pubrel.pkid = self.next_pkid(); + pubrel.pkid = self.increment_pkid(); pubrel } _ => pubrel, @@ -471,24 +505,24 @@ impl MqttState { Ok(pubrel) } - /// http://stackoverflow.com/questions/11115364/mqtt-messageid-practical-implementation - /// Packet ids are incremented till maximum set inflight messages and reset to 1 after that. - /// - fn next_pkid(&mut self) -> u16 { - let next_pkid = self.last_pkid + 1; - - // When next packet id is at the edge of inflight queue, - // set await flag. This instructs eventloop to stop - // processing requests until all the inflight publishes - // are acked - if next_pkid == self.max_inflight { - self.last_pkid = 0; - return next_pkid; - } - - self.last_pkid = next_pkid; - next_pkid - } + ///// http://stackoverflow.com/questions/11115364/mqtt-messageid-practical-implementation + ///// Packet ids are incremented till maximum set inflight messages and reset to 1 after that. + ///// + //fn next_pkid(&mut self) -> u16 { + // let next_pkid = self.last_pkid + 1; + + // // When next packet id is at the edge of inflight queue, + // // set await flag. This instructs eventloop to stop + // // processing requests until all the inflight publishes + // // are acked + // if next_pkid == self.max_inflight { + // self.last_pkid = 0; + // return next_pkid; + // } + + // self.last_pkid = next_pkid; + // next_pkid + //} } #[cfg(test)] @@ -521,10 +555,10 @@ mod test { #[test] fn next_pkid_increments_as_expected() { - let mut mqtt = build_mqttstate(); + let mqtt = build_mqttstate(); for i in 1..=100 { - let pkid = mqtt.next_pkid(); + let pkid = mqtt.increment_pkid(); // loops between 0-99. % 100 == 0 implies border let expected = i % 100; @@ -545,7 +579,7 @@ mod test { // QoS 0 publish shouldn't be saved in queue mqtt.outgoing_publish(publish).unwrap(); - assert_eq!(mqtt.last_pkid, 0); + assert_eq!(mqtt.cur_pkid(), 0); assert_eq!(mqtt.inflight, 0); // QoS1 Publish @@ -553,12 +587,12 @@ mod test { // Packet id should be set and publish should be saved in queue mqtt.outgoing_publish(publish.clone()).unwrap(); - assert_eq!(mqtt.last_pkid, 1); + assert_eq!(mqtt.cur_pkid(), 1); assert_eq!(mqtt.inflight, 1); // Packet id should be incremented and publish should be saved in queue mqtt.outgoing_publish(publish).unwrap(); - assert_eq!(mqtt.last_pkid, 2); + assert_eq!(mqtt.cur_pkid(), 2); assert_eq!(mqtt.inflight, 2); // QoS1 Publish @@ -566,12 +600,12 @@ mod test { // Packet id should be set and publish should be saved in queue mqtt.outgoing_publish(publish.clone()).unwrap(); - assert_eq!(mqtt.last_pkid, 3); + assert_eq!(mqtt.cur_pkid(), 3); assert_eq!(mqtt.inflight, 3); // Packet id should be incremented and publish should be saved in queue mqtt.outgoing_publish(publish).unwrap(); - assert_eq!(mqtt.last_pkid, 4); + assert_eq!(mqtt.cur_pkid(), 4); assert_eq!(mqtt.inflight, 4); } From eff7193e34f4d794dc8911f9585f6f2a5b4238a5 Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Thu, 24 Mar 2022 19:03:01 +0530 Subject: [PATCH 18/41] rumqttc: v5: remove cancel channel Signed-off-by: Abhik Jain --- rumqttc/src/v5/client/asyncclient.rs | 16 +------------ rumqttc/src/v5/client/publisher.rs | 9 ------- rumqttc/src/v5/client/syncclient.rs | 5 ---- rumqttc/src/v5/eventloop.rs | 35 +--------------------------- 4 files changed, 2 insertions(+), 63 deletions(-) diff --git a/rumqttc/src/v5/client/asyncclient.rs b/rumqttc/src/v5/client/asyncclient.rs index d811cb2fc..478b189fd 100644 --- a/rumqttc/src/v5/client/asyncclient.rs +++ b/rumqttc/src/v5/client/asyncclient.rs @@ -26,18 +26,16 @@ pub struct AsyncClient { sub_events_buf_cache: VecDeque, request_buf_capacity: usize, request_tx: Sender<()>, - pub(crate) cancel_tx: Sender<()>, } impl AsyncClient { /// Create a new `AsyncClient` pub fn new(options: MqttOptions, cap: usize) -> (AsyncClient, EventLoop) { - let mut eventloop = EventLoop::new(options, cap); + let eventloop = EventLoop::new(options, cap); let request_buf = eventloop.request_buf().clone(); let sub_events_buf = eventloop.sub_events_buf().clone(); let sub_events_buf_cache = VecDeque::with_capacity(cap); let request_tx = eventloop.handle(); - let cancel_tx = eventloop.cancel_handle(); let max_inflight = eventloop.state.max_inflight; let pkid_counter = eventloop.state.pkid_counter().clone(); @@ -49,7 +47,6 @@ impl AsyncClient { max_inflight, sub_events_buf_cache, request_tx, - cancel_tx, }; (client, eventloop) @@ -63,7 +60,6 @@ impl AsyncClient { pkid_counter: Arc, max_inflight: u16, request_tx: Sender<()>, - cancel_tx: Sender<()>, cap: usize, ) -> AsyncClient { AsyncClient { @@ -74,7 +70,6 @@ impl AsyncClient { sub_events_buf, sub_events_buf_cache: VecDeque::with_capacity(cap), request_tx, - cancel_tx, } } @@ -209,14 +204,6 @@ impl AsyncClient { self.try_send_and_notify(Request::Disconnect) } - /// Stops the eventloop right away - pub async fn cancel(&self) -> Result<(), ClientError> { - self.cancel_tx - .send_async(()) - .await - .map_err(ClientError::Cancel) - } - async fn send_async_and_notify(&self, request: Request) -> Result<(), ClientError> { { let mut request_buf = self.request_buf.lock().unwrap(); @@ -278,7 +265,6 @@ impl AsyncClient { pkid_counter: self.pkid_counter, max_inflight: self.max_inflight, request_tx: self.request_tx.clone(), - cancel_tx: self.cancel_tx.clone(), publish_topic: publish_topic.into(), publish_qos, }; diff --git a/rumqttc/src/v5/client/publisher.rs b/rumqttc/src/v5/client/publisher.rs index e4de4af8b..8cf924254 100644 --- a/rumqttc/src/v5/client/publisher.rs +++ b/rumqttc/src/v5/client/publisher.rs @@ -14,7 +14,6 @@ pub struct Publisher { pub(crate) pkid_counter: Arc, pub(crate) max_inflight: u16, pub(crate) request_tx: Sender<()>, - pub(crate) cancel_tx: Sender<()>, pub(crate) publish_topic: String, pub(crate) publish_qos: QoS, } @@ -83,14 +82,6 @@ impl Publisher { self.try_send_and_notify(Request::Disconnect) } - /// Stops the eventloop right away - pub async fn cancel(&self) -> Result<(), ClientError> { - self.cancel_tx - .send_async(()) - .await - .map_err(ClientError::Cancel) - } - fn try_send_and_notify(&self, request: Request) -> Result<(), ClientError> { let mut request_buf = self.request_buf.lock().unwrap(); if request_buf.len() == self.request_buf_capacity { diff --git a/rumqttc/src/v5/client/syncclient.rs b/rumqttc/src/v5/client/syncclient.rs index bed105012..749b06627 100644 --- a/rumqttc/src/v5/client/syncclient.rs +++ b/rumqttc/src/v5/client/syncclient.rs @@ -127,9 +127,4 @@ impl Client { pub fn try_disconnect(&mut self) -> Result<(), ClientError> { self.client.try_disconnect() } - - /// Stops the eventloop right away - pub fn cancel(&mut self) -> Result<(), ClientError> { - self.client.cancel_tx.send(()).map_err(ClientError::Cancel) - } } diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index 7c1843100..0c03fe3bf 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -64,10 +64,6 @@ pub struct EventLoop { pub(crate) network: Option, /// Keep alive time pub(crate) keepalive_timeout: Option>>, - /// Handle to read cancellation requests - pub(crate) cancel_rx: Receiver<()>, - /// Handle to send cancellation requests (and drops) - pub(crate) cancel_tx: Sender<()>, } /// Events which can be yielded by the event loop @@ -83,7 +79,6 @@ impl EventLoop { /// When connection encounters critical errors (like auth failure), user has a choice to /// access and update `options`, `state` and `requests`. pub fn new(options: MqttOptions, cap: usize) -> EventLoop { - let (cancel_tx, cancel_rx) = bounded(5); let (requests_tx, requests_rx) = bounded(1); let request_buf = Arc::new(Mutex::new(VecDeque::with_capacity(cap))); let pending = Vec::new(); @@ -101,8 +96,6 @@ impl EventLoop { pending, network: None, keepalive_timeout: None, - cancel_rx, - cancel_tx, } } @@ -119,14 +112,6 @@ impl EventLoop { &self.state.sub_events_buf } - /// Handle for cancelling the eventloop. - /// - /// Can be useful in cases when connection should be halted immediately - /// between half-open connection detections or (re)connection timeouts - pub(crate) fn cancel_handle(&mut self) -> Sender<()> { - self.cancel_tx.clone() - } - fn clean(&mut self) { self.network = None; self.keepalive_timeout = None; @@ -140,7 +125,7 @@ impl EventLoop { /// **NOTE** Don't block this while iterating pub async fn poll(&mut self) -> Result { if self.network.is_none() { - let (network, connack) = connect_or_cancel(&self.options, &self.cancel_rx).await?; + let (network, connack) = connect(&self.options).await?; self.network = Some(network); if self.keepalive_timeout.is_none() { @@ -243,29 +228,11 @@ impl EventLoop { network.flush(&mut self.state.write).await?; return Ok(self.state.events.pop_front().unwrap()) } - // cancellation requests to stop the polling - _ = self.cancel_rx.recv_async() => { - return Err(ConnectionError::Cancel) - } } } } } -async fn connect_or_cancel( - options: &MqttOptions, - cancel_rx: &Receiver<()>, -) -> Result<(Network, Incoming), ConnectionError> { - // select here prevents cancel request from being blocked until connection request is - // resolved. Returns with an error if connections fail continuously - select! { - o = connect(options) => o, - _ = cancel_rx.recv_async() => { - Err(ConnectionError::Cancel) - } - } -} - /// This stream internally processes requests from the request stream provided to the eventloop /// while also consuming byte stream from the network and yielding mqtt packets as the output of /// the stream. From 3aae9cf998967b2b002449ac6a95e737af658b6a Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Thu, 24 Mar 2022 19:13:08 +0530 Subject: [PATCH 19/41] rumqttc: v5: renaming buffers Signed-off-by: Abhik Jain --- rumqttc/src/v5/client/asyncclient.rs | 62 ++++++++++++++-------------- rumqttc/src/v5/client/publisher.rs | 12 +++--- rumqttc/src/v5/client/subscriber.rs | 18 ++++---- rumqttc/src/v5/eventloop.rs | 32 +++++++------- rumqttc/src/v5/state.rs | 8 ++-- 5 files changed, 66 insertions(+), 66 deletions(-) diff --git a/rumqttc/src/v5/client/asyncclient.rs b/rumqttc/src/v5/client/asyncclient.rs index 478b189fd..84a75b8f0 100644 --- a/rumqttc/src/v5/client/asyncclient.rs +++ b/rumqttc/src/v5/client/asyncclient.rs @@ -19,12 +19,12 @@ use crate::v5::{ /// This is cloneable and can be used to asynchronously Publish, Subscribe. #[derive(Clone, Debug)] pub struct AsyncClient { - request_buf: Arc>>, - sub_events_buf: Arc>>, + incoming_buf: Arc>>, + outgoing_buf: Arc>>, pkid_counter: Arc, max_inflight: u16, - sub_events_buf_cache: VecDeque, - request_buf_capacity: usize, + incoming_buf_cache: VecDeque, + incoming_buf_capacity: usize, request_tx: Sender<()>, } @@ -33,19 +33,19 @@ impl AsyncClient { pub fn new(options: MqttOptions, cap: usize) -> (AsyncClient, EventLoop) { let eventloop = EventLoop::new(options, cap); let request_buf = eventloop.request_buf().clone(); - let sub_events_buf = eventloop.sub_events_buf().clone(); - let sub_events_buf_cache = VecDeque::with_capacity(cap); + let incoming_buf = eventloop.state.incoming_buf.clone(); + let incoming_buf_cache = VecDeque::with_capacity(cap); let request_tx = eventloop.handle(); let max_inflight = eventloop.state.max_inflight; let pkid_counter = eventloop.state.pkid_counter().clone(); let client = AsyncClient { - request_buf, - request_buf_capacity: cap, - sub_events_buf, + incoming_buf: request_buf, + incoming_buf_capacity: cap, + outgoing_buf: incoming_buf, pkid_counter, max_inflight, - sub_events_buf_cache, + incoming_buf_cache, request_tx, }; @@ -56,19 +56,19 @@ impl AsyncClient { /// creating a test instance. pub fn from_senders( request_buf: Arc>>, - sub_events_buf: Arc>>, + incoming_buf: Arc>>, pkid_counter: Arc, max_inflight: u16, request_tx: Sender<()>, cap: usize, ) -> AsyncClient { AsyncClient { - request_buf, - request_buf_capacity: cap, + incoming_buf: request_buf, + incoming_buf_capacity: cap, pkid_counter, max_inflight, - sub_events_buf, - sub_events_buf_cache: VecDeque::with_capacity(cap), + outgoing_buf: incoming_buf, + incoming_buf_cache: VecDeque::with_capacity(cap), request_tx, } } @@ -206,8 +206,8 @@ impl AsyncClient { async fn send_async_and_notify(&self, request: Request) -> Result<(), ClientError> { { - let mut request_buf = self.request_buf.lock().unwrap(); - if request_buf.len() == self.request_buf_capacity { + let mut request_buf = self.incoming_buf.lock().unwrap(); + if request_buf.len() == self.incoming_buf_capacity { return Err(ClientError::RequestsFull); } request_buf.push_back(request); @@ -219,8 +219,8 @@ impl AsyncClient { } pub(crate) fn send_and_notify(&self, request: Request) -> Result<(), ClientError> { - let mut request_buf = self.request_buf.lock().unwrap(); - if request_buf.len() == self.request_buf_capacity { + let mut request_buf = self.incoming_buf.lock().unwrap(); + if request_buf.len() == self.incoming_buf_capacity { return Err(ClientError::RequestsFull); } request_buf.push_back(request); @@ -231,8 +231,8 @@ impl AsyncClient { } fn try_send_and_notify(&self, request: Request) -> Result<(), ClientError> { - let mut request_buf = self.request_buf.lock().unwrap(); - if request_buf.len() == self.request_buf_capacity { + let mut request_buf = self.incoming_buf.lock().unwrap(); + if request_buf.len() == self.incoming_buf_capacity { return Err(ClientError::RequestsFull); } request_buf.push_back(request); @@ -243,15 +243,15 @@ impl AsyncClient { } pub fn next_publish(&mut self) -> Option { - if let Some(publish) = self.sub_events_buf_cache.pop_front() { + if let Some(publish) = self.incoming_buf_cache.pop_front() { return Some(publish); } std::mem::swap( - &mut self.sub_events_buf_cache, - &mut *self.sub_events_buf.lock().unwrap(), + &mut self.incoming_buf_cache, + &mut *self.outgoing_buf.lock().unwrap(), ); - self.sub_events_buf_cache.pop_front() + self.incoming_buf_cache.pop_front() } pub async fn split( @@ -260,8 +260,8 @@ impl AsyncClient { publish_qos: QoS, ) -> Result<(Publisher, Subscriber), ClientError> { let publisher = Publisher { - request_buf: self.request_buf.clone(), - request_buf_capacity: self.request_buf_capacity, + incoming_buf: self.incoming_buf.clone(), + incoming_buf_capacity: self.incoming_buf_capacity, pkid_counter: self.pkid_counter, max_inflight: self.max_inflight, request_tx: self.request_tx.clone(), @@ -269,10 +269,10 @@ impl AsyncClient { publish_qos, }; let subscriber = Subscriber { - request_buf: self.request_buf, - sub_events_buf: self.sub_events_buf, - sub_events_buf_cache: self.sub_events_buf_cache, - request_buf_capacity: self.request_buf_capacity, + outgoing_buf: self.incoming_buf, + incoming_buf: self.outgoing_buf, + incoming_buf_cache: self.incoming_buf_cache, + request_buf_capacity: self.incoming_buf_capacity, request_tx: self.request_tx, }; Ok((publisher, subscriber)) diff --git a/rumqttc/src/v5/client/publisher.rs b/rumqttc/src/v5/client/publisher.rs index 8cf924254..3b5c07f2d 100644 --- a/rumqttc/src/v5/client/publisher.rs +++ b/rumqttc/src/v5/client/publisher.rs @@ -9,8 +9,8 @@ use flume::{SendError, Sender, TrySendError}; use crate::v5::{packet::Publish, ClientError, QoS, Request}; pub struct Publisher { - pub(crate) request_buf: Arc>>, - pub(crate) request_buf_capacity: usize, + pub(crate) incoming_buf: Arc>>, + pub(crate) incoming_buf_capacity: usize, pub(crate) pkid_counter: Arc, pub(crate) max_inflight: u16, pub(crate) request_tx: Sender<()>, @@ -60,8 +60,8 @@ impl Publisher { async fn send_async_and_notify(&self, request: Request) -> Result<(), ClientError> { { - let mut request_buf = self.request_buf.lock().unwrap(); - if request_buf.len() == self.request_buf_capacity { + let mut request_buf = self.incoming_buf.lock().unwrap(); + if request_buf.len() == self.incoming_buf_capacity { return Err(ClientError::RequestsFull); } request_buf.push_back(request); @@ -83,8 +83,8 @@ impl Publisher { } fn try_send_and_notify(&self, request: Request) -> Result<(), ClientError> { - let mut request_buf = self.request_buf.lock().unwrap(); - if request_buf.len() == self.request_buf_capacity { + let mut request_buf = self.incoming_buf.lock().unwrap(); + if request_buf.len() == self.incoming_buf_capacity { return Err(ClientError::RequestsFull); } request_buf.push_back(request); diff --git a/rumqttc/src/v5/client/subscriber.rs b/rumqttc/src/v5/client/subscriber.rs index 6a7889f18..d1982bddd 100644 --- a/rumqttc/src/v5/client/subscriber.rs +++ b/rumqttc/src/v5/client/subscriber.rs @@ -12,24 +12,24 @@ use crate::v5::{ #[derive(Debug, Clone)] pub struct Subscriber { - pub(crate) request_buf: Arc>>, - pub(crate) sub_events_buf: Arc>>, - pub(crate) sub_events_buf_cache: VecDeque, + pub(crate) outgoing_buf: Arc>>, + pub(crate) incoming_buf: Arc>>, + pub(crate) incoming_buf_cache: VecDeque, pub(crate) request_buf_capacity: usize, pub(crate) request_tx: Sender<()>, } impl Subscriber { pub fn next_publish(&mut self) -> Option { - if let Some(publish) = self.sub_events_buf_cache.pop_front() { + if let Some(publish) = self.incoming_buf_cache.pop_front() { return Some(publish); } std::mem::swap( - &mut self.sub_events_buf_cache, - &mut *self.sub_events_buf.lock().unwrap(), + &mut self.incoming_buf_cache, + &mut *self.incoming_buf.lock().unwrap(), ); - self.sub_events_buf_cache.pop_front() + self.incoming_buf_cache.pop_front() } /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. @@ -95,7 +95,7 @@ impl Subscriber { async fn send_async_and_notify(&self, request: Request) -> Result<(), ClientError> { { - let mut request_buf = self.request_buf.lock().unwrap(); + let mut request_buf = self.outgoing_buf.lock().unwrap(); if request_buf.len() == self.request_buf_capacity { return Err(ClientError::RequestsFull); } @@ -108,7 +108,7 @@ impl Subscriber { } fn try_send_and_notify(&self, request: Request) -> Result<(), ClientError> { - let mut request_buf = self.request_buf.lock().unwrap(); + let mut request_buf = self.outgoing_buf.lock().unwrap(); if request_buf.len() == self.request_buf_capacity { return Err(ClientError::RequestsFull); } diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index 0c03fe3bf..d8b8ad4cf 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -52,12 +52,12 @@ pub struct EventLoop { pub options: MqttOptions, /// Current state of the connection pub state: MqttState, - request_buf: Arc>>, - request_buf_cache: VecDeque, + incoming_buf: Arc>>, + incoming_buf_cache: VecDeque, /// Request stream - pub requests_rx: Receiver<()>, + pub incoming_rx: Receiver<()>, /// Requests handle to send requests - pub requests_tx: Sender<()>, + pub incoming_tx: Sender<()>, /// Pending packets from last session pub pending: IntoIter, /// Network connection to the broker @@ -79,7 +79,7 @@ impl EventLoop { /// When connection encounters critical errors (like auth failure), user has a choice to /// access and update `options`, `state` and `requests`. pub fn new(options: MqttOptions, cap: usize) -> EventLoop { - let (requests_tx, requests_rx) = bounded(1); + let (incoming_tx, incoming_rx) = bounded(1); let request_buf = Arc::new(Mutex::new(VecDeque::with_capacity(cap))); let pending = Vec::new(); let pending = pending.into_iter(); @@ -89,10 +89,10 @@ impl EventLoop { EventLoop { options, state: MqttState::new(max_inflight, manual_acks, cap), - request_buf, - request_buf_cache: VecDeque::with_capacity(cap), - requests_tx, - requests_rx, + incoming_buf: request_buf, + incoming_buf_cache: VecDeque::with_capacity(cap), + incoming_tx, + incoming_rx, pending, network: None, keepalive_timeout: None, @@ -101,15 +101,15 @@ impl EventLoop { /// Returns a handle to communicate with this eventloop pub fn handle(&self) -> Sender<()> { - self.requests_tx.clone() + self.incoming_tx.clone() } pub fn request_buf(&self) -> &Arc>> { - &self.request_buf + &self.incoming_buf } pub fn sub_events_buf(&self) -> &Arc>> { - &self.state.sub_events_buf + &self.state.incoming_buf } fn clean(&mut self) { @@ -194,14 +194,14 @@ impl EventLoop { // After collision with pkid 1 -> [1b ,2, x, 4, 5]. // 1a is saved to state and event loop is set to collision mode stopping new // outgoing requests (along with 1b). - o = self.requests_rx.recv_async(), if !inflight_full && !pending && !collision => match o { + o = self.incoming_rx.recv_async(), if !inflight_full && !pending && !collision => match o { Ok(_request_notif) => { // swapping to avoid blocking the mutex - std::mem::swap(&mut self.request_buf_cache,&mut *self.request_buf.lock().unwrap()); - if self.request_buf_cache.is_empty() { + std::mem::swap(&mut self.incoming_buf_cache,&mut *self.incoming_buf.lock().unwrap()); + if self.incoming_buf_cache.is_empty() { continue; } - for request in self.request_buf_cache.drain(..) { + for request in self.incoming_buf_cache.drain(..) { self.state.handle_outgoing_packet(request)?; } network.flush(&mut self.state.write).await?; diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index c990f92d0..8ad0e5375 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -78,7 +78,7 @@ pub struct MqttState { pub write: BytesMut, /// Indicates if acknowledgements should be send immediately pub manual_acks: bool, - pub(crate) sub_events_buf: Arc>>, + pub(crate) incoming_buf: Arc>>, pkid_counter: Arc, } @@ -103,7 +103,7 @@ impl MqttState { events: VecDeque::with_capacity(100), write: BytesMut::with_capacity(10 * 1024), manual_acks, - sub_events_buf: Arc::new(Mutex::new(VecDeque::with_capacity(cap))), + incoming_buf: Arc::new(Mutex::new(VecDeque::with_capacity(cap))), pkid_counter: Arc::new(AtomicU16::new(0)), } } @@ -249,8 +249,8 @@ impl MqttState { } } - // TODO: maybe limit the capacity of `self.sub_events_buf` - self.sub_events_buf + // TODO: maybe limit the capacity of `self.incoming_buf` + self.incoming_buf .lock() .unwrap() .push_back(publish.clone()); From cd05cd5d285b4f2e5c5e567e0d229814066b52cd Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Thu, 24 Mar 2022 19:13:49 +0530 Subject: [PATCH 20/41] rumqttc: v5: asyncclient: remove `AsyncClient::from_senders(..)` Signed-off-by: Abhik Jain --- rumqttc/src/v5/client/asyncclient.rs | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/rumqttc/src/v5/client/asyncclient.rs b/rumqttc/src/v5/client/asyncclient.rs index 84a75b8f0..9e72ab45c 100644 --- a/rumqttc/src/v5/client/asyncclient.rs +++ b/rumqttc/src/v5/client/asyncclient.rs @@ -52,27 +52,6 @@ impl AsyncClient { (client, eventloop) } - /// Create a new `AsyncClient` from a pair of async channel `Sender`s. This is mostly useful for - /// creating a test instance. - pub fn from_senders( - request_buf: Arc>>, - incoming_buf: Arc>>, - pkid_counter: Arc, - max_inflight: u16, - request_tx: Sender<()>, - cap: usize, - ) -> AsyncClient { - AsyncClient { - incoming_buf: request_buf, - incoming_buf_capacity: cap, - pkid_counter, - max_inflight, - outgoing_buf: incoming_buf, - incoming_buf_cache: VecDeque::with_capacity(cap), - request_tx, - } - } - /// Sends a MQTT Publish to the eventloop pub async fn publish( &self, From 1c42f8f947c3e39a86e31326a737792dc7d2e088 Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Thu, 24 Mar 2022 19:15:35 +0530 Subject: [PATCH 21/41] rumqttc: v5: asyncclient: internal functions rename Signed-off-by: Abhik Jain --- rumqttc/src/v5/client/asyncclient.rs | 32 ++++++++++++++-------------- rumqttc/src/v5/client/syncclient.rs | 12 +++++------ 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/rumqttc/src/v5/client/asyncclient.rs b/rumqttc/src/v5/client/asyncclient.rs index 9e72ab45c..f7e1f3783 100644 --- a/rumqttc/src/v5/client/asyncclient.rs +++ b/rumqttc/src/v5/client/asyncclient.rs @@ -68,7 +68,7 @@ impl AsyncClient { publish.retain = retain; let pkid = self.increment_pkid(); publish.pkid = pkid; - self.send_async_and_notify(Request::Publish(publish)) + self.push_and_async_notify(Request::Publish(publish)) .await?; Ok(pkid) } @@ -89,14 +89,14 @@ impl AsyncClient { publish.retain = retain; let pkid = self.increment_pkid(); publish.pkid = pkid; - self.try_send_and_notify(Request::Publish(publish))?; + self.push_and_try_notify(Request::Publish(publish))?; Ok(pkid) } /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. pub async fn ack(&self, publish: &Publish) -> Result<(), ClientError> { if let Some(ack) = get_ack_req(publish.qos, publish.pkid) { - self.send_async_and_notify(ack).await?; + self.push_and_async_notify(ack).await?; } Ok(()) } @@ -104,7 +104,7 @@ impl AsyncClient { /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { if let Some(ack) = get_ack_req(publish.qos, publish.pkid) { - self.try_send_and_notify(ack)?; + self.push_and_try_notify(ack)?; } Ok(()) } @@ -124,21 +124,21 @@ impl AsyncClient { publish.retain = retain; let pkid = self.increment_pkid(); publish.pkid = pkid; - self.send_async_and_notify(Request::Publish(publish)).await?; + self.push_and_async_notify(Request::Publish(publish)).await?; Ok(pkid) } /// Sends a MQTT Subscribe to the eventloop pub async fn subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { let subscribe = Subscribe::new(topic.into(), qos); - self.send_async_and_notify(Request::Subscribe(subscribe)) + self.push_and_async_notify(Request::Subscribe(subscribe)) .await } /// Sends a MQTT Subscribe to the eventloop pub fn try_subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { let subscribe = Subscribe::new(topic.into(), qos); - self.try_send_and_notify(Request::Subscribe(subscribe)) + self.push_and_try_notify(Request::Subscribe(subscribe)) } /// Sends a MQTT Subscribe for multiple topics to the eventloop @@ -147,7 +147,7 @@ impl AsyncClient { T: IntoIterator, { let subscribe = Subscribe::new_many(topics); - self.send_async_and_notify(Request::Subscribe(subscribe)) + self.push_and_async_notify(Request::Subscribe(subscribe)) .await } @@ -157,33 +157,33 @@ impl AsyncClient { T: IntoIterator, { let subscribe = Subscribe::new_many(topics); - self.try_send_and_notify(Request::Subscribe(subscribe)) + self.push_and_try_notify(Request::Subscribe(subscribe)) } /// Sends a MQTT Unsubscribe to the eventloop pub async fn unsubscribe>(&self, topic: S) -> Result<(), ClientError> { let unsubscribe = Unsubscribe::new(topic.into()); - self.send_async_and_notify(Request::Unsubscribe(unsubscribe)) + self.push_and_async_notify(Request::Unsubscribe(unsubscribe)) .await } /// Sends a MQTT Unsubscribe to the eventloop pub fn try_unsubscribe>(&self, topic: S) -> Result<(), ClientError> { let unsubscribe = Unsubscribe::new(topic.into()); - self.try_send_and_notify(Request::Unsubscribe(unsubscribe)) + self.push_and_try_notify(Request::Unsubscribe(unsubscribe)) } /// Sends a MQTT disconnect to the eventloop pub async fn disconnect(&self) -> Result<(), ClientError> { - self.send_async_and_notify(Request::Disconnect).await + self.push_and_async_notify(Request::Disconnect).await } /// Sends a MQTT disconnect to the eventloop pub fn try_disconnect(&self) -> Result<(), ClientError> { - self.try_send_and_notify(Request::Disconnect) + self.push_and_try_notify(Request::Disconnect) } - async fn send_async_and_notify(&self, request: Request) -> Result<(), ClientError> { + async fn push_and_async_notify(&self, request: Request) -> Result<(), ClientError> { { let mut request_buf = self.incoming_buf.lock().unwrap(); if request_buf.len() == self.incoming_buf_capacity { @@ -197,7 +197,7 @@ impl AsyncClient { Ok(()) } - pub(crate) fn send_and_notify(&self, request: Request) -> Result<(), ClientError> { + pub(crate) fn push_and_notify(&self, request: Request) -> Result<(), ClientError> { let mut request_buf = self.incoming_buf.lock().unwrap(); if request_buf.len() == self.incoming_buf_capacity { return Err(ClientError::RequestsFull); @@ -209,7 +209,7 @@ impl AsyncClient { Ok(()) } - fn try_send_and_notify(&self, request: Request) -> Result<(), ClientError> { + fn push_and_try_notify(&self, request: Request) -> Result<(), ClientError> { let mut request_buf = self.incoming_buf.lock().unwrap(); if request_buf.len() == self.incoming_buf_capacity { return Err(ClientError::RequestsFull); diff --git a/rumqttc/src/v5/client/syncclient.rs b/rumqttc/src/v5/client/syncclient.rs index 749b06627..e517db9ff 100644 --- a/rumqttc/src/v5/client/syncclient.rs +++ b/rumqttc/src/v5/client/syncclient.rs @@ -44,7 +44,7 @@ impl Client { let mut publish = Publish::new(topic, qos, payload); publish.retain = retain; let pkid = publish.pkid; - self.client.send_and_notify(Request::Publish(publish))?; + self.client.push_and_notify(Request::Publish(publish))?; Ok(pkid) } @@ -65,7 +65,7 @@ impl Client { /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. pub fn ack(&self, publish: &Publish) -> Result<(), ClientError> { if let Some(ack) = get_ack_req(publish.qos, publish.pkid) { - self.client.send_and_notify(ack)?; + self.client.push_and_notify(ack)?; } Ok(()) } @@ -78,7 +78,7 @@ impl Client { /// Sends a MQTT Subscribe to the eventloop pub fn subscribe>(&mut self, topic: S, qos: QoS) -> Result<(), ClientError> { let subscribe = Subscribe::new(topic.into(), qos); - self.client.send_and_notify(Request::Subscribe(subscribe)) + self.client.push_and_notify(Request::Subscribe(subscribe)) } /// Sends a MQTT Subscribe to the eventloop @@ -96,7 +96,7 @@ impl Client { T: IntoIterator, { let subscribe = Subscribe::new_many(topics); - self.client.send_and_notify(Request::Subscribe(subscribe)) + self.client.push_and_notify(Request::Subscribe(subscribe)) } pub fn try_subscribe_many(&mut self, topics: T) -> Result<(), ClientError> @@ -110,7 +110,7 @@ impl Client { pub fn unsubscribe>(&mut self, topic: S) -> Result<(), ClientError> { let unsubscribe = Unsubscribe::new(topic.into()); self.client - .send_and_notify(Request::Unsubscribe(unsubscribe)) + .push_and_notify(Request::Unsubscribe(unsubscribe)) } /// Sends a MQTT Unsubscribe to the eventloop @@ -120,7 +120,7 @@ impl Client { /// Sends a MQTT disconnect to the eventloop pub fn disconnect(&mut self) -> Result<(), ClientError> { - self.client.send_and_notify(Request::Disconnect) + self.client.push_and_notify(Request::Disconnect) } /// Sends a MQTT disconnect to the eventloop From fe7beb99075faad2e31bd7c0ffdabd308c57c7d5 Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Thu, 24 Mar 2022 21:14:06 +0530 Subject: [PATCH 22/41] rumqttc: merge master (fixed) Signed-off-by: Abhik Jain --- rumqttc/examples/websocket.rs | 2 +- rumqttc/src/v4/mod.rs | 127 +++++++++++++++++++++++++++++++--- rumqttc/src/v5/eventloop.rs | 8 ++- rumqttc/src/v5/mod.rs | 33 +++++---- 4 files changed, 140 insertions(+), 30 deletions(-) diff --git a/rumqttc/examples/websocket.rs b/rumqttc/examples/websocket.rs index 712427b3c..b5d1e39a4 100644 --- a/rumqttc/examples/websocket.rs +++ b/rumqttc/examples/websocket.rs @@ -1,5 +1,5 @@ #[cfg(feature = "websocket")] -use rumqttc::{self, AsyncClient, MqttOptions, QoS, Transport}; +use rumqttc::v4::{AsyncClient, MqttOptions, QoS, Transport}; #[cfg(feature = "websocket")] use std::{error::Error, time::Duration}; #[cfg(feature = "websocket")] diff --git a/rumqttc/src/v4/mod.rs b/rumqttc/src/v4/mod.rs index 709b327bf..0c39b6478 100644 --- a/rumqttc/src/v4/mod.rs +++ b/rumqttc/src/v4/mod.rs @@ -1,4 +1,103 @@ +//! A pure rust MQTT client which strives to be robust, efficient and easy to use. +//! This library is backed by an async (tokio) eventloop which handles all the +//! robustness and and efficiency parts of MQTT but naturally fits into both sync +//! and async worlds as we'll see +//! +//! Let's jump into examples right away +//! +//! A simple synchronous publish and subscribe +//! ---------------------------- +//! +//! ```no_run +//! use rumqttc::v4::{MqttOptions, Client, QoS}; +//! use std::time::Duration; +//! use std::thread; +//! +//! let mut mqttoptions = MqttOptions::new("rumqtt-sync", "test.mosquitto.org", 1883); +//! mqttoptions.set_keep_alive(Duration::from_secs(5)); +//! +//! let (mut client, mut connection) = Client::new(mqttoptions, 10); +//! client.subscribe("hello/rumqtt", QoS::AtMostOnce).unwrap(); +//! thread::spawn(move || for i in 0..10 { +//! client.publish("hello/rumqtt", QoS::AtLeastOnce, false, vec![i; i as usize]).unwrap(); +//! thread::sleep(Duration::from_millis(100)); +//! }); +//! +//! // Iterate to poll the eventloop for connection progress +//! for (i, notification) in connection.iter().enumerate() { +//! println!("Notification = {:?}", notification); +//! } +//! ``` +//! +//! A simple asynchronous publish and subscribe +//! ------------------------------ +//! +//! ```no_run +//! use rumqttc::v4::{MqttOptions, AsyncClient, QoS}; +//! use tokio::{task, time}; +//! use std::time::Duration; +//! use std::error::Error; +//! +//! # #[tokio::main(worker_threads = 1)] +//! # async fn main() { +//! let mut mqttoptions = MqttOptions::new("rumqtt-async", "test.mosquitto.org", 1883); +//! mqttoptions.set_keep_alive(Duration::from_secs(5)); +//! +//! let (mut client, mut eventloop) = AsyncClient::new(mqttoptions, 10); +//! client.subscribe("hello/rumqtt", QoS::AtMostOnce).await.unwrap(); +//! +//! task::spawn(async move { +//! for i in 0..10 { +//! client.publish("hello/rumqtt", QoS::AtLeastOnce, false, vec![i; i as usize]).await.unwrap(); +//! time::sleep(Duration::from_millis(100)).await; +//! } +//! }); +//! +//! loop { +//! let notification = eventloop.poll().await.unwrap(); +//! println!("Received = {:?}", notification); +//! } +//! # } +//! ``` +//! +//! Quick overview of features +//! - Eventloop orchestrates outgoing/incoming packets concurrently and hadles the state +//! - Pings the broker when necessary and detects client side half open connections as well +//! - Throttling of outgoing packets (todo) +//! - Queue size based flow control on outgoing packets +//! - Automatic reconnections by just continuing the `eventloop.poll()/connection.iter()` loop` +//! - Natural backpressure to client APIs during bad network +//! - Immediate cancellation with `client.cancel()` +//! +//! In short, everything necessary to maintain a robust connection +//! +//! Since the eventloop is externally polled (with `iter()/poll()` in a loop) +//! out side the library and `Eventloop` is accessible, users can +//! - Distribute incoming messages based on topics +//! - Stop it when required +//! - Access internal state for use cases like graceful shutdown or to modify options before reconnection +//! +//! ## Important notes +//! +//! - Looping on `connection.iter()`/`eventloop.poll()` is necessary to run the +//! event loop and make progress. It yields incoming and outgoing activity +//! notifications which allows customization as you see fit. +//! +//! - Blocking inside the `connection.iter()`/`eventloop.poll()` loop will block +//! connection progress. +//! +//! ## FAQ +//! **Connecting to a broker using raw ip doesn't work** +//! +//! You cannot create a TLS connection to a bare IP address with a self-signed +//! certificate. This is a [limitation of rustls](https://github.com/ctz/rustls/issues/184). +//! One workaround, which only works under *nix/BSD-like systems, is to add an +//! entry to wherever your DNS resolver looks (e.g. `/etc/hosts`) for the bare IP +//! address and use that name in your code. +#![cfg_attr(docsrs, feature(doc_cfg))] + use std::fmt::{self, Debug, Formatter}; +#[cfg(feature = "use-rustls")] use std::sync::Arc; use std::time::Duration; @@ -6,15 +105,18 @@ mod client; mod eventloop; mod framed; mod state; +#[cfg(feature = "use-rustls")] mod tls; -pub use crate::mqttbytes::v4::*; -pub use crate::mqttbytes::*; pub use async_channel::{SendError, Sender, TrySendError}; pub use client::{AsyncClient, Client, ClientError, Connection}; pub use eventloop::{ConnectionError, Event, EventLoop}; +pub use crate::mqttbytes::v4::*; +pub use crate::mqttbytes::*; pub use state::{MqttState, StateError}; -pub use tls::Error; +#[cfg(feature = "use-rustls")] +pub use tls::Error as TlsError; +#[cfg(feature = "use-rustls")] pub use tokio_rustls::rustls::ClientConfig; pub type Incoming = Packet; @@ -92,14 +194,15 @@ impl From for Request { #[derive(Clone)] pub enum Transport { Tcp, + #[cfg(feature = "use-rustls")] Tls(TlsConfiguration), #[cfg(unix)] Unix, #[cfg(feature = "websocket")] #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] Ws, - #[cfg(feature = "websocket")] - #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + #[cfg(all(feature = "use-rustls", feature = "websocket"))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "use-rustls", feature = "websocket"))))] Wss(TlsConfiguration), } @@ -116,6 +219,7 @@ impl Transport { } /// Use secure tcp with tls as transport + #[cfg(feature = "use-rustls")] pub fn tls( ca: Vec, client_auth: Option<(Vec, Key)>, @@ -130,6 +234,7 @@ impl Transport { Self::tls_with_config(config) } + #[cfg(feature = "use-rustls")] pub fn tls_with_config(tls_config: TlsConfiguration) -> Self { Self::Tls(tls_config) } @@ -147,8 +252,8 @@ impl Transport { } /// Use secure websockets with tls as transport - #[cfg(feature = "websocket")] - #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + #[cfg(all(feature = "use-rustls", feature = "websocket"))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "use-rustls", feature = "websocket"))))] pub fn wss( ca: Vec, client_auth: Option<(Vec, Key)>, @@ -163,14 +268,15 @@ impl Transport { Self::wss_with_config(config) } - #[cfg(feature = "websocket")] - #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + #[cfg(all(feature = "use-rustls", feature = "websocket"))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "use-rustls", feature = "websocket"))))] pub fn wss_with_config(tls_config: TlsConfiguration) -> Self { Self::Wss(tls_config) } } #[derive(Clone)] +#[cfg(feature = "use-rustls")] pub enum TlsConfiguration { Simple { /// connection method @@ -184,6 +290,7 @@ pub enum TlsConfiguration { Rustls(Arc), } +#[cfg(feature = "use-rustls")] impl From for TlsConfiguration { fn from(config: ClientConfig) -> Self { TlsConfiguration::Rustls(Arc::new(config)) @@ -613,7 +720,7 @@ mod test { } #[test] - #[cfg(feature = "websocket")] + #[cfg(all(feature = "use-rustls", feature = "websocket"))] fn no_scheme() { let mut _mqtt_opts = MqttOptions::new("client_a", "a3f8czas.iot.eu-west-1.amazonaws.com/mqtt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=MyCreds%2F20201001%2Feu-west-1%2Fiotdevicegateway%2Faws4_request&X-Amz-Date=20201001T130812Z&X-Amz-Expires=7200&X-Amz-Signature=9ae09b49896f44270f2707551581953e6cac71a4ccf34c7c3415555be751b2d1&X-Amz-SignedHeaders=host", 443); diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index d8b8ad4cf..a28fd1e23 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -1,7 +1,9 @@ use crate::v5::{ - framed::Network, packet::*, tls, Incoming, MqttOptions, MqttState, Outgoing, Packet, Request, + framed::Network, packet::*, Incoming, MqttOptions, MqttState, Outgoing, Packet, Request, StateError, Transport, }; +#[cfg(feature = "use-rustls")] +use crate::v5::tls; #[cfg(feature = "websocket")] use async_tungstenite::tokio::{connect_async, connect_async_with_tls_connector}; @@ -34,6 +36,7 @@ pub enum ConnectionError { Timeout(#[from] Elapsed), #[error("Packet parsing error: {0}")] Mqtt5Bytes(Error), + #[cfg(feature = "use-rustls")] #[error("Network: {0}")] Network(#[from] tls::Error), #[error("I/O: {0}")] @@ -269,8 +272,9 @@ async fn network_connect(options: &MqttOptions) -> Result { - let socket = tls::tls_connect(&options, &tls_config).await?; + let socket = tls::tls_connect(options, &tls_config).await?; Network::new(socket, options.max_incoming_packet_size) } #[cfg(unix)] diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index d8d2db4c1..943a14f9b 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -1,4 +1,5 @@ use std::fmt::{self, Debug, Formatter}; +#[cfg(feature = "use-rustls")] use std::sync::Arc; use std::time::Duration; @@ -7,6 +8,7 @@ mod eventloop; mod framed; mod packet; mod state; +#[cfg(feature = "use-rustls")] mod tls; pub use client::{AsyncClient, Client, ClientError, Connection}; @@ -14,7 +16,9 @@ pub use eventloop::{ConnectionError, Event, EventLoop}; pub use flume::{SendError, Sender, TrySendError}; pub use packet::*; pub use state::{MqttState, StateError}; +#[cfg(feature = "use-rustls")] pub use tls::Error; +#[cfg(feature = "use-rustls")] pub use tokio_rustls::rustls::ClientConfig; pub type Incoming = Packet; @@ -92,14 +96,15 @@ impl From for Request { #[derive(Clone)] pub enum Transport { Tcp, +#[cfg(feature = "use-rustls")] Tls(TlsConfiguration), #[cfg(unix)] Unix, #[cfg(feature = "websocket")] #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] Ws, - #[cfg(feature = "websocket")] - #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + #[cfg(all(feature = "use-rustls", feature = "websocket"))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "use-rustls", feature = "websocket"))))] Wss(TlsConfiguration), } @@ -116,6 +121,7 @@ impl Transport { } /// Use secure tcp with tls as transport + #[cfg(feature = "use-rustls")] pub fn tls( ca: Vec, client_auth: Option<(Vec, Key)>, @@ -130,6 +136,7 @@ impl Transport { Self::tls_with_config(config) } + #[cfg(feature = "use-rustls")] pub fn tls_with_config(tls_config: TlsConfiguration) -> Self { Self::Tls(tls_config) } @@ -147,8 +154,8 @@ impl Transport { } /// Use secure websockets with tls as transport - #[cfg(feature = "websocket")] - #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + #[cfg(all(feature = "use-rustls", feature = "websocket"))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "use-rustls", feature = "websocket"))))] pub fn wss( ca: Vec, client_auth: Option<(Vec, Key)>, @@ -163,14 +170,15 @@ impl Transport { Self::wss_with_config(config) } - #[cfg(feature = "websocket")] - #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + #[cfg(all(feature = "use-rustls", feature = "websocket"))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "use-rustls", feature = "websocket"))))] pub fn wss_with_config(tls_config: TlsConfiguration) -> Self { Self::Wss(tls_config) } } #[derive(Clone)] +#[cfg(feature = "use-rustls")] pub enum TlsConfiguration { Simple { /// connection method @@ -184,6 +192,7 @@ pub enum TlsConfiguration { Rustls(Arc), } +#[cfg(feature = "use-rustls")] impl From for TlsConfiguration { fn from(config: ClientConfig) -> Self { TlsConfiguration::Rustls(Arc::new(config)) @@ -233,7 +242,6 @@ pub struct MqttOptions { /// If set to `true` MQTT acknowledgements are not sent automatically. /// Every incoming publish packet must be manually acknowledged with `client.ack(...)` method. manual_acks: bool, - override_slow_subs: bool, } impl MqttOptions { @@ -261,7 +269,6 @@ impl MqttOptions { last_will: None, conn_timeout: 5, manual_acks: false, - override_slow_subs: false, } } @@ -410,14 +417,6 @@ impl MqttOptions { pub fn manual_acks(&self) -> bool { self.manual_acks } - - pub fn set_override_slow_subs(&mut self, val: bool) { - self.override_slow_subs = val; - } - - pub fn override_slow_subs(&self) -> bool { - self.override_slow_subs - } } #[cfg(feature = "url")] @@ -623,7 +622,7 @@ mod test { } #[test] - #[cfg(feature = "websocket")] + #[cfg(all(feature = "use-rustls", feature = "websocket"))] fn no_scheme() { let mut _mqtt_opts = MqttOptions::new("client_a", "a3f8czas.iot.eu-west-1.amazonaws.com/mqtt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=MyCreds%2F20201001%2Feu-west-1%2Fiotdevicegateway%2Faws4_request&X-Amz-Date=20201001T130812Z&X-Amz-Expires=7200&X-Amz-Signature=9ae09b49896f44270f2707551581953e6cac71a4ccf34c7c3415555be751b2d1&X-Amz-SignedHeaders=host", 443); From 71c9537f9f322d9325257d6992965d10a58063fe Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Thu, 24 Mar 2022 21:26:14 +0530 Subject: [PATCH 23/41] rumqttc: client: rollback pub-sub seperation Signed-off-by: Abhik Jain --- rumqttc/src/v5/client/asyncclient.rs | 26 +----- rumqttc/src/v5/client/mod.rs | 4 - rumqttc/src/v5/client/publisher.rs | 116 ------------------------- rumqttc/src/v5/client/subscriber.rs | 121 --------------------------- 4 files changed, 1 insertion(+), 266 deletions(-) delete mode 100644 rumqttc/src/v5/client/publisher.rs delete mode 100644 rumqttc/src/v5/client/subscriber.rs diff --git a/rumqttc/src/v5/client/asyncclient.rs b/rumqttc/src/v5/client/asyncclient.rs index f7e1f3783..5483da737 100644 --- a/rumqttc/src/v5/client/asyncclient.rs +++ b/rumqttc/src/v5/client/asyncclient.rs @@ -10,7 +10,7 @@ use bytes::Bytes; use flume::{SendError, Sender, TrySendError}; use crate::v5::{ - client::{get_ack_req, Publisher, Subscriber}, + client::get_ack_req, packet::{Publish, Subscribe, SubscribeFilter, Unsubscribe}, ClientError, EventLoop, MqttOptions, QoS, Request, }; @@ -233,30 +233,6 @@ impl AsyncClient { self.incoming_buf_cache.pop_front() } - pub async fn split( - self, - publish_topic: impl Into, - publish_qos: QoS, - ) -> Result<(Publisher, Subscriber), ClientError> { - let publisher = Publisher { - incoming_buf: self.incoming_buf.clone(), - incoming_buf_capacity: self.incoming_buf_capacity, - pkid_counter: self.pkid_counter, - max_inflight: self.max_inflight, - request_tx: self.request_tx.clone(), - publish_topic: publish_topic.into(), - publish_qos, - }; - let subscriber = Subscriber { - outgoing_buf: self.incoming_buf, - incoming_buf: self.outgoing_buf, - incoming_buf_cache: self.incoming_buf_cache, - request_buf_capacity: self.incoming_buf_capacity, - request_tx: self.request_tx, - }; - Ok((publisher, subscriber)) - } - fn increment_pkid(&self) -> u16 { let mut cur_pkid = self.pkid_counter.load(Ordering::SeqCst); loop { diff --git a/rumqttc/src/v5/client/mod.rs b/rumqttc/src/v5/client/mod.rs index 3156126df..cb47da36e 100644 --- a/rumqttc/src/v5/client/mod.rs +++ b/rumqttc/src/v5/client/mod.rs @@ -8,10 +8,6 @@ use tokio::runtime::{self, Runtime}; mod asyncclient; pub use asyncclient::AsyncClient; -mod publisher; -pub use publisher::Publisher; -mod subscriber; -pub use subscriber::Subscriber; mod syncclient; pub use syncclient::Client; diff --git a/rumqttc/src/v5/client/publisher.rs b/rumqttc/src/v5/client/publisher.rs deleted file mode 100644 index 3b5c07f2d..000000000 --- a/rumqttc/src/v5/client/publisher.rs +++ /dev/null @@ -1,116 +0,0 @@ -use std::{ - collections::VecDeque, - sync::{Arc, atomic::{AtomicU16, Ordering}, Mutex}, -}; - -use bytes::Bytes; -use flume::{SendError, Sender, TrySendError}; - -use crate::v5::{packet::Publish, ClientError, QoS, Request}; - -pub struct Publisher { - pub(crate) incoming_buf: Arc>>, - pub(crate) incoming_buf_capacity: usize, - pub(crate) pkid_counter: Arc, - pub(crate) max_inflight: u16, - pub(crate) request_tx: Sender<()>, - pub(crate) publish_topic: String, - pub(crate) publish_qos: QoS, -} - -impl Publisher { - /// Sends a MQTT Publish to the eventloop - pub async fn publish( - &self, - retain: bool, - payload: impl Into>, - ) -> Result { - let mut publish = Publish::new(&self.publish_topic, self.publish_qos, payload); - publish.retain = retain; - let pkid = self.increment_pkid(); - publish.pkid = pkid; - self.send_async_and_notify(Request::Publish(publish)) - .await?; - Ok(pkid) - } - - /// Sends a MQTT Publish to the eventloop - pub fn try_publish( - &self, - retain: bool, - payload: impl Into>, - ) -> Result { - let mut publish = Publish::new(&self.publish_topic, self.publish_qos, payload); - publish.retain = retain; - let pkid = self.increment_pkid(); - publish.pkid = pkid; - self.try_send_and_notify(Request::Publish(publish))?; - Ok(pkid) - } - - /// Sends a MQTT Publish to the eventloop - pub async fn publish_bytes(&self, retain: bool, payload: Bytes) -> Result { - let mut publish = Publish::from_bytes(&self.publish_topic, self.publish_qos, payload); - let pkid = self.increment_pkid(); - publish.pkid = pkid; - publish.retain = retain; - self.send_async_and_notify(Request::Publish(publish)).await?; - Ok(pkid) - } - - async fn send_async_and_notify(&self, request: Request) -> Result<(), ClientError> { - { - let mut request_buf = self.incoming_buf.lock().unwrap(); - if request_buf.len() == self.incoming_buf_capacity { - return Err(ClientError::RequestsFull); - } - request_buf.push_back(request); - } - if let Err(SendError(_)) = self.request_tx.send_async(()).await { - return Err(ClientError::EventloopClosed); - }; - Ok(()) - } - - /// Sends a MQTT disconnect to the eventloop - pub async fn disconnect(&self) -> Result<(), ClientError> { - self.send_async_and_notify(Request::Disconnect).await - } - - /// Sends a MQTT disconnect to the eventloop - pub fn try_disconnect(&self) -> Result<(), ClientError> { - self.try_send_and_notify(Request::Disconnect) - } - - fn try_send_and_notify(&self, request: Request) -> Result<(), ClientError> { - let mut request_buf = self.incoming_buf.lock().unwrap(); - if request_buf.len() == self.incoming_buf_capacity { - return Err(ClientError::RequestsFull); - } - request_buf.push_back(request); - if let Err(TrySendError::Disconnected(_)) = self.request_tx.try_send(()) { - return Err(ClientError::EventloopClosed); - } - Ok(()) - } - - fn increment_pkid(&self) -> u16 { - let mut cur_pkid = self.pkid_counter.load(Ordering::SeqCst); - loop { - let new_pkid = if cur_pkid > self.max_inflight { - 1 - } else { - cur_pkid + 1 - }; - match self.pkid_counter.compare_exchange( - cur_pkid, - new_pkid, - Ordering::SeqCst, - Ordering::Relaxed, - ) { - Ok(_prev_pkid) => break new_pkid, - Err(actual_pkid) => cur_pkid = actual_pkid, - } - } - } -} diff --git a/rumqttc/src/v5/client/subscriber.rs b/rumqttc/src/v5/client/subscriber.rs deleted file mode 100644 index d1982bddd..000000000 --- a/rumqttc/src/v5/client/subscriber.rs +++ /dev/null @@ -1,121 +0,0 @@ -use std::{ - collections::VecDeque, - sync::{Arc, Mutex}, -}; - -use flume::{SendError, Sender, TrySendError}; - -use crate::v5::{ - client::get_ack_req, ClientError, Publish, QoS, Request, Subscribe, SubscribeFilter, - Unsubscribe, -}; - -#[derive(Debug, Clone)] -pub struct Subscriber { - pub(crate) outgoing_buf: Arc>>, - pub(crate) incoming_buf: Arc>>, - pub(crate) incoming_buf_cache: VecDeque, - pub(crate) request_buf_capacity: usize, - pub(crate) request_tx: Sender<()>, -} - -impl Subscriber { - pub fn next_publish(&mut self) -> Option { - if let Some(publish) = self.incoming_buf_cache.pop_front() { - return Some(publish); - } - - std::mem::swap( - &mut self.incoming_buf_cache, - &mut *self.incoming_buf.lock().unwrap(), - ); - self.incoming_buf_cache.pop_front() - } - - /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. - pub async fn ack(&self, qos: QoS, pkid: u16) -> Result<(), ClientError> { - if let Some(ack) = get_ack_req(qos, pkid) { - self.send_async_and_notify(ack).await?; - } - Ok(()) - } - - /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. - pub fn try_ack(&self, qos: QoS, pkid: u16) -> Result<(), ClientError> { - if let Some(ack) = get_ack_req(qos, pkid) { - self.try_send_and_notify(ack)?; - } - Ok(()) - } - - /// Sends a MQTT Subscribe to the eventloop - pub async fn subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { - let subscribe = Subscribe::new(topic.into(), qos); - self.send_async_and_notify(Request::Subscribe(subscribe)) - .await - } - - /// Sends a MQTT Subscribe to the eventloop - pub fn try_subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { - let subscribe = Subscribe::new(topic.into(), qos); - self.try_send_and_notify(Request::Subscribe(subscribe)) - } - - /// Sends a MQTT Subscribe for multiple topics to the eventloop - pub async fn subscribe_many(&self, topics: T) -> Result<(), ClientError> - where - T: IntoIterator, - { - let subscribe = Subscribe::new_many(topics); - self.send_async_and_notify(Request::Subscribe(subscribe)) - .await - } - - /// Sends a MQTT Subscribe for multiple topics to the eventloop - pub fn try_subscribe_many(&self, topics: T) -> Result<(), ClientError> - where - T: IntoIterator, - { - let subscribe = Subscribe::new_many(topics); - self.try_send_and_notify(Request::Subscribe(subscribe)) - } - - /// Sends a MQTT Unsubscribe to the eventloop - pub async fn unsubscribe>(&self, topic: S) -> Result<(), ClientError> { - let unsubscribe = Unsubscribe::new(topic.into()); - self.send_async_and_notify(Request::Unsubscribe(unsubscribe)) - .await - } - - /// Sends a MQTT Unsubscribe to the eventloop - pub fn try_unsubscribe>(&self, topic: S) -> Result<(), ClientError> { - let unsubscribe = Unsubscribe::new(topic.into()); - self.try_send_and_notify(Request::Unsubscribe(unsubscribe)) - } - - async fn send_async_and_notify(&self, request: Request) -> Result<(), ClientError> { - { - let mut request_buf = self.outgoing_buf.lock().unwrap(); - if request_buf.len() == self.request_buf_capacity { - return Err(ClientError::RequestsFull); - } - request_buf.push_back(request); - } - if let Err(SendError(_)) = self.request_tx.send_async(()).await { - return Err(ClientError::EventloopClosed); - }; - Ok(()) - } - - fn try_send_and_notify(&self, request: Request) -> Result<(), ClientError> { - let mut request_buf = self.outgoing_buf.lock().unwrap(); - if request_buf.len() == self.request_buf_capacity { - return Err(ClientError::RequestsFull); - } - request_buf.push_back(request); - if let Err(TrySendError::Disconnected(_)) = self.request_tx.try_send(()) { - return Err(ClientError::EventloopClosed); - } - Ok(()) - } -} From 5e047d34c82f176e344c642a640f2a57d9ca2a47 Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Thu, 24 Mar 2022 21:50:31 +0530 Subject: [PATCH 24/41] rumqttc: eventloop: remove `Eventloop.event`, `Event` and `Outgoing` Signed-off-by: Abhik Jain --- rumqttc/examples/async_manual_acks.rs | 22 +++++---- rumqttc/examples/async_manual_acks_v5.rs | 29 ++++++------ rumqttc/src/v5/client/asyncclient.rs | 11 +++-- rumqttc/src/v5/client/mod.rs | 6 +-- rumqttc/src/v5/eventloop.rs | 43 ++++++----------- rumqttc/src/v5/mod.rs | 29 +----------- rumqttc/src/v5/state.rs | 59 ++---------------------- 7 files changed, 58 insertions(+), 141 deletions(-) diff --git a/rumqttc/examples/async_manual_acks.rs b/rumqttc/examples/async_manual_acks.rs index c89112332..000a060b7 100644 --- a/rumqttc/examples/async_manual_acks.rs +++ b/rumqttc/examples/async_manual_acks.rs @@ -1,6 +1,6 @@ use tokio::{task, time}; -use rumqttc::v4::{AsyncClient, Event, EventLoop, Incoming, MqttOptions, QoS}; +use rumqttc::v4::{AsyncClient, EventLoop, MqttOptions, QoS}; use std::error::Error; use std::time::Duration; @@ -44,21 +44,23 @@ async fn main() -> Result<(), Box> { } // create new broker connection - let (client, mut eventloop) = create_conn(); + let (_client, mut eventloop) = create_conn(); loop { // previously published messages should be republished after reconnection. let event = eventloop.poll().await; println!("{:?}", event); - if let Ok(Event::Incoming(Incoming::Publish(publish))) = event { - // this time we will ack incoming publishes. - // Its important not to block eventloop as this can cause deadlock. - let c = client.clone(); - tokio::spawn(async move { - c.ack(&publish).await.unwrap(); - }); - } + todo!("fix the commented out code below") + + // if let Ok(Event::Incoming(Incoming::Publish(publish))) = event { + // // this time we will ack incoming publishes. + // // Its important not to block eventloop as this can cause deadlock. + // let c = client.clone(); + // tokio::spawn(async move { + // c.ack(&publish).await.unwrap(); + // }); + // } } } diff --git a/rumqttc/examples/async_manual_acks_v5.rs b/rumqttc/examples/async_manual_acks_v5.rs index eb01d2e3d..471df0f40 100644 --- a/rumqttc/examples/async_manual_acks_v5.rs +++ b/rumqttc/examples/async_manual_acks_v5.rs @@ -1,6 +1,6 @@ use tokio::{task, time}; -use rumqttc::v5::{AsyncClient, Event, EventLoop, Incoming, MqttOptions, QoS}; +use rumqttc::v5::{AsyncClient, EventLoop, MqttOptions, QoS}; use std::error::Error; use std::time::Duration; @@ -44,23 +44,26 @@ async fn main() -> Result<(), Box> { } // create new broker connection - let (client, mut eventloop) = create_conn(); + let (_client, mut eventloop) = create_conn(); loop { // previously published messages should be republished after reconnection. let event = eventloop.poll().await; println!("{:?}", event); - match event { - Ok(Event::Incoming(Incoming::Publish(publish))) => { - // this time we will ack incoming publishes. - // Its important not to block eventloop as this can cause deadlock. - let c = client.clone(); - tokio::spawn(async move { - c.ack(&publish).await.unwrap(); - }); - } - _ => {} - } + + todo!("fix the commented out code below") + + // match event { + // Ok(Event::Incoming(Incoming::Publish(publish))) => { + // // this time we will ack incoming publishes. + // // Its important not to block eventloop as this can cause deadlock. + // let c = client.clone(); + // tokio::spawn(async move { + // c.ack(&publish).await.unwrap(); + // }); + // } + // _ => {} + // } } } diff --git a/rumqttc/src/v5/client/asyncclient.rs b/rumqttc/src/v5/client/asyncclient.rs index 5483da737..da778daee 100644 --- a/rumqttc/src/v5/client/asyncclient.rs +++ b/rumqttc/src/v5/client/asyncclient.rs @@ -12,7 +12,7 @@ use flume::{SendError, Sender, TrySendError}; use crate::v5::{ client::get_ack_req, packet::{Publish, Subscribe, SubscribeFilter, Unsubscribe}, - ClientError, EventLoop, MqttOptions, QoS, Request, + ClientError, EventLoop, Incoming, MqttOptions, QoS, Request, }; /// `AsyncClient` to communicate with MQTT `Eventloop` @@ -20,10 +20,10 @@ use crate::v5::{ #[derive(Clone, Debug)] pub struct AsyncClient { incoming_buf: Arc>>, - outgoing_buf: Arc>>, + outgoing_buf: Arc>>, pkid_counter: Arc, max_inflight: u16, - incoming_buf_cache: VecDeque, + incoming_buf_cache: VecDeque, incoming_buf_capacity: usize, request_tx: Sender<()>, } @@ -124,7 +124,8 @@ impl AsyncClient { publish.retain = retain; let pkid = self.increment_pkid(); publish.pkid = pkid; - self.push_and_async_notify(Request::Publish(publish)).await?; + self.push_and_async_notify(Request::Publish(publish)) + .await?; Ok(pkid) } @@ -221,7 +222,7 @@ impl AsyncClient { Ok(()) } - pub fn next_publish(&mut self) -> Option { + pub fn next_publish(&mut self) -> Option { if let Some(publish) = self.incoming_buf_cache.pop_front() { return Some(publish); } diff --git a/rumqttc/src/v5/client/mod.rs b/rumqttc/src/v5/client/mod.rs index cb47da36e..811b6889d 100644 --- a/rumqttc/src/v5/client/mod.rs +++ b/rumqttc/src/v5/client/mod.rs @@ -1,6 +1,6 @@ //! This module offers a high level synchronous and asynchronous abstraction to //! async eventloop. -use crate::v5::{packet::*, ConnectionError, Event, EventLoop, Request}; +use crate::v5::{packet::*, ConnectionError, EventLoop, Request}; use flume::SendError; use std::mem; @@ -69,12 +69,12 @@ pub struct Iter<'a> { } impl<'a> Iterator for Iter<'a> { - type Item = Result; + type Item = Result<(), ConnectionError>; fn next(&mut self) -> Option { let f = self.connection.eventloop.poll(); match self.runtime.block_on(f) { - Ok(v) => Some(Ok(v)), + Ok(_) => Some(Ok(())), // closing of request channel should stop the iterator Err(ConnectionError::RequestsDone) => { trace!("Done with requests"); diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index a28fd1e23..74edd7a6c 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -1,5 +1,5 @@ use crate::v5::{ - framed::Network, packet::*, Incoming, MqttOptions, MqttState, Outgoing, Packet, Request, + framed::Network, packet::*, Incoming, MqttOptions, MqttState, Packet, Request, StateError, Transport, }; #[cfg(feature = "use-rustls")] @@ -69,13 +69,6 @@ pub struct EventLoop { pub(crate) keepalive_timeout: Option>>, } -/// Events which can be yielded by the event loop -#[derive(Debug, PartialEq, Clone)] -pub enum Event { - Incoming(Incoming), - Outgoing(Outgoing), -} - impl EventLoop { /// New MQTT `EventLoop` /// @@ -111,7 +104,7 @@ impl EventLoop { &self.incoming_buf } - pub fn sub_events_buf(&self) -> &Arc>> { + pub fn sub_events_buf(&self) -> &Arc>> { &self.state.incoming_buf } @@ -126,29 +119,28 @@ impl EventLoop { /// the broker. Continuing to poll will reconnect to the broker if there is /// a disconnection. /// **NOTE** Don't block this while iterating - pub async fn poll(&mut self) -> Result { + pub async fn poll(&mut self) -> Result<(), ConnectionError> { if self.network.is_none() { - let (network, connack) = connect(&self.options).await?; + let (network, _connack) = connect(&self.options).await?; self.network = Some(network); if self.keepalive_timeout.is_none() { self.keepalive_timeout = Some(Box::pin(time::sleep(self.options.keep_alive))); } - return Ok(Event::Incoming(connack)); + return Ok(()); } - match self.select().await { - Ok(v) => Ok(v), - Err(e) => { - self.clean(); - Err(e) - } + if let Err(e) = self.select().await { + self.clean(); + return Err(e); } + + Ok(()) } /// Select on network and requests and generate keepalive pings when necessary - async fn select(&mut self) -> Result { + async fn select(&mut self) -> Result<(), ConnectionError> { let network = self.network.as_mut().unwrap(); // let await_acks = self.state.await_acks; let inflight_full = self.state.inflight >= self.options.inflight; @@ -156,11 +148,6 @@ impl EventLoop { let pending = self.pending.len() > 0; let collision = self.state.collision.is_some(); - // Read buffered events from previous polls before calling a new poll - if let Some(event) = self.state.events.pop_front() { - return Ok(event); - } - // this loop is necessary as self.request_buf might be empty, in which case it is possible // for self.state.events to be empty, and so popping off from it might return None. If None // is returned, we select again. @@ -171,7 +158,7 @@ impl EventLoop { o?; // flush all the acks and return first incoming packet network.flush(&mut self.state.write).await?; - return Ok(self.state.events.pop_front().unwrap()); + return Ok(()); }, // Pull next request from user requests channel. // If conditions in the below branch are for flow control. We read next user @@ -210,7 +197,7 @@ impl EventLoop { network.flush(&mut self.state.write).await?; // remaining events in the self.state.events will be taken out in next call // to poll() even before the select! is used. - return Ok(self.state.events.pop_front().unwrap()) + return Ok(()) } Err(_) => return Err(ConnectionError::RequestsDone), }, @@ -219,7 +206,7 @@ impl EventLoop { Some(request) = next_pending(throttle, &mut self.pending), if pending => { self.state.handle_outgoing_packet(request)?; network.flush(&mut self.state.write).await?; - return Ok(self.state.events.pop_front().unwrap()) + return Ok(()) }, // We generate pings irrespective of network activity. This keeps the ping logic // simple. We can change this behavior in future if necessary (to prevent extra pings) @@ -229,7 +216,7 @@ impl EventLoop { self.state.handle_outgoing_packet(Request::PingReq)?; network.flush(&mut self.state.write).await?; - return Ok(self.state.events.pop_front().unwrap()) + return Ok(()) } } } diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index 943a14f9b..2f26f4cc2 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -12,7 +12,7 @@ mod state; mod tls; pub use client::{AsyncClient, Client, ClientError, Connection}; -pub use eventloop::{ConnectionError, Event, EventLoop}; +pub use eventloop::{ConnectionError, EventLoop}; pub use flume::{SendError, Sender, TrySendError}; pub use packet::*; pub use state::{MqttState, StateError}; @@ -23,33 +23,6 @@ pub use tokio_rustls::rustls::ClientConfig; pub type Incoming = Packet; -/// Current outgoing activity on the eventloop -#[derive(Debug, Eq, PartialEq, Clone)] -pub enum Outgoing { - /// Publish packet with packet identifier. 0 implies QoS 0 - Publish(u16), - /// Subscribe packet with packet identifier - Subscribe(u16), - /// Unsubscribe packet with packet identifier - Unsubscribe(u16), - /// PubAck packet - PubAck(u16), - /// PubRec packet - PubRec(u16), - /// PubRel packet - PubRel(u16), - /// PubComp packet - PubComp(u16), - /// Ping request packet - PingReq, - /// Ping response packet - PingResp, - /// Disconnect packet - Disconnect, - /// Await for an ack for more outgoing progress - AwaitAck(u16), -} - /// Requests by the client to mqtt event loop. Request are /// handled one by one. #[derive(Clone, Debug, PartialEq)] diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 8ad0e5375..5af3c1b8d 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -1,4 +1,4 @@ -use super::{packet::*, Event, Incoming, Outgoing, Request}; +use super::{packet::*, Incoming, Request}; use bytes::BytesMut; use std::collections::VecDeque; @@ -72,13 +72,11 @@ pub struct MqttState { pub(crate) incoming_pub: Vec>, /// Last collision due to broker not acking in order pub collision: Option, - /// Buffered incoming packets - pub events: VecDeque, /// Write buffer pub write: BytesMut, /// Indicates if acknowledgements should be send immediately pub manual_acks: bool, - pub(crate) incoming_buf: Arc>>, + pub(crate) incoming_buf: Arc>>, pkid_counter: Arc, } @@ -100,7 +98,6 @@ impl MqttState { incoming_pub: vec![None; std::u16::MAX as usize + 1], collision: None, // TODO: Optimize these sizes later - events: VecDeque::with_capacity(100), write: BytesMut::with_capacity(10 * 1024), manual_acks, incoming_buf: Arc::new(Mutex::new(VecDeque::with_capacity(cap))), @@ -213,7 +210,7 @@ impl MqttState { }; out?; - self.events.push_back(Event::Incoming(packet)); + self.incoming_buf.lock().unwrap().push_back(packet); self.last_incoming = Instant::now(); Ok(()) } @@ -249,12 +246,6 @@ impl MqttState { } } - // TODO: maybe limit the capacity of `self.incoming_buf` - self.incoming_buf - .lock() - .unwrap() - .push_back(publish.clone()); - Ok(()) } @@ -275,8 +266,6 @@ impl MqttState { self.inflight += 1; publish.write(&mut self.write)?; - let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); - self.events.push_back(event); self.collision_ping_count = 0; } @@ -289,9 +278,6 @@ impl MqttState { // NOTE: Inflight - 1 for qos2 in comp self.outgoing_rel[pubrec.pkid as usize] = Some(pubrec.pkid); PubRel::new(pubrec.pkid).write(&mut self.write)?; - - let event = Event::Outgoing(Outgoing::PubRel(pubrec.pkid)); - self.events.push_back(event); Ok(()) } None => { @@ -305,8 +291,6 @@ impl MqttState { match mem::replace(&mut self.incoming_pub[pubrel.pkid as usize], None) { Some(_) => { PubComp::new(pubrel.pkid).write(&mut self.write)?; - let event = Event::Outgoing(Outgoing::PubComp(pubrel.pkid)); - self.events.push_back(event); Ok(()) } None => { @@ -319,8 +303,6 @@ impl MqttState { fn handle_incoming_pubcomp(&mut self, pubcomp: &PubComp) -> Result<(), StateError> { if let Some(publish) = self.check_collision(pubcomp.pkid) { publish.write(&mut self.write)?; - let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); - self.events.push_back(event); self.collision_ping_count = 0; } @@ -358,8 +340,6 @@ impl MqttState { { info!("Collision on packet id = {:?}", publish.pkid); self.collision = Some(publish); - let event = Event::Outgoing(Outgoing::AwaitAck(pkid)); - self.events.push_back(event); return Ok(()); } @@ -377,8 +357,6 @@ impl MqttState { ); publish.write(&mut self.write)?; - let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); - self.events.push_back(event); Ok(()) } @@ -387,23 +365,16 @@ impl MqttState { debug!("Pubrel. Pkid = {}", pubrel.pkid); PubRel::new(pubrel.pkid).write(&mut self.write)?; - - let event = Event::Outgoing(Outgoing::PubRel(pubrel.pkid)); - self.events.push_back(event); Ok(()) } fn outgoing_puback(&mut self, puback: PubAck) -> Result<(), StateError> { puback.write(&mut self.write)?; - let event = Event::Outgoing(Outgoing::PubAck(puback.pkid)); - self.events.push_back(event); Ok(()) } fn outgoing_pubrec(&mut self, pubrec: PubRec) -> Result<(), StateError> { pubrec.write(&mut self.write)?; - let event = Event::Outgoing(Outgoing::PubRec(pubrec.pkid)); - self.events.push_back(event); Ok(()) } @@ -437,8 +408,6 @@ impl MqttState { ); PingReq.write(&mut self.write)?; - let event = Event::Outgoing(Outgoing::PingReq); - self.events.push_back(event); Ok(()) } @@ -452,8 +421,6 @@ impl MqttState { ); subscription.write(&mut self.write)?; - let event = Event::Outgoing(Outgoing::Subscribe(subscription.pkid)); - self.events.push_back(event); Ok(()) } @@ -467,8 +434,6 @@ impl MqttState { ); unsub.write(&mut self.write)?; - let event = Event::Outgoing(Outgoing::Unsubscribe(unsub.pkid)); - self.events.push_back(event); Ok(()) } @@ -476,8 +441,6 @@ impl MqttState { debug!("Disconnect"); Disconnect::new().write(&mut self.write)?; - let event = Event::Outgoing(Outgoing::Disconnect); - self.events.push_back(event); Ok(()) } @@ -528,7 +491,7 @@ impl MqttState { #[cfg(test)] mod test { use super::{MqttState, StateError}; - use crate::v5::{packet::*, Event, Incoming, MqttOptions, Outgoing, Request}; + use crate::v5::{packet::*, Incoming, MqttOptions, Request}; fn build_outgoing_publish(qos: QoS) -> Publish { let topic = "hello/world".to_owned(); @@ -640,18 +603,6 @@ mod test { mqtt.handle_incoming_publish(&publish1).unwrap(); mqtt.handle_incoming_publish(&publish2).unwrap(); mqtt.handle_incoming_publish(&publish3).unwrap(); - - if let Event::Outgoing(Outgoing::PubAck(pkid)) = mqtt.events[0] { - assert_eq!(pkid, 2); - } else { - panic!("missing puback") - } - - if let Event::Outgoing(Outgoing::PubRec(pkid)) = mqtt.events[1] { - assert_eq!(pkid, 3); - } else { - panic!("missing PubRec") - } } #[test] @@ -671,7 +622,7 @@ mod test { let pkid = mqtt.incoming_pub[3].unwrap(); assert_eq!(pkid, 3); - assert!(mqtt.events.is_empty()); + assert!(mqtt.incoming_buf.lock().unwrap().is_empty()); } #[test] From 4198d93c541f3fa8e21f03551ecc562ec3dd1881 Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Thu, 24 Mar 2022 21:53:07 +0530 Subject: [PATCH 25/41] rumqttc: client: rename `AsyncClient::next_publish()` -> `AsyncClient::next()` Signed-off-by: Abhik Jain --- rumqttc/src/v5/client/asyncclient.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rumqttc/src/v5/client/asyncclient.rs b/rumqttc/src/v5/client/asyncclient.rs index da778daee..d048d9462 100644 --- a/rumqttc/src/v5/client/asyncclient.rs +++ b/rumqttc/src/v5/client/asyncclient.rs @@ -222,7 +222,7 @@ impl AsyncClient { Ok(()) } - pub fn next_publish(&mut self) -> Option { + pub fn next(&mut self) -> Option { if let Some(publish) = self.incoming_buf_cache.pop_front() { return Some(publish); } From 61a750630feb655a072a5d658abc3c8f1110d9b8 Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Thu, 24 Mar 2022 22:01:52 +0530 Subject: [PATCH 26/41] rumqttc: state: some inlines Signed-off-by: Abhik Jain --- rumqttc/src/v5/state.rs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 5af3c1b8d..cf9d2a16c 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -166,6 +166,7 @@ impl MqttState { pending } + #[inline] pub fn inflight(&self) -> u16 { self.inflight } @@ -215,10 +216,12 @@ impl MqttState { Ok(()) } + #[inline] fn handle_incoming_suback(&mut self) -> Result<(), StateError> { Ok(()) } + #[inline] fn handle_incoming_unsuback(&mut self) -> Result<(), StateError> { Ok(()) } @@ -318,6 +321,7 @@ impl MqttState { } } + #[inline] fn handle_incoming_pingresp(&mut self) -> Result<(), StateError> { self.await_pingresp = false; Ok(()) @@ -360,6 +364,7 @@ impl MqttState { Ok(()) } + #[inline] fn outgoing_pubrel(&mut self, pubrel: PubRel) -> Result<(), StateError> { let pubrel = self.save_pubrel(pubrel)?; @@ -368,11 +373,13 @@ impl MqttState { Ok(()) } + #[inline] fn outgoing_puback(&mut self, puback: PubAck) -> Result<(), StateError> { puback.write(&mut self.write)?; Ok(()) } + #[inline] fn outgoing_pubrec(&mut self, pubrec: PubRec) -> Result<(), StateError> { pubrec.write(&mut self.write)?; Ok(()) @@ -411,6 +418,7 @@ impl MqttState { Ok(()) } + #[inline] fn outgoing_subscribe(&mut self, mut subscription: Subscribe) -> Result<(), StateError> { let pkid = self.increment_pkid(); subscription.pkid = pkid; @@ -424,6 +432,7 @@ impl MqttState { Ok(()) } + #[inline] fn outgoing_unsubscribe(&mut self, mut unsub: Unsubscribe) -> Result<(), StateError> { let pkid = self.increment_pkid(); unsub.pkid = pkid; @@ -437,6 +446,7 @@ impl MqttState { Ok(()) } + #[inline] fn outgoing_disconnect(&mut self) -> Result<(), StateError> { debug!("Disconnect"); @@ -444,6 +454,7 @@ impl MqttState { Ok(()) } + #[inline] fn check_collision(&mut self, pkid: u16) -> Option { if let Some(publish) = &self.collision { if publish.pkid == pkid { @@ -454,6 +465,7 @@ impl MqttState { None } + #[inline] fn save_pubrel(&mut self, mut pubrel: PubRel) -> Result { let pubrel = match pubrel.pkid { // consider PacketIdentifier(0) as uninitialized packets From 05df29d6939f17842595bade563ddee01cf35350 Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Fri, 25 Mar 2022 11:33:54 +0530 Subject: [PATCH 27/41] rumqttc: ci: ignore packet parsing in cargo clippy Signed-off-by: Abhik Jain --- rumqttc/src/lib.rs | 1 + rumqttc/src/v5/client/asyncclient.rs | 66 +++++++++++++--------------- rumqttc/src/v5/mod.rs | 42 ++++++++++++++++-- rumqttc/src/v5/notifier.rs | 56 +++++++++++++++++++++++ 4 files changed, 126 insertions(+), 39 deletions(-) create mode 100644 rumqttc/src/v5/notifier.rs diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index a579c1be7..e74d6adbf 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -99,6 +99,7 @@ #[macro_use] extern crate log; +#[allow(clippy::all)] pub mod mqttbytes; pub mod v4; pub mod v5; diff --git a/rumqttc/src/v5/client/asyncclient.rs b/rumqttc/src/v5/client/asyncclient.rs index d048d9462..c824f98f5 100644 --- a/rumqttc/src/v5/client/asyncclient.rs +++ b/rumqttc/src/v5/client/asyncclient.rs @@ -12,40 +12,34 @@ use flume::{SendError, Sender, TrySendError}; use crate::v5::{ client::get_ack_req, packet::{Publish, Subscribe, SubscribeFilter, Unsubscribe}, - ClientError, EventLoop, Incoming, MqttOptions, QoS, Request, + ClientError, EventLoop, MqttOptions, QoS, Request, }; /// `AsyncClient` to communicate with MQTT `Eventloop` /// This is cloneable and can be used to asynchronously Publish, Subscribe. #[derive(Clone, Debug)] pub struct AsyncClient { - incoming_buf: Arc>>, - outgoing_buf: Arc>>, - pkid_counter: Arc, - max_inflight: u16, - incoming_buf_cache: VecDeque, - incoming_buf_capacity: usize, - request_tx: Sender<()>, + pub(crate) outgoing_buf: Arc>>, + pub(crate) outgoing_buf_capacity: usize, + pub(crate) pkid_counter: Arc, + pub(crate) max_inflight: u16, + pub(crate) request_tx: Sender<()>, } impl AsyncClient { /// Create a new `AsyncClient` pub fn new(options: MqttOptions, cap: usize) -> (AsyncClient, EventLoop) { let eventloop = EventLoop::new(options, cap); - let request_buf = eventloop.request_buf().clone(); - let incoming_buf = eventloop.state.incoming_buf.clone(); - let incoming_buf_cache = VecDeque::with_capacity(cap); + let outgoing_buf = eventloop.request_buf().clone(); let request_tx = eventloop.handle(); let max_inflight = eventloop.state.max_inflight; let pkid_counter = eventloop.state.pkid_counter().clone(); let client = AsyncClient { - incoming_buf: request_buf, - incoming_buf_capacity: cap, - outgoing_buf: incoming_buf, + outgoing_buf, + outgoing_buf_capacity: cap, pkid_counter, max_inflight, - incoming_buf_cache, request_tx, }; @@ -66,7 +60,11 @@ impl AsyncClient { { let mut publish = Publish::new(topic, qos, payload); publish.retain = retain; - let pkid = self.increment_pkid(); + let pkid = if qos != QoS::AtMostOnce { + self.increment_pkid() + } else { + 0 + }; publish.pkid = pkid; self.push_and_async_notify(Request::Publish(publish)) .await?; @@ -87,7 +85,11 @@ impl AsyncClient { { let mut publish = Publish::new(topic, qos, payload); publish.retain = retain; - let pkid = self.increment_pkid(); + let pkid = if qos != QoS::AtMostOnce { + self.increment_pkid() + } else { + 0 + }; publish.pkid = pkid; self.push_and_try_notify(Request::Publish(publish))?; Ok(pkid) @@ -122,7 +124,11 @@ impl AsyncClient { { let mut publish = Publish::from_bytes(topic, qos, payload); publish.retain = retain; - let pkid = self.increment_pkid(); + let pkid = if qos != QoS::AtMostOnce { + self.increment_pkid() + } else { + 0 + }; publish.pkid = pkid; self.push_and_async_notify(Request::Publish(publish)) .await?; @@ -186,8 +192,8 @@ impl AsyncClient { async fn push_and_async_notify(&self, request: Request) -> Result<(), ClientError> { { - let mut request_buf = self.incoming_buf.lock().unwrap(); - if request_buf.len() == self.incoming_buf_capacity { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.len() == self.outgoing_buf_capacity { return Err(ClientError::RequestsFull); } request_buf.push_back(request); @@ -199,8 +205,8 @@ impl AsyncClient { } pub(crate) fn push_and_notify(&self, request: Request) -> Result<(), ClientError> { - let mut request_buf = self.incoming_buf.lock().unwrap(); - if request_buf.len() == self.incoming_buf_capacity { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.len() == self.outgoing_buf_capacity { return Err(ClientError::RequestsFull); } request_buf.push_back(request); @@ -211,8 +217,8 @@ impl AsyncClient { } fn push_and_try_notify(&self, request: Request) -> Result<(), ClientError> { - let mut request_buf = self.incoming_buf.lock().unwrap(); - if request_buf.len() == self.incoming_buf_capacity { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.len() == self.outgoing_buf_capacity { return Err(ClientError::RequestsFull); } request_buf.push_back(request); @@ -222,18 +228,6 @@ impl AsyncClient { Ok(()) } - pub fn next(&mut self) -> Option { - if let Some(publish) = self.incoming_buf_cache.pop_front() { - return Some(publish); - } - - std::mem::swap( - &mut self.incoming_buf_cache, - &mut *self.outgoing_buf.lock().unwrap(), - ); - self.incoming_buf_cache.pop_front() - } - fn increment_pkid(&self) -> u16 { let mut cur_pkid = self.pkid_counter.load(Ordering::SeqCst); loop { diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index 2f26f4cc2..d9a62d333 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -1,11 +1,16 @@ -use std::fmt::{self, Debug, Formatter}; #[cfg(feature = "use-rustls")] use std::sync::Arc; -use std::time::Duration; +use std::{ + collections::VecDeque, + fmt::{self, Debug, Formatter}, + time::Duration, +}; mod client; mod eventloop; mod framed; +mod notifier; +#[allow(clippy::all)] mod packet; mod state; #[cfg(feature = "use-rustls")] @@ -20,6 +25,7 @@ pub use state::{MqttState, StateError}; pub use tls::Error; #[cfg(feature = "use-rustls")] pub use tokio_rustls::rustls::ClientConfig; +pub use notifier::Notifier; pub type Incoming = Packet; @@ -69,7 +75,7 @@ impl From for Request { #[derive(Clone)] pub enum Transport { Tcp, -#[cfg(feature = "use-rustls")] + #[cfg(feature = "use-rustls")] Tls(TlsConfiguration), #[cfg(unix)] Unix, @@ -584,6 +590,36 @@ impl Debug for MqttOptions { } } +pub async fn connect(options: MqttOptions, cap: usize) -> Result<(AsyncClient, Notifier), ()> { + let mut eventloop = EventLoop::new(options, cap); + let outgoing_buf = eventloop.request_buf().clone(); + let incoming_buf = eventloop.state.incoming_buf.clone(); + let incoming_buf_cache = VecDeque::with_capacity(cap); + let request_tx = eventloop.handle(); + let max_inflight = eventloop.state.max_inflight; + let pkid_counter = eventloop.state.pkid_counter().clone(); + + let client = AsyncClient { + outgoing_buf, + incoming_buf_capacity: cap, + incoming_buf, + pkid_counter, + max_inflight, + incoming_buf_cache, + request_tx, + }; + + tokio::spawn(async move { + loop { + // TODO: maybe do something like retries for some specific errors? or maybe give user + // options to configure these retries? + eventloop.poll().await.unwrap(); + } + }); + + Ok((client, Notifier {})) +} + #[cfg(test)] mod test { use super::*; diff --git a/rumqttc/src/v5/notifier.rs b/rumqttc/src/v5/notifier.rs new file mode 100644 index 000000000..00554fed1 --- /dev/null +++ b/rumqttc/src/v5/notifier.rs @@ -0,0 +1,56 @@ +use std::{ + collections::VecDeque, + mem, + sync::{Arc, Mutex}, +}; + +use crate::v5::Incoming; + +#[derive(Debug)] +pub struct Notifier { + incoming_buf: Arc>>, + incoming_buf_cache: VecDeque, +} + +impl Notifier { + #[inline] + pub(crate) fn new( + incoming_buf: Arc>>, + incoming_buf_cache: VecDeque, + ) -> Self { + Self { + incoming_buf, + incoming_buf_cache, + } + } + + #[inline] + pub fn next(&mut self) -> Option { + match self.incoming_buf_cache.pop_front() { + None => { + mem::replace( + &mut self.incoming_buf_cache, + *self.incoming_buf.lock().unwrap(), + ); + self.incoming_buf_cache.pop_front() + } + val => val, + } + } + + #[inline] + pub fn iter(&mut self) -> NotifierIter<'_> { + NotifierIter(self) + } +} + +pub struct NotifierIter<'a>(&'a mut Notifier); + +impl<'a> Iterator for NotifierIter<'a> { + type Item = Incoming; + + #[inline] + fn next(&mut self) -> Option { + self.0.next() + } +} From 7d028766871c4abfe2cd6a370360043132c40c32 Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Fri, 25 Mar 2022 11:38:45 +0530 Subject: [PATCH 28/41] rumqttc: ci: ignore packet parsing in cargo clippy Signed-off-by: Abhik Jain --- rumqttc/src/v4/tls.rs | 2 +- rumqttc/src/v5/mod.rs | 8 +++----- rumqttc/src/v5/notifier.rs | 4 ++-- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/rumqttc/src/v4/tls.rs b/rumqttc/src/v4/tls.rs index ea1809e23..72eca8aa8 100644 --- a/rumqttc/src/v4/tls.rs +++ b/rumqttc/src/v4/tls.rs @@ -96,7 +96,7 @@ pub async fn tls_connector(tls_config: &TlsConfiguration) -> Result return Err(Error::NoValidCertInChain), }; - let certs = certs.into_iter().map(|cert| Certificate(cert)).collect(); + let certs = certs.into_iter().map(Certificate).collect(); config.with_single_cert(certs, PrivateKey(key))? } else { diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index d9a62d333..6e6d4428a 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -19,13 +19,13 @@ mod tls; pub use client::{AsyncClient, Client, ClientError, Connection}; pub use eventloop::{ConnectionError, EventLoop}; pub use flume::{SendError, Sender, TrySendError}; +pub use notifier::Notifier; pub use packet::*; pub use state::{MqttState, StateError}; #[cfg(feature = "use-rustls")] pub use tls::Error; #[cfg(feature = "use-rustls")] pub use tokio_rustls::rustls::ClientConfig; -pub use notifier::Notifier; pub type Incoming = Packet; @@ -601,11 +601,9 @@ pub async fn connect(options: MqttOptions, cap: usize) -> Result<(AsyncClient, N let client = AsyncClient { outgoing_buf, - incoming_buf_capacity: cap, - incoming_buf, + outgoing_buf_capacity: cap, pkid_counter, max_inflight, - incoming_buf_cache, request_tx, }; @@ -617,7 +615,7 @@ pub async fn connect(options: MqttOptions, cap: usize) -> Result<(AsyncClient, N } }); - Ok((client, Notifier {})) + Ok((client, Notifier::new(incoming_buf, incoming_buf_cache))) } #[cfg(test)] diff --git a/rumqttc/src/v5/notifier.rs b/rumqttc/src/v5/notifier.rs index 00554fed1..4aad90d16 100644 --- a/rumqttc/src/v5/notifier.rs +++ b/rumqttc/src/v5/notifier.rs @@ -28,9 +28,9 @@ impl Notifier { pub fn next(&mut self) -> Option { match self.incoming_buf_cache.pop_front() { None => { - mem::replace( + mem::swap( &mut self.incoming_buf_cache, - *self.incoming_buf.lock().unwrap(), + &mut *self.incoming_buf.lock().unwrap(), ); self.incoming_buf_cache.pop_front() } From c4b2522df929abe182a16ef194669076096e5a1c Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Sun, 27 Mar 2022 09:20:58 +0530 Subject: [PATCH 29/41] rumqttc: handle pkids at client side Signed-off-by: Abhik Jain --- rumqttc/examples/async_manual_acks_v5.rs | 86 ++-- rumqttc/examples/asyncpubsub_v5.rs | 2 +- rumqttc/src/v5/client/asyncclient.rs | 114 +++-- rumqttc/src/v5/client/syncclient.rs | 36 +- rumqttc/src/v5/mod.rs | 3 +- rumqttc/src/v5/state.rs | 620 +++++++++++------------ 6 files changed, 416 insertions(+), 445 deletions(-) diff --git a/rumqttc/examples/async_manual_acks_v5.rs b/rumqttc/examples/async_manual_acks_v5.rs index 471df0f40..fde55ad81 100644 --- a/rumqttc/examples/async_manual_acks_v5.rs +++ b/rumqttc/examples/async_manual_acks_v5.rs @@ -1,3 +1,4 @@ +#![allow(dead_code, unused_imports)] use tokio::{task, time}; use rumqttc::v5::{AsyncClient, EventLoop, MqttOptions, QoS}; @@ -16,58 +17,59 @@ fn create_conn() -> (AsyncClient, EventLoop) { #[tokio::main(worker_threads = 1)] async fn main() -> Result<(), Box> { - pretty_env_logger::init(); + todo!("fix this example with new way of spawning clients") + // pretty_env_logger::init(); - // create mqtt connection with clean_session = false and manual_acks = true - let (client, mut eventloop) = create_conn(); + // // create mqtt connection with clean_session = false and manual_acks = true + // let (client, mut eventloop) = create_conn(); - // subscribe example topic - client - .subscribe("hello/world", QoS::AtLeastOnce) - .await - .unwrap(); + // // subscribe example topic + // client + // .subscribe("hello/world", QoS::AtLeastOnce) + // .await + // .unwrap(); - task::spawn(async move { - // send some messages to example topic and disconnect - requests(client.clone()).await; - client.disconnect().await.unwrap() - }); + // task::spawn(async move { + // // send some messages to example topic and disconnect + // requests(client.clone()).await; + // client.disconnect().await.unwrap() + // }); - loop { - // get subscribed messages without acking - let event = eventloop.poll().await; - println!("{:?}", event); - if let Err(_err) = event { - // break loop on disconnection - break; - } - } + // loop { + // // get subscribed messages without acking + // let event = eventloop.poll().await; + // println!("{:?}", event); + // if let Err(_err) = event { + // // break loop on disconnection + // break; + // } + // } - // create new broker connection - let (_client, mut eventloop) = create_conn(); + // // create new broker connection + // let (_client, mut eventloop) = create_conn(); - loop { - // previously published messages should be republished after reconnection. - let event = eventloop.poll().await; - println!("{:?}", event); + // loop { + // // previously published messages should be republished after reconnection. + // let event = eventloop.poll().await; + // println!("{:?}", event); - todo!("fix the commented out code below") + // todo!("fix the commented out code below") - // match event { - // Ok(Event::Incoming(Incoming::Publish(publish))) => { - // // this time we will ack incoming publishes. - // // Its important not to block eventloop as this can cause deadlock. - // let c = client.clone(); - // tokio::spawn(async move { - // c.ack(&publish).await.unwrap(); - // }); - // } - // _ => {} - // } - } + // // match event { + // // Ok(Event::Incoming(Incoming::Publish(publish))) => { + // // // this time we will ack incoming publishes. + // // // Its important not to block eventloop as this can cause deadlock. + // // let c = client.clone(); + // // tokio::spawn(async move { + // // c.ack(&publish).await.unwrap(); + // // }); + // // } + // // _ => {} + // // } + // } } -async fn requests(client: AsyncClient) { +async fn requests(mut client: AsyncClient) { for i in 1..=10 { client .publish("hello/world", QoS::AtLeastOnce, false, vec![1; i]) diff --git a/rumqttc/examples/asyncpubsub_v5.rs b/rumqttc/examples/asyncpubsub_v5.rs index b0b7a6698..b398a5e3f 100644 --- a/rumqttc/examples/asyncpubsub_v5.rs +++ b/rumqttc/examples/asyncpubsub_v5.rs @@ -24,7 +24,7 @@ async fn main() -> Result<(), Box> { } } -async fn requests(client: AsyncClient) { +async fn requests(mut client: AsyncClient) { client .subscribe("hello/world", QoS::AtMostOnce) .await diff --git a/rumqttc/src/v5/client/asyncclient.rs b/rumqttc/src/v5/client/asyncclient.rs index c824f98f5..798439e79 100644 --- a/rumqttc/src/v5/client/asyncclient.rs +++ b/rumqttc/src/v5/client/asyncclient.rs @@ -1,9 +1,6 @@ use std::{ collections::VecDeque, - sync::{ - atomic::{AtomicU16, Ordering}, - Arc, Mutex, - }, + sync::{Arc, Mutex}, }; use bytes::Bytes; @@ -17,11 +14,11 @@ use crate::v5::{ /// `AsyncClient` to communicate with MQTT `Eventloop` /// This is cloneable and can be used to asynchronously Publish, Subscribe. -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct AsyncClient { pub(crate) outgoing_buf: Arc>>, pub(crate) outgoing_buf_capacity: usize, - pub(crate) pkid_counter: Arc, + pub(crate) pkid_counter: u16, pub(crate) max_inflight: u16, pub(crate) request_tx: Sender<()>, } @@ -33,12 +30,11 @@ impl AsyncClient { let outgoing_buf = eventloop.request_buf().clone(); let request_tx = eventloop.handle(); let max_inflight = eventloop.state.max_inflight; - let pkid_counter = eventloop.state.pkid_counter().clone(); let client = AsyncClient { outgoing_buf, outgoing_buf_capacity: cap, - pkid_counter, + pkid_counter: 0, max_inflight, request_tx, }; @@ -48,7 +44,7 @@ impl AsyncClient { /// Sends a MQTT Publish to the eventloop pub async fn publish( - &self, + &mut self, topic: S, qos: QoS, retain: bool, @@ -73,7 +69,7 @@ impl AsyncClient { /// Sends a MQTT Publish to the eventloop pub fn try_publish( - &self, + &mut self, topic: S, qos: QoS, retain: bool, @@ -96,7 +92,7 @@ impl AsyncClient { } /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. - pub async fn ack(&self, publish: &Publish) -> Result<(), ClientError> { + pub async fn ack(&mut self, publish: &Publish) -> Result<(), ClientError> { if let Some(ack) = get_ack_req(publish.qos, publish.pkid) { self.push_and_async_notify(ack).await?; } @@ -104,7 +100,7 @@ impl AsyncClient { } /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. - pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { + pub fn try_ack(&mut self, publish: &Publish) -> Result<(), ClientError> { if let Some(ack) = get_ack_req(publish.qos, publish.pkid) { self.push_and_try_notify(ack)?; } @@ -113,7 +109,7 @@ impl AsyncClient { /// Sends a MQTT Publish to the eventloop pub async fn publish_bytes( - &self, + &mut self, topic: S, qos: QoS, retain: bool, @@ -136,57 +132,83 @@ impl AsyncClient { } /// Sends a MQTT Subscribe to the eventloop - pub async fn subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { - let subscribe = Subscribe::new(topic.into(), qos); + pub async fn subscribe>( + &mut self, + topic: S, + qos: QoS, + ) -> Result { + let mut subscribe = Subscribe::new(topic.into(), qos); + let pkid = self.increment_pkid(); + subscribe.pkid = pkid; self.push_and_async_notify(Request::Subscribe(subscribe)) - .await + .await?; + Ok(pkid) } /// Sends a MQTT Subscribe to the eventloop - pub fn try_subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { - let subscribe = Subscribe::new(topic.into(), qos); - self.push_and_try_notify(Request::Subscribe(subscribe)) + pub fn try_subscribe>( + &mut self, + topic: S, + qos: QoS, + ) -> Result { + let mut subscribe = Subscribe::new(topic.into(), qos); + let pkid = self.increment_pkid(); + subscribe.pkid = pkid; + self.push_and_try_notify(Request::Subscribe(subscribe))?; + Ok(pkid) } /// Sends a MQTT Subscribe for multiple topics to the eventloop - pub async fn subscribe_many(&self, topics: T) -> Result<(), ClientError> + pub async fn subscribe_many(&mut self, topics: T) -> Result where T: IntoIterator, { - let subscribe = Subscribe::new_many(topics); + let mut subscribe = Subscribe::new_many(topics); + let pkid = self.increment_pkid(); + subscribe.pkid = pkid; self.push_and_async_notify(Request::Subscribe(subscribe)) - .await + .await?; + Ok(pkid) } /// Sends a MQTT Subscribe for multiple topics to the eventloop - pub fn try_subscribe_many(&self, topics: T) -> Result<(), ClientError> + pub fn try_subscribe_many(&mut self, topics: T) -> Result where T: IntoIterator, { - let subscribe = Subscribe::new_many(topics); - self.push_and_try_notify(Request::Subscribe(subscribe)) + let mut subscribe = Subscribe::new_many(topics); + let pkid = self.increment_pkid(); + subscribe.pkid = pkid; + self.push_and_try_notify(Request::Subscribe(subscribe))?; + Ok(pkid) } /// Sends a MQTT Unsubscribe to the eventloop - pub async fn unsubscribe>(&self, topic: S) -> Result<(), ClientError> { - let unsubscribe = Unsubscribe::new(topic.into()); + pub async fn unsubscribe>(&mut self, topic: S) -> Result { + let mut unsubscribe = Unsubscribe::new(topic.into()); + let pkid = self.increment_pkid(); + unsubscribe.pkid = pkid; self.push_and_async_notify(Request::Unsubscribe(unsubscribe)) - .await + .await?; + Ok(pkid) } /// Sends a MQTT Unsubscribe to the eventloop - pub fn try_unsubscribe>(&self, topic: S) -> Result<(), ClientError> { - let unsubscribe = Unsubscribe::new(topic.into()); - self.push_and_try_notify(Request::Unsubscribe(unsubscribe)) + pub fn try_unsubscribe>(&mut self, topic: S) -> Result { + let mut unsubscribe = Unsubscribe::new(topic.into()); + let pkid = self.increment_pkid(); + unsubscribe.pkid = pkid; + self.push_and_try_notify(Request::Unsubscribe(unsubscribe))?; + Ok(pkid) } /// Sends a MQTT disconnect to the eventloop - pub async fn disconnect(&self) -> Result<(), ClientError> { + pub async fn disconnect(&mut self) -> Result<(), ClientError> { self.push_and_async_notify(Request::Disconnect).await } /// Sends a MQTT disconnect to the eventloop - pub fn try_disconnect(&self) -> Result<(), ClientError> { + pub fn try_disconnect(&mut self) -> Result<(), ClientError> { self.push_and_try_notify(Request::Disconnect) } @@ -228,23 +250,13 @@ impl AsyncClient { Ok(()) } - fn increment_pkid(&self) -> u16 { - let mut cur_pkid = self.pkid_counter.load(Ordering::SeqCst); - loop { - let new_pkid = if cur_pkid > self.max_inflight { - 1 - } else { - cur_pkid + 1 - }; - match self.pkid_counter.compare_exchange( - cur_pkid, - new_pkid, - Ordering::SeqCst, - Ordering::Relaxed, - ) { - Ok(_prev_pkid) => break new_pkid, - Err(actual_pkid) => cur_pkid = actual_pkid, - } - } + #[inline] + pub(crate) fn increment_pkid(&mut self) -> u16 { + self.pkid_counter = if self.pkid_counter == self.max_inflight { + 1 + } else { + self.pkid_counter + 1 + }; + self.pkid_counter } } diff --git a/rumqttc/src/v5/client/syncclient.rs b/rumqttc/src/v5/client/syncclient.rs index e517db9ff..ae94a1119 100644 --- a/rumqttc/src/v5/client/syncclient.rs +++ b/rumqttc/src/v5/client/syncclient.rs @@ -10,7 +10,6 @@ use crate::v5::{ /// /// Client is cloneable and can be used to synchronously Publish, Subscribe. /// Asynchronous channel handle can also be extracted if necessary -#[derive(Clone)] pub struct Client { client: AsyncClient, } @@ -71,14 +70,17 @@ impl Client { } /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. - pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { + pub fn try_ack(&mut self, publish: &Publish) -> Result<(), ClientError> { self.client.try_ack(publish) } /// Sends a MQTT Subscribe to the eventloop - pub fn subscribe>(&mut self, topic: S, qos: QoS) -> Result<(), ClientError> { - let subscribe = Subscribe::new(topic.into(), qos); - self.client.push_and_notify(Request::Subscribe(subscribe)) + pub fn subscribe>(&mut self, topic: S, qos: QoS) -> Result { + let mut subscribe = Subscribe::new(topic.into(), qos); + let pkid = self.client.increment_pkid(); + subscribe.pkid = pkid; + self.client.push_and_notify(Request::Subscribe(subscribe))?; + Ok(pkid) } /// Sends a MQTT Subscribe to the eventloop @@ -86,20 +88,23 @@ impl Client { &mut self, topic: S, qos: QoS, - ) -> Result<(), ClientError> { + ) -> Result { self.client.try_subscribe(topic, qos) } /// Sends a MQTT Subscribe for multiple topics to the eventloop - pub fn subscribe_many(&mut self, topics: T) -> Result<(), ClientError> + pub fn subscribe_many(&mut self, topics: T) -> Result where T: IntoIterator, { - let subscribe = Subscribe::new_many(topics); - self.client.push_and_notify(Request::Subscribe(subscribe)) + let mut subscribe = Subscribe::new_many(topics); + let pkid = self.client.increment_pkid(); + subscribe.pkid = pkid; + self.client.push_and_notify(Request::Subscribe(subscribe))?; + Ok(pkid) } - pub fn try_subscribe_many(&mut self, topics: T) -> Result<(), ClientError> + pub fn try_subscribe_many(&mut self, topics: T) -> Result where T: IntoIterator, { @@ -107,14 +112,17 @@ impl Client { } /// Sends a MQTT Unsubscribe to the eventloop - pub fn unsubscribe>(&mut self, topic: S) -> Result<(), ClientError> { - let unsubscribe = Unsubscribe::new(topic.into()); + pub fn unsubscribe>(&mut self, topic: S) -> Result { + let mut unsubscribe = Unsubscribe::new(topic.into()); + let pkid = self.client.increment_pkid(); + unsubscribe.pkid = pkid; self.client - .push_and_notify(Request::Unsubscribe(unsubscribe)) + .push_and_notify(Request::Unsubscribe(unsubscribe))?; + Ok(pkid) } /// Sends a MQTT Unsubscribe to the eventloop - pub fn try_unsubscribe>(&mut self, topic: S) -> Result<(), ClientError> { + pub fn try_unsubscribe>(&mut self, topic: S) -> Result { self.client.try_unsubscribe(topic) } diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index 6e6d4428a..615d18148 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -597,12 +597,11 @@ pub async fn connect(options: MqttOptions, cap: usize) -> Result<(AsyncClient, N let incoming_buf_cache = VecDeque::with_capacity(cap); let request_tx = eventloop.handle(); let max_inflight = eventloop.state.max_inflight; - let pkid_counter = eventloop.state.pkid_counter().clone(); let client = AsyncClient { outgoing_buf, outgoing_buf_capacity: cap, - pkid_counter, + pkid_counter: 0, max_inflight, request_tx, }; diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index cf9d2a16c..0b304f666 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -2,7 +2,6 @@ use super::{packet::*, Incoming, Request}; use bytes::BytesMut; use std::collections::VecDeque; -use std::sync::atomic::{AtomicU16, Ordering}; use std::{ io, mem, sync::{Arc, Mutex}, @@ -77,7 +76,6 @@ pub struct MqttState { /// Indicates if acknowledgements should be send immediately pub manual_acks: bool, pub(crate) incoming_buf: Arc>>, - pkid_counter: Arc, } impl MqttState { @@ -101,38 +99,6 @@ impl MqttState { write: BytesMut::with_capacity(10 * 1024), manual_acks, incoming_buf: Arc::new(Mutex::new(VecDeque::with_capacity(cap))), - pkid_counter: Arc::new(AtomicU16::new(0)), - } - } - - #[inline] - pub(crate) fn pkid_counter(&self) -> &Arc { - &self.pkid_counter - } - - #[cfg(test)] - #[inline] - fn cur_pkid(&self) -> u16 { - self.pkid_counter.load(Ordering::SeqCst) - } - - pub(crate) fn increment_pkid(&self) -> u16 { - let mut cur_pkid = self.pkid_counter.load(Ordering::SeqCst); - loop { - let new_pkid = if cur_pkid > self.max_inflight { - 1 - } else { - cur_pkid + 1 - }; - match self.pkid_counter.compare_exchange( - cur_pkid, - new_pkid, - Ordering::SeqCst, - Ordering::Relaxed, - ) { - Ok(_prev_pkid) => break new_pkid, - Err(actual_pkid) => cur_pkid = actual_pkid, - } } } @@ -329,12 +295,9 @@ impl MqttState { /// Adds next packet identifier to QoS 1 and 2 publish packets and returns /// it buy wrapping publish in packet - fn outgoing_publish(&mut self, mut publish: Publish) -> Result<(), StateError> { + fn outgoing_publish(&mut self, publish: Publish) -> Result<(), StateError> { if publish.qos != QoS::AtMostOnce { - if publish.pkid == 0 { - publish.pkid = self.increment_pkid(); - } - + // client should set proper pkid let pkid = publish.pkid; if self .outgoing_pub @@ -419,10 +382,8 @@ impl MqttState { } #[inline] - fn outgoing_subscribe(&mut self, mut subscription: Subscribe) -> Result<(), StateError> { - let pkid = self.increment_pkid(); - subscription.pkid = pkid; - + fn outgoing_subscribe(&mut self, subscription: Subscribe) -> Result<(), StateError> { + // client should set correct pkid debug!( "Subscribe. Topics = {:?}, Pkid = {:?}", subscription.filters, subscription.pkid @@ -433,10 +394,7 @@ impl MqttState { } #[inline] - fn outgoing_unsubscribe(&mut self, mut unsub: Unsubscribe) -> Result<(), StateError> { - let pkid = self.increment_pkid(); - unsub.pkid = pkid; - + fn outgoing_unsubscribe(&mut self, unsub: Unsubscribe) -> Result<(), StateError> { debug!( "Unsubscribe. Topics = {:?}, Pkid = {:?}", unsub.filters, unsub.pkid @@ -466,16 +424,8 @@ impl MqttState { } #[inline] - fn save_pubrel(&mut self, mut pubrel: PubRel) -> Result { - let pubrel = match pubrel.pkid { - // consider PacketIdentifier(0) as uninitialized packets - 0 => { - pubrel.pkid = self.increment_pkid(); - pubrel - } - _ => pubrel, - }; - + fn save_pubrel(&mut self, pubrel: PubRel) -> Result { + // pubrel's pkid should already be set correct self.outgoing_rel[pubrel.pkid as usize] = Some(pubrel.pkid); Ok(pubrel) } @@ -502,282 +452,282 @@ impl MqttState { #[cfg(test)] mod test { - use super::{MqttState, StateError}; - use crate::v5::{packet::*, Incoming, MqttOptions, Request}; - - fn build_outgoing_publish(qos: QoS) -> Publish { - let topic = "hello/world".to_owned(); - let payload = vec![1, 2, 3]; - - let mut publish = Publish::new(topic, QoS::AtLeastOnce, payload); - publish.qos = qos; - publish - } - - fn build_incoming_publish(qos: QoS, pkid: u16) -> Publish { - let topic = "hello/world".to_owned(); - let payload = vec![1, 2, 3]; - - let mut publish = Publish::new(topic, QoS::AtLeastOnce, payload); - publish.pkid = pkid; - publish.qos = qos; - publish - } - - fn build_mqttstate() -> MqttState { - MqttState::new(100, false, 100) - } - - #[test] - fn next_pkid_increments_as_expected() { - let mqtt = build_mqttstate(); - - for i in 1..=100 { - let pkid = mqtt.increment_pkid(); - - // loops between 0-99. % 100 == 0 implies border - let expected = i % 100; - if expected == 0 { - break; - } - - assert_eq!(expected, pkid); - } - } - - #[test] - fn outgoing_publish_should_set_pkid_and_add_publish_to_queue() { - let mut mqtt = build_mqttstate(); - - // QoS0 Publish - let publish = build_outgoing_publish(QoS::AtMostOnce); - - // QoS 0 publish shouldn't be saved in queue - mqtt.outgoing_publish(publish).unwrap(); - assert_eq!(mqtt.cur_pkid(), 0); - assert_eq!(mqtt.inflight, 0); - - // QoS1 Publish - let publish = build_outgoing_publish(QoS::AtLeastOnce); - - // Packet id should be set and publish should be saved in queue - mqtt.outgoing_publish(publish.clone()).unwrap(); - assert_eq!(mqtt.cur_pkid(), 1); - assert_eq!(mqtt.inflight, 1); - - // Packet id should be incremented and publish should be saved in queue - mqtt.outgoing_publish(publish).unwrap(); - assert_eq!(mqtt.cur_pkid(), 2); - assert_eq!(mqtt.inflight, 2); - - // QoS1 Publish - let publish = build_outgoing_publish(QoS::ExactlyOnce); - - // Packet id should be set and publish should be saved in queue - mqtt.outgoing_publish(publish.clone()).unwrap(); - assert_eq!(mqtt.cur_pkid(), 3); - assert_eq!(mqtt.inflight, 3); - - // Packet id should be incremented and publish should be saved in queue - mqtt.outgoing_publish(publish).unwrap(); - assert_eq!(mqtt.cur_pkid(), 4); - assert_eq!(mqtt.inflight, 4); - } - - #[test] - fn incoming_publish_should_be_added_to_queue_correctly() { - let mut mqtt = build_mqttstate(); - - // QoS0, 1, 2 Publishes - let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); - let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); - let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); - - mqtt.handle_incoming_publish(&publish1).unwrap(); - mqtt.handle_incoming_publish(&publish2).unwrap(); - mqtt.handle_incoming_publish(&publish3).unwrap(); - - let pkid = mqtt.incoming_pub[3].unwrap(); - - // only qos2 publish should be add to queue - assert_eq!(pkid, 3); - } - - #[test] - fn incoming_publish_should_be_acked() { - let mut mqtt = build_mqttstate(); - - // QoS0, 1, 2 Publishes - let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); - let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); - let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); - - mqtt.handle_incoming_publish(&publish1).unwrap(); - mqtt.handle_incoming_publish(&publish2).unwrap(); - mqtt.handle_incoming_publish(&publish3).unwrap(); - } - - #[test] - fn incoming_publish_should_not_be_acked_with_manual_acks() { - let mut mqtt = build_mqttstate(); - mqtt.manual_acks = true; - - // QoS0, 1, 2 Publishes - let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); - let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); - let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); - - mqtt.handle_incoming_publish(&publish1).unwrap(); - mqtt.handle_incoming_publish(&publish2).unwrap(); - mqtt.handle_incoming_publish(&publish3).unwrap(); - - let pkid = mqtt.incoming_pub[3].unwrap(); - assert_eq!(pkid, 3); - - assert!(mqtt.incoming_buf.lock().unwrap().is_empty()); - } - - #[test] - fn incoming_qos2_publish_should_send_rec_to_network_and_publish_to_user() { - let mut mqtt = build_mqttstate(); - let publish = build_incoming_publish(QoS::ExactlyOnce, 1); - - mqtt.handle_incoming_publish(&publish).unwrap(); - let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); - match packet { - Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), - _ => panic!("Invalid network request: {:?}", packet), - } - } - - #[test] - fn incoming_puback_should_remove_correct_publish_from_queue() { - let mut mqtt = build_mqttstate(); - - let publish1 = build_outgoing_publish(QoS::AtLeastOnce); - let publish2 = build_outgoing_publish(QoS::ExactlyOnce); - - mqtt.outgoing_publish(publish1).unwrap(); - mqtt.outgoing_publish(publish2).unwrap(); - assert_eq!(mqtt.inflight, 2); - - mqtt.handle_incoming_puback(&PubAck::new(1)).unwrap(); - assert_eq!(mqtt.inflight, 1); - - mqtt.handle_incoming_puback(&PubAck::new(2)).unwrap(); - assert_eq!(mqtt.inflight, 0); - - assert!(mqtt.outgoing_pub[1].is_none()); - assert!(mqtt.outgoing_pub[2].is_none()); - } - - #[test] - fn incoming_pubrec_should_release_publish_from_queue_and_add_relid_to_rel_queue() { - let mut mqtt = build_mqttstate(); - - let publish1 = build_outgoing_publish(QoS::AtLeastOnce); - let publish2 = build_outgoing_publish(QoS::ExactlyOnce); - - let _publish_out = mqtt.outgoing_publish(publish1); - let _publish_out = mqtt.outgoing_publish(publish2); - - mqtt.handle_incoming_pubrec(&PubRec::new(2)).unwrap(); - assert_eq!(mqtt.inflight, 2); - - // check if the remaining element's pkid is 1 - let backup = mqtt.outgoing_pub[1].clone(); - assert_eq!(backup.unwrap().pkid, 1); - - // check if the qos2 element's release pkid is 2 - assert_eq!(mqtt.outgoing_rel[2].unwrap(), 2); - } - - #[test] - fn incoming_pubrec_should_send_release_to_network_and_nothing_to_user() { - let mut mqtt = build_mqttstate(); - - let publish = build_outgoing_publish(QoS::ExactlyOnce); - mqtt.outgoing_publish(publish).unwrap(); - let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); - match packet { - Packet::Publish(publish) => assert_eq!(publish.pkid, 1), - packet => panic!("Invalid network request: {:?}", packet), - } - - mqtt.handle_incoming_pubrec(&PubRec::new(1)).unwrap(); - let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); - match packet { - Packet::PubRel(pubrel) => assert_eq!(pubrel.pkid, 1), - packet => panic!("Invalid network request: {:?}", packet), - } - } - - #[test] - fn incoming_pubrel_should_send_comp_to_network_and_nothing_to_user() { - let mut mqtt = build_mqttstate(); - let publish = build_incoming_publish(QoS::ExactlyOnce, 1); - - mqtt.handle_incoming_publish(&publish).unwrap(); - let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); - match packet { - Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), - packet => panic!("Invalid network request: {:?}", packet), - } - - mqtt.handle_incoming_pubrel(&PubRel::new(1)).unwrap(); - let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); - match packet { - Packet::PubComp(pubcomp) => assert_eq!(pubcomp.pkid, 1), - packet => panic!("Invalid network request: {:?}", packet), - } - } - - #[test] - fn incoming_pubcomp_should_release_correct_pkid_from_release_queue() { - let mut mqtt = build_mqttstate(); - let publish = build_outgoing_publish(QoS::ExactlyOnce); - - mqtt.outgoing_publish(publish).unwrap(); - mqtt.handle_incoming_pubrec(&PubRec::new(1)).unwrap(); - - mqtt.handle_incoming_pubcomp(&PubComp::new(1)).unwrap(); - assert_eq!(mqtt.inflight, 0); - } - - #[test] - fn outgoing_ping_handle_should_throw_errors_for_no_pingresp() { - let mut mqtt = build_mqttstate(); - let mut opts = MqttOptions::new("test", "localhost", 1883); - opts.set_keep_alive(std::time::Duration::from_secs(10)); - mqtt.outgoing_ping().unwrap(); - - // network activity other than pingresp - let publish = build_outgoing_publish(QoS::AtLeastOnce); - mqtt.handle_outgoing_packet(Request::Publish(publish)) - .unwrap(); - mqtt.handle_incoming_packet(Incoming::PubAck(PubAck::new(1))) - .unwrap(); - - // should throw error because we didn't get pingresp for previous ping - match mqtt.outgoing_ping() { - Ok(_) => panic!("Should throw pingresp await error"), - Err(StateError::AwaitPingResp) => (), - Err(e) => panic!("Should throw pingresp await error. Error = {:?}", e), - } - } - - #[test] - fn outgoing_ping_handle_should_succeed_if_pingresp_is_received() { - let mut mqtt = build_mqttstate(); - - let mut opts = MqttOptions::new("test", "localhost", 1883); - opts.set_keep_alive(std::time::Duration::from_secs(10)); - - // should ping - mqtt.outgoing_ping().unwrap(); - mqtt.handle_incoming_packet(Incoming::PingResp).unwrap(); - - // should ping - mqtt.outgoing_ping().unwrap(); - } + // use super::{MqttState, StateError}; + // use crate::v5::{packet::*, Incoming, MqttOptions, Request}; + + // fn build_outgoing_publish(qos: QoS) -> Publish { + // let topic = "hello/world".to_owned(); + // let payload = vec![1, 2, 3]; + + // let mut publish = Publish::new(topic, QoS::AtLeastOnce, payload); + // publish.qos = qos; + // publish + // } + + // fn build_incoming_publish(qos: QoS, pkid: u16) -> Publish { + // let topic = "hello/world".to_owned(); + // let payload = vec![1, 2, 3]; + + // let mut publish = Publish::new(topic, QoS::AtLeastOnce, payload); + // publish.pkid = pkid; + // publish.qos = qos; + // publish + // } + + // fn build_mqttstate() -> MqttState { + // MqttState::new(100, false, 100) + // } + + // #[test] + // fn next_pkid_increments_as_expected() { + // let mqtt = build_mqttstate(); + + // for i in 1..=100 { + // let pkid = mqtt.increment_pkid(); + + // // loops between 0-99. % 100 == 0 implies border + // let expected = i % 100; + // if expected == 0 { + // break; + // } + + // assert_eq!(expected, pkid); + // } + // } + + // #[test] + // fn outgoing_publish_should_set_pkid_and_add_publish_to_queue() { + // let mut mqtt = build_mqttstate(); + + // // QoS0 Publish + // let publish = build_outgoing_publish(QoS::AtMostOnce); + + // // QoS 0 publish shouldn't be saved in queue + // mqtt.outgoing_publish(publish).unwrap(); + // assert_eq!(mqtt.cur_pkid(), 0); + // assert_eq!(mqtt.inflight, 0); + + // // QoS1 Publish + // let publish = build_outgoing_publish(QoS::AtLeastOnce); + + // // Packet id should be set and publish should be saved in queue + // mqtt.outgoing_publish(publish.clone()).unwrap(); + // assert_eq!(mqtt.cur_pkid(), 1); + // assert_eq!(mqtt.inflight, 1); + + // // Packet id should be incremented and publish should be saved in queue + // mqtt.outgoing_publish(publish).unwrap(); + // assert_eq!(mqtt.cur_pkid(), 2); + // assert_eq!(mqtt.inflight, 2); + + // // QoS1 Publish + // let publish = build_outgoing_publish(QoS::ExactlyOnce); + + // // Packet id should be set and publish should be saved in queue + // mqtt.outgoing_publish(publish.clone()).unwrap(); + // assert_eq!(mqtt.cur_pkid(), 3); + // assert_eq!(mqtt.inflight, 3); + + // // Packet id should be incremented and publish should be saved in queue + // mqtt.outgoing_publish(publish).unwrap(); + // assert_eq!(mqtt.cur_pkid(), 4); + // assert_eq!(mqtt.inflight, 4); + // } + + // #[test] + // fn incoming_publish_should_be_added_to_queue_correctly() { + // let mut mqtt = build_mqttstate(); + + // // QoS0, 1, 2 Publishes + // let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); + // let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); + // let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); + + // mqtt.handle_incoming_publish(&publish1).unwrap(); + // mqtt.handle_incoming_publish(&publish2).unwrap(); + // mqtt.handle_incoming_publish(&publish3).unwrap(); + + // let pkid = mqtt.incoming_pub[3].unwrap(); + + // // only qos2 publish should be add to queue + // assert_eq!(pkid, 3); + // } + + // #[test] + // fn incoming_publish_should_be_acked() { + // let mut mqtt = build_mqttstate(); + + // // QoS0, 1, 2 Publishes + // let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); + // let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); + // let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); + + // mqtt.handle_incoming_publish(&publish1).unwrap(); + // mqtt.handle_incoming_publish(&publish2).unwrap(); + // mqtt.handle_incoming_publish(&publish3).unwrap(); + // } + + // #[test] + // fn incoming_publish_should_not_be_acked_with_manual_acks() { + // let mut mqtt = build_mqttstate(); + // mqtt.manual_acks = true; + + // // QoS0, 1, 2 Publishes + // let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); + // let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); + // let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); + + // mqtt.handle_incoming_publish(&publish1).unwrap(); + // mqtt.handle_incoming_publish(&publish2).unwrap(); + // mqtt.handle_incoming_publish(&publish3).unwrap(); + + // let pkid = mqtt.incoming_pub[3].unwrap(); + // assert_eq!(pkid, 3); + + // assert!(mqtt.incoming_buf.lock().unwrap().is_empty()); + // } + + // #[test] + // fn incoming_qos2_publish_should_send_rec_to_network_and_publish_to_user() { + // let mut mqtt = build_mqttstate(); + // let publish = build_incoming_publish(QoS::ExactlyOnce, 1); + + // mqtt.handle_incoming_publish(&publish).unwrap(); + // let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + // match packet { + // Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), + // _ => panic!("Invalid network request: {:?}", packet), + // } + // } + + // #[test] + // fn incoming_puback_should_remove_correct_publish_from_queue() { + // let mut mqtt = build_mqttstate(); + + // let publish1 = build_outgoing_publish(QoS::AtLeastOnce); + // let publish2 = build_outgoing_publish(QoS::ExactlyOnce); + + // mqtt.outgoing_publish(publish1).unwrap(); + // mqtt.outgoing_publish(publish2).unwrap(); + // assert_eq!(mqtt.inflight, 2); + + // mqtt.handle_incoming_puback(&PubAck::new(1)).unwrap(); + // assert_eq!(mqtt.inflight, 1); + + // mqtt.handle_incoming_puback(&PubAck::new(2)).unwrap(); + // assert_eq!(mqtt.inflight, 0); + + // assert!(mqtt.outgoing_pub[1].is_none()); + // assert!(mqtt.outgoing_pub[2].is_none()); + // } + + // #[test] + // fn incoming_pubrec_should_release_publish_from_queue_and_add_relid_to_rel_queue() { + // let mut mqtt = build_mqttstate(); + + // let publish1 = build_outgoing_publish(QoS::AtLeastOnce); + // let publish2 = build_outgoing_publish(QoS::ExactlyOnce); + + // let _publish_out = mqtt.outgoing_publish(publish1); + // let _publish_out = mqtt.outgoing_publish(publish2); + + // mqtt.handle_incoming_pubrec(&PubRec::new(2)).unwrap(); + // assert_eq!(mqtt.inflight, 2); + + // // check if the remaining element's pkid is 1 + // let backup = mqtt.outgoing_pub[1].clone(); + // assert_eq!(backup.unwrap().pkid, 1); + + // // check if the qos2 element's release pkid is 2 + // assert_eq!(mqtt.outgoing_rel[2].unwrap(), 2); + // } + + // #[test] + // fn incoming_pubrec_should_send_release_to_network_and_nothing_to_user() { + // let mut mqtt = build_mqttstate(); + + // let publish = build_outgoing_publish(QoS::ExactlyOnce); + // mqtt.outgoing_publish(publish).unwrap(); + // let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + // match packet { + // Packet::Publish(publish) => assert_eq!(publish.pkid, 1), + // packet => panic!("Invalid network request: {:?}", packet), + // } + + // mqtt.handle_incoming_pubrec(&PubRec::new(1)).unwrap(); + // let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + // match packet { + // Packet::PubRel(pubrel) => assert_eq!(pubrel.pkid, 1), + // packet => panic!("Invalid network request: {:?}", packet), + // } + // } + + // #[test] + // fn incoming_pubrel_should_send_comp_to_network_and_nothing_to_user() { + // let mut mqtt = build_mqttstate(); + // let publish = build_incoming_publish(QoS::ExactlyOnce, 1); + + // mqtt.handle_incoming_publish(&publish).unwrap(); + // let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + // match packet { + // Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), + // packet => panic!("Invalid network request: {:?}", packet), + // } + + // mqtt.handle_incoming_pubrel(&PubRel::new(1)).unwrap(); + // let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + // match packet { + // Packet::PubComp(pubcomp) => assert_eq!(pubcomp.pkid, 1), + // packet => panic!("Invalid network request: {:?}", packet), + // } + // } + + // #[test] + // fn incoming_pubcomp_should_release_correct_pkid_from_release_queue() { + // let mut mqtt = build_mqttstate(); + // let publish = build_outgoing_publish(QoS::ExactlyOnce); + + // mqtt.outgoing_publish(publish).unwrap(); + // mqtt.handle_incoming_pubrec(&PubRec::new(1)).unwrap(); + + // mqtt.handle_incoming_pubcomp(&PubComp::new(1)).unwrap(); + // assert_eq!(mqtt.inflight, 0); + // } + + // #[test] + // fn outgoing_ping_handle_should_throw_errors_for_no_pingresp() { + // let mut mqtt = build_mqttstate(); + // let mut opts = MqttOptions::new("test", "localhost", 1883); + // opts.set_keep_alive(std::time::Duration::from_secs(10)); + // mqtt.outgoing_ping().unwrap(); + + // // network activity other than pingresp + // let publish = build_outgoing_publish(QoS::AtLeastOnce); + // mqtt.handle_outgoing_packet(Request::Publish(publish)) + // .unwrap(); + // mqtt.handle_incoming_packet(Incoming::PubAck(PubAck::new(1))) + // .unwrap(); + + // // should throw error because we didn't get pingresp for previous ping + // match mqtt.outgoing_ping() { + // Ok(_) => panic!("Should throw pingresp await error"), + // Err(StateError::AwaitPingResp) => (), + // Err(e) => panic!("Should throw pingresp await error. Error = {:?}", e), + // } + // } + + // #[test] + // fn outgoing_ping_handle_should_succeed_if_pingresp_is_received() { + // let mut mqtt = build_mqttstate(); + + // let mut opts = MqttOptions::new("test", "localhost", 1883); + // opts.set_keep_alive(std::time::Duration::from_secs(10)); + + // // should ping + // mqtt.outgoing_ping().unwrap(); + // mqtt.handle_incoming_packet(Incoming::PingResp).unwrap(); + + // // should ping + // mqtt.outgoing_ping().unwrap(); + // } } From ea9e2397a2b63a1185cb21a2c37bfc9440de90a5 Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Sun, 27 Mar 2022 20:47:10 +0530 Subject: [PATCH 30/41] rumqttc: restore backwards compatibilty Signed-off-by: Abhik Jain --- rumqttc/examples/async_manual_acks.rs | 22 +- rumqttc/examples/asyncpubsub.rs | 2 +- rumqttc/examples/sync_pubsub_weird.rs | 41 -- rumqttc/examples/syncpubsub.rs | 2 +- rumqttc/examples/tls.rs | 2 +- rumqttc/examples/tls2.rs | 3 +- rumqttc/examples/websocket.rs | 2 +- rumqttc/src/{v4 => }/client.rs | 7 +- rumqttc/src/{v4 => }/eventloop.rs | 7 +- rumqttc/src/{v4 => }/framed.rs | 9 +- rumqttc/src/lib.rs | 725 ++++++++++++++++++++++- rumqttc/src/{v4 => }/state.rs | 6 +- rumqttc/src/{v4 => }/tls.rs | 2 +- rumqttc/src/v4/mod.rs | 817 -------------------------- rumqttc/src/v5/client/mod.rs | 1 - rumqttc/src/v5/eventloop.rs | 8 +- rumqttc/tests/broker.rs | 2 +- rumqttc/tests/reliability.rs | 2 +- 18 files changed, 754 insertions(+), 906 deletions(-) delete mode 100644 rumqttc/examples/sync_pubsub_weird.rs rename rumqttc/src/{v4 => }/client.rs (99%) rename rumqttc/src/{v4 => }/eventloop.rs (98%) rename rumqttc/src/{v4 => }/framed.rs (97%) rename rumqttc/src/{v4 => }/state.rs (99%) rename rumqttc/src/{v4 => }/tls.rs (98%) delete mode 100644 rumqttc/src/v4/mod.rs diff --git a/rumqttc/examples/async_manual_acks.rs b/rumqttc/examples/async_manual_acks.rs index 000a060b7..4e9449841 100644 --- a/rumqttc/examples/async_manual_acks.rs +++ b/rumqttc/examples/async_manual_acks.rs @@ -1,6 +1,6 @@ use tokio::{task, time}; -use rumqttc::v4::{AsyncClient, EventLoop, MqttOptions, QoS}; +use rumqttc::{self, AsyncClient, Event, EventLoop, Incoming, MqttOptions, QoS}; use std::error::Error; use std::time::Duration; @@ -44,23 +44,21 @@ async fn main() -> Result<(), Box> { } // create new broker connection - let (_client, mut eventloop) = create_conn(); + let (client, mut eventloop) = create_conn(); loop { // previously published messages should be republished after reconnection. let event = eventloop.poll().await; println!("{:?}", event); - todo!("fix the commented out code below") - - // if let Ok(Event::Incoming(Incoming::Publish(publish))) = event { - // // this time we will ack incoming publishes. - // // Its important not to block eventloop as this can cause deadlock. - // let c = client.clone(); - // tokio::spawn(async move { - // c.ack(&publish).await.unwrap(); - // }); - // } + if let Ok(Event::Incoming(Incoming::Publish(publish))) = event { + // this time we will ack incoming publishes. + // Its important not to block eventloop as this can cause deadlock. + let c = client.clone(); + tokio::spawn(async move { + c.ack(&publish).await.unwrap(); + }); + } } } diff --git a/rumqttc/examples/asyncpubsub.rs b/rumqttc/examples/asyncpubsub.rs index b4de624e7..4ec2cd983 100644 --- a/rumqttc/examples/asyncpubsub.rs +++ b/rumqttc/examples/asyncpubsub.rs @@ -1,6 +1,6 @@ use tokio::{task, time}; -use rumqttc::v4::{AsyncClient, MqttOptions, QoS}; +use rumqttc::{self, AsyncClient, MqttOptions, QoS}; use std::error::Error; use std::time::Duration; diff --git a/rumqttc/examples/sync_pubsub_weird.rs b/rumqttc/examples/sync_pubsub_weird.rs deleted file mode 100644 index e6c7b9c3d..000000000 --- a/rumqttc/examples/sync_pubsub_weird.rs +++ /dev/null @@ -1,41 +0,0 @@ -use rumqttc::v4::{Client, Event, MqttOptions, Packet, QoS}; -use std::thread; -use std::time::Duration; - -fn main() { - let mut mqttoptions = MqttOptions::new("test1", "localhost", 1883); - mqttoptions.set_keep_alive(Duration::from_secs(5)); - mqttoptions.set_inflight(5); - - let (mut client, mut connection) = Client::new(mqttoptions, 25); - - client.subscribe("test/me", QoS::AtMostOnce).unwrap(); - - thread::spawn(move || publish(client)); - - for notification in connection.iter() { - println!("notification: {:?}", notification); - match notification.unwrap() { - Event::Incoming(inc) => match inc { - Packet::Publish(_) => { - thread::sleep(Duration::from_millis(200)); - } - _ => {} - }, - _ => {} - } - } - - println!("Done with the stream!!"); -} - -fn publish(client: Client) { - loop { - let payload = b"Foobar"; - let mut cl2 = client.clone(); - cl2.publish("test/me", QoS::AtLeastOnce, false, &payload[..]) - .unwrap(); - thread::sleep(Duration::from_millis(50)); - println!("send!"); - } -} diff --git a/rumqttc/examples/syncpubsub.rs b/rumqttc/examples/syncpubsub.rs index 2001b4d29..bb2171cfd 100644 --- a/rumqttc/examples/syncpubsub.rs +++ b/rumqttc/examples/syncpubsub.rs @@ -1,4 +1,4 @@ -use rumqttc::v4::{Client, LastWill, MqttOptions, QoS}; +use rumqttc::{self, Client, LastWill, MqttOptions, QoS}; use std::thread; use std::time::Duration; diff --git a/rumqttc/examples/tls.rs b/rumqttc/examples/tls.rs index f765d2f5a..0e25860e3 100644 --- a/rumqttc/examples/tls.rs +++ b/rumqttc/examples/tls.rs @@ -4,7 +4,7 @@ use std::error::Error; #[cfg(feature = "use-rustls")] #[tokio::main] async fn main() -> Result<(), Box> { - use rumqttc::v4::{AsyncClient, Event, Incoming, MqttOptions, Transport}; + use rumqttc::{self, AsyncClient, Event, Incoming, MqttOptions, Transport}; use rustls::ClientConfig; pretty_env_logger::init(); diff --git a/rumqttc/examples/tls2.rs b/rumqttc/examples/tls2.rs index c2106d585..496a806a0 100644 --- a/rumqttc/examples/tls2.rs +++ b/rumqttc/examples/tls2.rs @@ -1,11 +1,10 @@ //! Example of how to configure rumqttd to connect to a server using TLS and authentication. - use std::error::Error; #[cfg(feature = "use-rustls")] #[tokio::main] async fn main() -> Result<(), Box> { - use rumqttc::v4::{AsyncClient, Key, MqttOptions, TlsConfiguration, Transport}; + use rumqttc::{self, AsyncClient, Key, MqttOptions, TlsConfiguration, Transport}; pretty_env_logger::init(); color_backtrace::install(); diff --git a/rumqttc/examples/websocket.rs b/rumqttc/examples/websocket.rs index b5d1e39a4..712427b3c 100644 --- a/rumqttc/examples/websocket.rs +++ b/rumqttc/examples/websocket.rs @@ -1,5 +1,5 @@ #[cfg(feature = "websocket")] -use rumqttc::v4::{AsyncClient, MqttOptions, QoS, Transport}; +use rumqttc::{self, AsyncClient, MqttOptions, QoS, Transport}; #[cfg(feature = "websocket")] use std::{error::Error, time::Duration}; #[cfg(feature = "websocket")] diff --git a/rumqttc/src/v4/client.rs b/rumqttc/src/client.rs similarity index 99% rename from rumqttc/src/v4/client.rs rename to rumqttc/src/client.rs index 31a56c6ae..8464df271 100644 --- a/rumqttc/src/v4/client.rs +++ b/rumqttc/src/client.rs @@ -1,10 +1,7 @@ //! This module offers a high level synchronous and asynchronous abstraction to //! async eventloop. -use crate::{ - mqttbytes, - mqttbytes::v4::*, - v4::{ConnectionError, Event, EventLoop, MqttOptions, QoS, Request}, -}; +use crate::mqttbytes::{self, v4::*, QoS}; +use crate::{ConnectionError, Event, EventLoop, MqttOptions, Request}; use async_channel::{SendError, Sender, TrySendError}; use bytes::Bytes; diff --git a/rumqttc/src/v4/eventloop.rs b/rumqttc/src/eventloop.rs similarity index 98% rename from rumqttc/src/v4/eventloop.rs rename to rumqttc/src/eventloop.rs index efb8da1af..e65619b3f 100644 --- a/rumqttc/src/v4/eventloop.rs +++ b/rumqttc/src/eventloop.rs @@ -1,8 +1,7 @@ -use crate::v4::{framed::Network, Transport}; -use crate::v4::{Incoming, MqttState, Packet, Request, StateError}; -use crate::v4::{MqttOptions, Outgoing}; +use crate::framed::Network; #[cfg(feature = "use-rustls")] -use crate::v4::tls; +use crate::tls; +use crate::{Incoming, MqttOptions, MqttState, Outgoing, Packet, Request, StateError, Transport}; use crate::mqttbytes::v4::*; use async_channel::{bounded, Receiver, Sender}; diff --git a/rumqttc/src/v4/framed.rs b/rumqttc/src/framed.rs similarity index 97% rename from rumqttc/src/v4/framed.rs rename to rumqttc/src/framed.rs index 171f689c7..b0a536e78 100644 --- a/rumqttc/src/v4/framed.rs +++ b/rumqttc/src/framed.rs @@ -1,13 +1,8 @@ use bytes::BytesMut; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use crate::{ - mqttbytes::{ - self, - v4::{read, Connect}, - }, - v4::{Incoming, MqttState, StateError}, -}; +use crate::mqttbytes::{self, v4::*}; +use crate::{Incoming, MqttState, StateError}; use std::io; /// Network transforms packets <-> frames efficiently. It takes diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index e74d6adbf..a124168d3 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -9,7 +9,7 @@ //! ---------------------------- //! //! ```no_run -//! use rumqttc::v4::{MqttOptions, Client, QoS}; +//! use rumqttc::{MqttOptions, Client, QoS}; //! use std::time::Duration; //! use std::thread; //! @@ -33,7 +33,7 @@ //! ------------------------------ //! //! ```no_run -//! use rumqttc::v4::{MqttOptions, AsyncClient, QoS}; +//! use rumqttc::{MqttOptions, AsyncClient, QoS}; //! use tokio::{task, time}; //! use std::time::Duration; //! use std::error::Error; @@ -99,7 +99,724 @@ #[macro_use] extern crate log; -#[allow(clippy::all)] +use std::fmt::{self, Debug, Formatter}; +#[cfg(feature = "use-rustls")] +use std::sync::Arc; +use std::time::Duration; + +mod client; +mod eventloop; +mod framed; pub mod mqttbytes; -pub mod v4; +mod state; +#[cfg(feature = "use-rustls")] +mod tls; pub mod v5; + +pub use async_channel::{SendError, Sender, TrySendError}; +pub use client::{AsyncClient, Client, ClientError, Connection}; +pub use eventloop::{ConnectionError, Event, EventLoop}; +pub use mqttbytes::v4::*; +pub use mqttbytes::*; +pub use state::{MqttState, StateError}; +#[cfg(feature = "use-rustls")] +pub use tls::Error as TlsError; +#[cfg(feature = "use-rustls")] +pub use tokio_rustls::rustls::ClientConfig; + +pub type Incoming = Packet; + +/// Current outgoing activity on the eventloop +#[derive(Debug, Eq, PartialEq, Clone)] +pub enum Outgoing { + /// Publish packet with packet identifier. 0 implies QoS 0 + Publish(u16), + /// Subscribe packet with packet identifier + Subscribe(u16), + /// Unsubscribe packet with packet identifier + Unsubscribe(u16), + /// PubAck packet + PubAck(u16), + /// PubRec packet + PubRec(u16), + /// PubRel packet + PubRel(u16), + /// PubComp packet + PubComp(u16), + /// Ping request packet + PingReq, + /// Ping response packet + PingResp, + /// Disconnect packet + Disconnect, + /// Await for an ack for more outgoing progress + AwaitAck(u16), +} + +/// Requests by the client to mqtt event loop. Request are +/// handled one by one. +#[derive(Clone, Debug, PartialEq)] +pub enum Request { + Publish(Publish), + PubAck(PubAck), + PubRec(PubRec), + PubComp(PubComp), + PubRel(PubRel), + PingReq, + PingResp, + Subscribe(Subscribe), + SubAck(SubAck), + Unsubscribe(Unsubscribe), + UnsubAck(UnsubAck), + Disconnect, +} + +/// Key type for TLS authentication +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum Key { + RSA(Vec), + ECC(Vec), +} + +impl From for Request { + fn from(publish: Publish) -> Request { + Request::Publish(publish) + } +} + +impl From for Request { + fn from(subscribe: Subscribe) -> Request { + Request::Subscribe(subscribe) + } +} + +impl From for Request { + fn from(unsubscribe: Unsubscribe) -> Request { + Request::Unsubscribe(unsubscribe) + } +} + +#[derive(Clone)] +pub enum Transport { + Tcp, + #[cfg(feature = "use-rustls")] + Tls(TlsConfiguration), + #[cfg(unix)] + Unix, + #[cfg(feature = "websocket")] + #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + Ws, + #[cfg(all(feature = "use-rustls", feature = "websocket"))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "use-rustls", feature = "websocket"))))] + Wss(TlsConfiguration), +} + +impl Default for Transport { + fn default() -> Self { + Self::tcp() + } +} + +impl Transport { + /// Use regular tcp as transport (default) + pub fn tcp() -> Self { + Self::Tcp + } + + /// Use secure tcp with tls as transport + #[cfg(feature = "use-rustls")] + pub fn tls( + ca: Vec, + client_auth: Option<(Vec, Key)>, + alpn: Option>>, + ) -> Self { + let config = TlsConfiguration::Simple { + ca, + alpn, + client_auth, + }; + + Self::tls_with_config(config) + } + + #[cfg(feature = "use-rustls")] + pub fn tls_with_config(tls_config: TlsConfiguration) -> Self { + Self::Tls(tls_config) + } + + #[cfg(unix)] + pub fn unix() -> Self { + Self::Unix + } + + /// Use websockets as transport + #[cfg(feature = "websocket")] + #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + pub fn ws() -> Self { + Self::Ws + } + + /// Use secure websockets with tls as transport + #[cfg(all(feature = "use-rustls", feature = "websocket"))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "use-rustls", feature = "websocket"))))] + pub fn wss( + ca: Vec, + client_auth: Option<(Vec, Key)>, + alpn: Option>>, + ) -> Self { + let config = TlsConfiguration::Simple { + ca, + client_auth, + alpn, + }; + + Self::wss_with_config(config) + } + + #[cfg(all(feature = "use-rustls", feature = "websocket"))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "use-rustls", feature = "websocket"))))] + pub fn wss_with_config(tls_config: TlsConfiguration) -> Self { + Self::Wss(tls_config) + } +} + +#[derive(Clone)] +#[cfg(feature = "use-rustls")] +pub enum TlsConfiguration { + Simple { + /// connection method + ca: Vec, + /// alpn settings + alpn: Option>>, + /// tls client_authentication + client_auth: Option<(Vec, Key)>, + }, + /// Injected rustls ClientConfig for TLS, to allow more customisation. + Rustls(Arc), +} + +#[cfg(feature = "use-rustls")] +impl From for TlsConfiguration { + fn from(config: ClientConfig) -> Self { + TlsConfiguration::Rustls(Arc::new(config)) + } +} + +// TODO: Should all the options be exposed as public? Drawback +// would be loosing the ability to panic when the user options +// are wrong (e.g empty client id) or aggressive (keep alive time) +/// Options to configure the behaviour of mqtt connection +#[derive(Clone)] +pub struct MqttOptions { + /// broker address that you want to connect to + broker_addr: String, + /// broker port + port: u16, + // What transport protocol to use + transport: Transport, + /// keep alive time to send pingreq to broker when the connection is idle + keep_alive: Duration, + /// clean (or) persistent session + clean_session: bool, + /// client identifier + client_id: String, + /// username and password + credentials: Option<(String, String)>, + /// maximum incoming packet size (verifies remaining length of the packet) + max_incoming_packet_size: usize, + /// Maximum outgoing packet size (only verifies publish payload size) + // TODO Verify this with all packets. This can be packet.write but message left in + // the state might be a footgun as user has to explicitly clean it. Probably state + // has to be moved to network + max_outgoing_packet_size: usize, + /// request (publish, subscribe) channel capacity + request_channel_capacity: usize, + /// Max internal request batching + max_request_batch: usize, + /// Minimum delay time between consecutive outgoing packets + /// while retransmitting pending packets + pending_throttle: Duration, + /// maximum number of outgoing inflight messages + inflight: u16, + /// Last will that will be issued on unexpected disconnect + last_will: Option, + /// Connection timeout + conn_timeout: u64, + /// If set to `true` MQTT acknowledgements are not sent automatically. + /// Every incoming publish packet must be manually acknowledged with `client.ack(...)` method. + manual_acks: bool, +} + +impl MqttOptions { + /// New mqtt options + pub fn new, T: Into>(id: S, host: T, port: u16) -> MqttOptions { + let id = id.into(); + if id.starts_with(' ') || id.is_empty() { + panic!("Invalid client id") + } + + MqttOptions { + broker_addr: host.into(), + port, + transport: Transport::tcp(), + keep_alive: Duration::from_secs(60), + clean_session: true, + client_id: id, + credentials: None, + max_incoming_packet_size: 10 * 1024, + max_outgoing_packet_size: 10 * 1024, + request_channel_capacity: 10, + max_request_batch: 0, + pending_throttle: Duration::from_micros(0), + inflight: 100, + last_will: None, + conn_timeout: 5, + manual_acks: false, + } + } + + /// Broker address + pub fn broker_address(&self) -> (String, u16) { + (self.broker_addr.clone(), self.port) + } + + pub fn set_last_will(&mut self, will: LastWill) -> &mut Self { + self.last_will = Some(will); + self + } + + pub fn last_will(&self) -> Option { + self.last_will.clone() + } + + pub fn set_transport(&mut self, transport: Transport) -> &mut Self { + self.transport = transport; + self + } + + pub fn transport(&self) -> Transport { + self.transport.clone() + } + + /// Set number of seconds after which client should ping the broker + /// if there is no other data exchange + pub fn set_keep_alive(&mut self, duration: Duration) -> &mut Self { + if duration.as_secs() < 5 { + panic!("Keep alives should be >= 5 secs"); + } + + self.keep_alive = duration; + self + } + + /// Keep alive time + pub fn keep_alive(&self) -> Duration { + self.keep_alive + } + + /// Client identifier + pub fn client_id(&self) -> String { + self.client_id.clone() + } + + /// Set packet size limit for outgoing an incoming packets + pub fn set_max_packet_size(&mut self, incoming: usize, outgoing: usize) -> &mut Self { + self.max_incoming_packet_size = incoming; + self.max_outgoing_packet_size = outgoing; + self + } + + /// Maximum packet size + pub fn max_packet_size(&self) -> usize { + self.max_incoming_packet_size + } + + /// `clean_session = true` removes all the state from queues & instructs the broker + /// to clean all the client state when client disconnects. + /// + /// When set `false`, broker will hold the client state and performs pending + /// operations on the client when reconnection with same `client_id` + /// happens. Local queue state is also held to retransmit packets after reconnection. + pub fn set_clean_session(&mut self, clean_session: bool) -> &mut Self { + self.clean_session = clean_session; + self + } + + /// Clean session + pub fn clean_session(&self) -> bool { + self.clean_session + } + + /// Username and password + pub fn set_credentials, P: Into>( + &mut self, + username: U, + password: P, + ) -> &mut Self { + self.credentials = Some((username.into(), password.into())); + self + } + + /// Security options + pub fn credentials(&self) -> Option<(String, String)> { + self.credentials.clone() + } + + /// Set request channel capacity + pub fn set_request_channel_capacity(&mut self, capacity: usize) -> &mut Self { + self.request_channel_capacity = capacity; + self + } + + /// Request channel capacity + pub fn request_channel_capacity(&self) -> usize { + self.request_channel_capacity + } + + /// Enables throttling and sets outoing message rate to the specified 'rate' + pub fn set_pending_throttle(&mut self, duration: Duration) -> &mut Self { + self.pending_throttle = duration; + self + } + + /// Outgoing message rate + pub fn pending_throttle(&self) -> Duration { + self.pending_throttle + } + + /// Set number of concurrent in flight messages + pub fn set_inflight(&mut self, inflight: u16) -> &mut Self { + if inflight == 0 { + panic!("zero in flight is not allowed") + } + + self.inflight = inflight; + self + } + + /// Number of concurrent in flight messages + pub fn inflight(&self) -> u16 { + self.inflight + } + + /// set connection timeout in secs + pub fn set_connection_timeout(&mut self, timeout: u64) -> &mut Self { + self.conn_timeout = timeout; + self + } + + /// get timeout in secs + pub fn connection_timeout(&self) -> u64 { + self.conn_timeout + } + + /// set manual acknowledgements + pub fn set_manual_acks(&mut self, manual_acks: bool) -> &mut Self { + self.manual_acks = manual_acks; + self + } + + /// get manual acknowledgements + pub fn manual_acks(&self) -> bool { + self.manual_acks + } +} + +#[cfg(feature = "url")] +#[derive(Debug, PartialEq, thiserror::Error)] +pub enum OptionError { + #[error("Unsupported URL scheme.")] + Scheme, + + #[error("Missing client ID.")] + ClientId, + + #[error("Invalid keep-alive value.")] + KeepAlive, + + #[error("Invalid clean-session value.")] + CleanSession, + + #[error("Invalid max-incoming-packet-size value.")] + MaxIncomingPacketSize, + + #[error("Invalid max-outgoing-packet-size value.")] + MaxOutgoingPacketSize, + + #[error("Invalid request-channel-capacity value.")] + RequestChannelCapacity, + + #[error("Invalid max-request-batch value.")] + MaxRequestBatch, + + #[error("Invalid pending-throttle value.")] + PendingThrottle, + + #[error("Invalid inflight value.")] + Inflight, + + #[error("Invalid conn-timeout value.")] + ConnTimeout, + + #[error("Unknown option: {0}")] + Unknown(String), +} + +#[cfg(feature = "url")] +impl std::convert::TryFrom for MqttOptions { + type Error = OptionError; + + fn try_from(url: url::Url) -> Result { + use std::collections::HashMap; + + let broker_addr = url.host_str().unwrap_or_default().to_owned(); + + let (transport, default_port) = match url.scheme() { + // Encrypted connections are supported, but require explicit TLS configuration. We fall + // back to the unencrypted transport layer, so that `set_transport` can be used to + // configure the encrypted transport layer with the provided TLS configuration. + "mqtts" | "ssl" => (Transport::Tcp, 8883), + "mqtt" | "tcp" => (Transport::Tcp, 1883), + _ => return Err(OptionError::Scheme), + }; + + let port = url.port().unwrap_or(default_port); + + let mut queries = url.query_pairs().collect::>(); + + let keep_alive = Duration::from_secs( + queries + .remove("keep_alive_secs") + .map(|v| v.parse::().map_err(|_| OptionError::KeepAlive)) + .transpose()? + .unwrap_or(60), + ); + + let client_id = queries + .remove("client_id") + .ok_or(OptionError::ClientId)? + .into_owned(); + + let clean_session = queries + .remove("clean_session") + .map(|v| v.parse::().map_err(|_| OptionError::CleanSession)) + .transpose()? + .unwrap_or(true); + + let credentials = { + match url.username() { + "" => None, + username => Some(( + username.to_owned(), + url.password().unwrap_or_default().to_owned(), + )), + } + }; + + let max_incoming_packet_size = queries + .remove("max_incoming_packet_size_bytes") + .map(|v| { + v.parse::() + .map_err(|_| OptionError::MaxIncomingPacketSize) + }) + .transpose()? + .unwrap_or(10 * 1024); + + let max_outgoing_packet_size = queries + .remove("max_outgoing_packet_size_bytes") + .map(|v| { + v.parse::() + .map_err(|_| OptionError::MaxOutgoingPacketSize) + }) + .transpose()? + .unwrap_or(10 * 1024); + + let request_channel_capacity = queries + .remove("request_channel_capacity_num") + .map(|v| { + v.parse::() + .map_err(|_| OptionError::RequestChannelCapacity) + }) + .transpose()? + .unwrap_or(10); + + let max_request_batch = queries + .remove("max_request_batch_num") + .map(|v| v.parse::().map_err(|_| OptionError::MaxRequestBatch)) + .transpose()? + .unwrap_or(0); + + let pending_throttle = Duration::from_micros( + queries + .remove("pending_throttle_usecs") + .map(|v| v.parse::().map_err(|_| OptionError::PendingThrottle)) + .transpose()? + .unwrap_or(0), + ); + + let inflight = queries + .remove("inflight_num") + .map(|v| v.parse::().map_err(|_| OptionError::Inflight)) + .transpose()? + .unwrap_or(100); + + let conn_timeout = queries + .remove("conn_timeout_secs") + .map(|v| v.parse::().map_err(|_| OptionError::ConnTimeout)) + .transpose()? + .unwrap_or(5); + + if let Some((opt, _)) = queries.into_iter().next() { + return Err(OptionError::Unknown(opt.into_owned())); + } + + Ok(Self { + broker_addr, + port, + transport, + keep_alive, + clean_session, + client_id, + credentials, + max_incoming_packet_size, + max_outgoing_packet_size, + request_channel_capacity, + max_request_batch, + pending_throttle, + inflight, + last_will: None, + conn_timeout, + manual_acks: false, + }) + } +} + +// Implement Debug manually because ClientConfig doesn't implement it, so derive(Debug) doesn't +// work. +impl Debug for MqttOptions { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("MqttOptions") + .field("broker_addr", &self.broker_addr) + .field("port", &self.port) + .field("keep_alive", &self.keep_alive) + .field("clean_session", &self.clean_session) + .field("client_id", &self.client_id) + .field("credentials", &self.credentials) + .field("max_packet_size", &self.max_incoming_packet_size) + .field("request_channel_capacity", &self.request_channel_capacity) + .field("max_request_batch", &self.max_request_batch) + .field("pending_throttle", &self.pending_throttle) + .field("inflight", &self.inflight) + .field("last_will", &self.last_will) + .field("conn_timeout", &self.conn_timeout) + .field("manual_acks", &self.manual_acks) + .finish() + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + #[should_panic] + fn client_id_startswith_space() { + let _mqtt_opts = MqttOptions::new(" client_a", "127.0.0.1", 1883).set_clean_session(true); + } + + #[test] + #[cfg(all(feature = "use-rustls", feature = "websocket"))] + fn no_scheme() { + let mut _mqtt_opts = MqttOptions::new("client_a", "a3f8czas.iot.eu-west-1.amazonaws.com/mqtt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=MyCreds%2F20201001%2Feu-west-1%2Fiotdevicegateway%2Faws4_request&X-Amz-Date=20201001T130812Z&X-Amz-Expires=7200&X-Amz-Signature=9ae09b49896f44270f2707551581953e6cac71a4ccf34c7c3415555be751b2d1&X-Amz-SignedHeaders=host", 443); + + _mqtt_opts.set_transport(crate::Transport::wss(Vec::from("Test CA"), None, None)); + + if let crate::Transport::Wss(TlsConfiguration::Simple { + ca, + client_auth, + alpn, + }) = _mqtt_opts.transport + { + assert_eq!(ca, Vec::from("Test CA")); + assert_eq!(client_auth, None); + assert_eq!(alpn, None); + } else { + panic!("Unexpected transport!"); + } + + assert_eq!(_mqtt_opts.broker_addr, "a3f8czas.iot.eu-west-1.amazonaws.com/mqtt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=MyCreds%2F20201001%2Feu-west-1%2Fiotdevicegateway%2Faws4_request&X-Amz-Date=20201001T130812Z&X-Amz-Expires=7200&X-Amz-Signature=9ae09b49896f44270f2707551581953e6cac71a4ccf34c7c3415555be751b2d1&X-Amz-SignedHeaders=host"); + } + + #[test] + #[cfg(feature = "url")] + fn from_url() { + use std::convert::TryInto; + use std::str::FromStr; + + fn opt(s: &str) -> Result { + url::Url::from_str(s).expect("valid url").try_into() + } + fn ok(s: &str) -> MqttOptions { + opt(s).expect("valid options") + } + fn err(s: &str) -> OptionError { + opt(s).expect_err("invalid options") + } + + let v = ok("mqtt://host:42?client_id=foo"); + assert_eq!(v.broker_address(), ("host".to_owned(), 42)); + assert_eq!(v.client_id(), "foo".to_owned()); + + let v = ok("mqtt://host:42?client_id=foo&keep_alive_secs=5"); + assert_eq!(v.keep_alive, Duration::from_secs(5)); + + assert_eq!(err("mqtt://host:42"), OptionError::ClientId); + assert_eq!( + err("mqtt://host:42?client_id=foo&foo=bar"), + OptionError::Unknown("foo".to_owned()) + ); + assert_eq!(err("mqt://host:42?client_id=foo"), OptionError::Scheme); + assert_eq!( + err("mqtt://host:42?client_id=foo&keep_alive_secs=foo"), + OptionError::KeepAlive + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&clean_session=foo"), + OptionError::CleanSession + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&max_incoming_packet_size_bytes=foo"), + OptionError::MaxIncomingPacketSize + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&max_outgoing_packet_size_bytes=foo"), + OptionError::MaxOutgoingPacketSize + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&request_channel_capacity_num=foo"), + OptionError::RequestChannelCapacity + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&max_request_batch_num=foo"), + OptionError::MaxRequestBatch + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&pending_throttle_usecs=foo"), + OptionError::PendingThrottle + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&inflight_num=foo"), + OptionError::Inflight + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&conn_timeout_secs=foo"), + OptionError::ConnTimeout + ); + } + + #[test] + #[should_panic] + fn no_client_id() { + let _mqtt_opts = MqttOptions::new("", "127.0.0.1", 1883).set_clean_session(true); + } +} diff --git a/rumqttc/src/v4/state.rs b/rumqttc/src/state.rs similarity index 99% rename from rumqttc/src/v4/state.rs rename to rumqttc/src/state.rs index 7f6d5959f..399001cd8 100644 --- a/rumqttc/src/v4/state.rs +++ b/rumqttc/src/state.rs @@ -1,4 +1,4 @@ -use crate::v4::{Event, Incoming, Outgoing, Request}; +use crate::{Event, Incoming, Outgoing, Request}; use crate::mqttbytes::v4::*; use crate::mqttbytes::{self, *}; @@ -487,7 +487,9 @@ impl MqttState { #[cfg(test)] mod test { use super::{MqttState, StateError}; - use crate::{mqttbytes::v4::read, v4::*}; + use crate::mqttbytes::v4::*; + use crate::mqttbytes::*; + use crate::{Event, Incoming, MqttOptions, Outgoing, Request}; fn build_outgoing_publish(qos: QoS) -> Publish { let topic = "hello/world".to_owned(); diff --git a/rumqttc/src/v4/tls.rs b/rumqttc/src/tls.rs similarity index 98% rename from rumqttc/src/v4/tls.rs rename to rumqttc/src/tls.rs index 72eca8aa8..c740626b3 100644 --- a/rumqttc/src/v4/tls.rs +++ b/rumqttc/src/tls.rs @@ -7,7 +7,7 @@ use tokio_rustls::rustls::{ use tokio_rustls::webpki; use tokio_rustls::{client::TlsStream, TlsConnector}; -use crate::v4::{Key, MqttOptions, TlsConfiguration}; +use crate::{Key, MqttOptions, TlsConfiguration}; use std::convert::TryFrom; use std::io; diff --git a/rumqttc/src/v4/mod.rs b/rumqttc/src/v4/mod.rs deleted file mode 100644 index 0c39b6478..000000000 --- a/rumqttc/src/v4/mod.rs +++ /dev/null @@ -1,817 +0,0 @@ -//! A pure rust MQTT client which strives to be robust, efficient and easy to use. -//! This library is backed by an async (tokio) eventloop which handles all the -//! robustness and and efficiency parts of MQTT but naturally fits into both sync -//! and async worlds as we'll see -//! -//! Let's jump into examples right away -//! -//! A simple synchronous publish and subscribe -//! ---------------------------- -//! -//! ```no_run -//! use rumqttc::v4::{MqttOptions, Client, QoS}; -//! use std::time::Duration; -//! use std::thread; -//! -//! let mut mqttoptions = MqttOptions::new("rumqtt-sync", "test.mosquitto.org", 1883); -//! mqttoptions.set_keep_alive(Duration::from_secs(5)); -//! -//! let (mut client, mut connection) = Client::new(mqttoptions, 10); -//! client.subscribe("hello/rumqtt", QoS::AtMostOnce).unwrap(); -//! thread::spawn(move || for i in 0..10 { -//! client.publish("hello/rumqtt", QoS::AtLeastOnce, false, vec![i; i as usize]).unwrap(); -//! thread::sleep(Duration::from_millis(100)); -//! }); -//! -//! // Iterate to poll the eventloop for connection progress -//! for (i, notification) in connection.iter().enumerate() { -//! println!("Notification = {:?}", notification); -//! } -//! ``` -//! -//! A simple asynchronous publish and subscribe -//! ------------------------------ -//! -//! ```no_run -//! use rumqttc::v4::{MqttOptions, AsyncClient, QoS}; -//! use tokio::{task, time}; -//! use std::time::Duration; -//! use std::error::Error; -//! -//! # #[tokio::main(worker_threads = 1)] -//! # async fn main() { -//! let mut mqttoptions = MqttOptions::new("rumqtt-async", "test.mosquitto.org", 1883); -//! mqttoptions.set_keep_alive(Duration::from_secs(5)); -//! -//! let (mut client, mut eventloop) = AsyncClient::new(mqttoptions, 10); -//! client.subscribe("hello/rumqtt", QoS::AtMostOnce).await.unwrap(); -//! -//! task::spawn(async move { -//! for i in 0..10 { -//! client.publish("hello/rumqtt", QoS::AtLeastOnce, false, vec![i; i as usize]).await.unwrap(); -//! time::sleep(Duration::from_millis(100)).await; -//! } -//! }); -//! -//! loop { -//! let notification = eventloop.poll().await.unwrap(); -//! println!("Received = {:?}", notification); -//! } -//! # } -//! ``` -//! -//! Quick overview of features -//! - Eventloop orchestrates outgoing/incoming packets concurrently and hadles the state -//! - Pings the broker when necessary and detects client side half open connections as well -//! - Throttling of outgoing packets (todo) -//! - Queue size based flow control on outgoing packets -//! - Automatic reconnections by just continuing the `eventloop.poll()/connection.iter()` loop` -//! - Natural backpressure to client APIs during bad network -//! - Immediate cancellation with `client.cancel()` -//! -//! In short, everything necessary to maintain a robust connection -//! -//! Since the eventloop is externally polled (with `iter()/poll()` in a loop) -//! out side the library and `Eventloop` is accessible, users can -//! - Distribute incoming messages based on topics -//! - Stop it when required -//! - Access internal state for use cases like graceful shutdown or to modify options before reconnection -//! -//! ## Important notes -//! -//! - Looping on `connection.iter()`/`eventloop.poll()` is necessary to run the -//! event loop and make progress. It yields incoming and outgoing activity -//! notifications which allows customization as you see fit. -//! -//! - Blocking inside the `connection.iter()`/`eventloop.poll()` loop will block -//! connection progress. -//! -//! ## FAQ -//! **Connecting to a broker using raw ip doesn't work** -//! -//! You cannot create a TLS connection to a bare IP address with a self-signed -//! certificate. This is a [limitation of rustls](https://github.com/ctz/rustls/issues/184). -//! One workaround, which only works under *nix/BSD-like systems, is to add an -//! entry to wherever your DNS resolver looks (e.g. `/etc/hosts`) for the bare IP -//! address and use that name in your code. -#![cfg_attr(docsrs, feature(doc_cfg))] - -use std::fmt::{self, Debug, Formatter}; -#[cfg(feature = "use-rustls")] -use std::sync::Arc; -use std::time::Duration; - -mod client; -mod eventloop; -mod framed; -mod state; -#[cfg(feature = "use-rustls")] -mod tls; - -pub use async_channel::{SendError, Sender, TrySendError}; -pub use client::{AsyncClient, Client, ClientError, Connection}; -pub use eventloop::{ConnectionError, Event, EventLoop}; -pub use crate::mqttbytes::v4::*; -pub use crate::mqttbytes::*; -pub use state::{MqttState, StateError}; -#[cfg(feature = "use-rustls")] -pub use tls::Error as TlsError; -#[cfg(feature = "use-rustls")] -pub use tokio_rustls::rustls::ClientConfig; - -pub type Incoming = Packet; - -/// Current outgoing activity on the eventloop -#[derive(Debug, Eq, PartialEq, Clone)] -pub enum Outgoing { - /// Publish packet with packet identifier. 0 implies QoS 0 - Publish(u16), - /// Subscribe packet with packet identifier - Subscribe(u16), - /// Unsubscribe packet with packet identifier - Unsubscribe(u16), - /// PubAck packet - PubAck(u16), - /// PubRec packet - PubRec(u16), - /// PubRel packet - PubRel(u16), - /// PubComp packet - PubComp(u16), - /// Ping request packet - PingReq, - /// Ping response packet - PingResp, - /// Disconnect packet - Disconnect, - /// Await for an ack for more outgoing progress - AwaitAck(u16), -} - -/// Requests by the client to mqtt event loop. Request are -/// handled one by one. -#[derive(Clone, Debug, PartialEq)] -pub enum Request { - Publish(Publish), - PubAck(PubAck), - PubRec(PubRec), - PubComp(PubComp), - PubRel(PubRel), - PingReq, - PingResp, - Subscribe(Subscribe), - SubAck(SubAck), - Unsubscribe(Unsubscribe), - UnsubAck(UnsubAck), - Disconnect, -} - -/// Key type for TLS authentication -#[derive(Clone, Debug, Eq, PartialEq)] -pub enum Key { - RSA(Vec), - ECC(Vec), -} - -impl From for Request { - fn from(publish: Publish) -> Request { - Request::Publish(publish) - } -} - -impl From for Request { - fn from(subscribe: Subscribe) -> Request { - Request::Subscribe(subscribe) - } -} - -impl From for Request { - fn from(unsubscribe: Unsubscribe) -> Request { - Request::Unsubscribe(unsubscribe) - } -} - -#[derive(Clone)] -pub enum Transport { - Tcp, - #[cfg(feature = "use-rustls")] - Tls(TlsConfiguration), - #[cfg(unix)] - Unix, - #[cfg(feature = "websocket")] - #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] - Ws, - #[cfg(all(feature = "use-rustls", feature = "websocket"))] - #[cfg_attr(docsrs, doc(cfg(all(feature = "use-rustls", feature = "websocket"))))] - Wss(TlsConfiguration), -} - -impl Default for Transport { - fn default() -> Self { - Self::tcp() - } -} - -impl Transport { - /// Use regular tcp as transport (default) - pub fn tcp() -> Self { - Self::Tcp - } - - /// Use secure tcp with tls as transport - #[cfg(feature = "use-rustls")] - pub fn tls( - ca: Vec, - client_auth: Option<(Vec, Key)>, - alpn: Option>>, - ) -> Self { - let config = TlsConfiguration::Simple { - ca, - alpn, - client_auth, - }; - - Self::tls_with_config(config) - } - - #[cfg(feature = "use-rustls")] - pub fn tls_with_config(tls_config: TlsConfiguration) -> Self { - Self::Tls(tls_config) - } - - #[cfg(unix)] - pub fn unix() -> Self { - Self::Unix - } - - /// Use websockets as transport - #[cfg(feature = "websocket")] - #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] - pub fn ws() -> Self { - Self::Ws - } - - /// Use secure websockets with tls as transport - #[cfg(all(feature = "use-rustls", feature = "websocket"))] - #[cfg_attr(docsrs, doc(cfg(all(feature = "use-rustls", feature = "websocket"))))] - pub fn wss( - ca: Vec, - client_auth: Option<(Vec, Key)>, - alpn: Option>>, - ) -> Self { - let config = TlsConfiguration::Simple { - ca, - client_auth, - alpn, - }; - - Self::wss_with_config(config) - } - - #[cfg(all(feature = "use-rustls", feature = "websocket"))] - #[cfg_attr(docsrs, doc(cfg(all(feature = "use-rustls", feature = "websocket"))))] - pub fn wss_with_config(tls_config: TlsConfiguration) -> Self { - Self::Wss(tls_config) - } -} - -#[derive(Clone)] -#[cfg(feature = "use-rustls")] -pub enum TlsConfiguration { - Simple { - /// connection method - ca: Vec, - /// alpn settings - alpn: Option>>, - /// tls client_authentication - client_auth: Option<(Vec, Key)>, - }, - /// Injected rustls ClientConfig for TLS, to allow more customisation. - Rustls(Arc), -} - -#[cfg(feature = "use-rustls")] -impl From for TlsConfiguration { - fn from(config: ClientConfig) -> Self { - TlsConfiguration::Rustls(Arc::new(config)) - } -} - -// TODO: Should all the options be exposed as public? Drawback -// would be loosing the ability to panic when the user options -// are wrong (e.g empty client id) or aggressive (keep alive time) -/// Options to configure the behaviour of mqtt connection -#[derive(Clone)] -pub struct MqttOptions { - /// broker address that you want to connect to - broker_addr: String, - /// broker port - port: u16, - // What transport protocol to use - transport: Transport, - /// keep alive time to send pingreq to broker when the connection is idle - keep_alive: Duration, - /// clean (or) persistent session - clean_session: bool, - /// client identifier - client_id: String, - /// username and password - credentials: Option<(String, String)>, - /// maximum incoming packet size (verifies remaining length of the packet) - max_incoming_packet_size: usize, - /// Maximum outgoing packet size (only verifies publish payload size) - // TODO Verify this with all packets. This can be packet.write but message left in - // the state might be a footgun as user has to explicitly clean it. Probably state - // has to be moved to network - max_outgoing_packet_size: usize, - /// request (publish, subscribe) channel capacity - request_channel_capacity: usize, - /// Max internal request batching - max_request_batch: usize, - /// Minimum delay time between consecutive outgoing packets - /// while retransmitting pending packets - pending_throttle: Duration, - /// maximum number of outgoing inflight messages - inflight: u16, - /// Last will that will be issued on unexpected disconnect - last_will: Option, - /// Connection timeout - conn_timeout: u64, - /// If set to `true` MQTT acknowledgements are not sent automatically. - /// Every incoming publish packet must be manually acknowledged with `client.ack(...)` method. - manual_acks: bool, -} - -impl MqttOptions { - /// New mqtt options - pub fn new, T: Into>(id: S, host: T, port: u16) -> MqttOptions { - let id = id.into(); - if id.starts_with(' ') || id.is_empty() { - panic!("Invalid client id") - } - - MqttOptions { - broker_addr: host.into(), - port, - transport: Transport::tcp(), - keep_alive: Duration::from_secs(60), - clean_session: true, - client_id: id, - credentials: None, - max_incoming_packet_size: 10 * 1024, - max_outgoing_packet_size: 10 * 1024, - request_channel_capacity: 10, - max_request_batch: 0, - pending_throttle: Duration::from_micros(0), - inflight: 100, - last_will: None, - conn_timeout: 5, - manual_acks: false, - } - } - - /// Broker address - pub fn broker_address(&self) -> (String, u16) { - (self.broker_addr.clone(), self.port) - } - - pub fn set_last_will(&mut self, will: LastWill) -> &mut Self { - self.last_will = Some(will); - self - } - - pub fn last_will(&self) -> Option { - self.last_will.clone() - } - - pub fn set_transport(&mut self, transport: Transport) -> &mut Self { - self.transport = transport; - self - } - - pub fn transport(&self) -> Transport { - self.transport.clone() - } - - /// Set number of seconds after which client should ping the broker - /// if there is no other data exchange - pub fn set_keep_alive(&mut self, duration: Duration) -> &mut Self { - if duration.as_secs() < 5 { - panic!("Keep alives should be >= 5 secs"); - } - - self.keep_alive = duration; - self - } - - /// Keep alive time - pub fn keep_alive(&self) -> Duration { - self.keep_alive - } - - /// Client identifier - pub fn client_id(&self) -> String { - self.client_id.clone() - } - - /// Set packet size limit for outgoing an incoming packets - pub fn set_max_packet_size(&mut self, incoming: usize, outgoing: usize) -> &mut Self { - self.max_incoming_packet_size = incoming; - self.max_outgoing_packet_size = outgoing; - self - } - - /// Maximum packet size - pub fn max_packet_size(&self) -> usize { - self.max_incoming_packet_size - } - - /// `clean_session = true` removes all the state from queues & instructs the broker - /// to clean all the client state when client disconnects. - /// - /// When set `false`, broker will hold the client state and performs pending - /// operations on the client when reconnection with same `client_id` - /// happens. Local queue state is also held to retransmit packets after reconnection. - pub fn set_clean_session(&mut self, clean_session: bool) -> &mut Self { - self.clean_session = clean_session; - self - } - - /// Clean session - pub fn clean_session(&self) -> bool { - self.clean_session - } - - /// Username and password - pub fn set_credentials, P: Into>( - &mut self, - username: U, - password: P, - ) -> &mut Self { - self.credentials = Some((username.into(), password.into())); - self - } - - /// Security options - pub fn credentials(&self) -> Option<(String, String)> { - self.credentials.clone() - } - - /// Set request channel capacity - pub fn set_request_channel_capacity(&mut self, capacity: usize) -> &mut Self { - self.request_channel_capacity = capacity; - self - } - - /// Request channel capacity - pub fn request_channel_capacity(&self) -> usize { - self.request_channel_capacity - } - - /// Enables throttling and sets outoing message rate to the specified 'rate' - pub fn set_pending_throttle(&mut self, duration: Duration) -> &mut Self { - self.pending_throttle = duration; - self - } - - /// Outgoing message rate - pub fn pending_throttle(&self) -> Duration { - self.pending_throttle - } - - /// Set number of concurrent in flight messages - pub fn set_inflight(&mut self, inflight: u16) -> &mut Self { - if inflight == 0 { - panic!("zero in flight is not allowed") - } - - self.inflight = inflight; - self - } - - /// Number of concurrent in flight messages - pub fn inflight(&self) -> u16 { - self.inflight - } - - /// set connection timeout in secs - pub fn set_connection_timeout(&mut self, timeout: u64) -> &mut Self { - self.conn_timeout = timeout; - self - } - - /// get timeout in secs - pub fn connection_timeout(&self) -> u64 { - self.conn_timeout - } - - /// set manual acknowledgements - pub fn set_manual_acks(&mut self, manual_acks: bool) -> &mut Self { - self.manual_acks = manual_acks; - self - } - - /// get manual acknowledgements - pub fn manual_acks(&self) -> bool { - self.manual_acks - } -} - -#[cfg(feature = "url")] -#[derive(Debug, PartialEq, thiserror::Error)] -pub enum OptionError { - #[error("Unsupported URL scheme.")] - Scheme, - - #[error("Missing client ID.")] - ClientId, - - #[error("Invalid keep-alive value.")] - KeepAlive, - - #[error("Invalid clean-session value.")] - CleanSession, - - #[error("Invalid max-incoming-packet-size value.")] - MaxIncomingPacketSize, - - #[error("Invalid max-outgoing-packet-size value.")] - MaxOutgoingPacketSize, - - #[error("Invalid request-channel-capacity value.")] - RequestChannelCapacity, - - #[error("Invalid max-request-batch value.")] - MaxRequestBatch, - - #[error("Invalid pending-throttle value.")] - PendingThrottle, - - #[error("Invalid inflight value.")] - Inflight, - - #[error("Invalid conn-timeout value.")] - ConnTimeout, - - #[error("Unknown option: {0}")] - Unknown(String), -} - -#[cfg(feature = "url")] -impl std::convert::TryFrom for MqttOptions { - type Error = OptionError; - - fn try_from(url: url::Url) -> Result { - use std::collections::HashMap; - - let broker_addr = url.host_str().unwrap_or_default().to_owned(); - - let (transport, default_port) = match url.scheme() { - // Encrypted connections are supported, but require explicit TLS configuration. We fall - // back to the unencrypted transport layer, so that `set_transport` can be used to - // configure the encrypted transport layer with the provided TLS configuration. - "mqtts" | "ssl" => (Transport::Tcp, 8883), - "mqtt" | "tcp" => (Transport::Tcp, 1883), - _ => return Err(OptionError::Scheme), - }; - - let port = url.port().unwrap_or(default_port); - - let mut queries = url.query_pairs().collect::>(); - - let keep_alive = Duration::from_secs( - queries - .remove("keep_alive_secs") - .map(|v| v.parse::().map_err(|_| OptionError::KeepAlive)) - .transpose()? - .unwrap_or(60), - ); - - let client_id = queries - .remove("client_id") - .ok_or(OptionError::ClientId)? - .into_owned(); - - let clean_session = queries - .remove("clean_session") - .map(|v| v.parse::().map_err(|_| OptionError::CleanSession)) - .transpose()? - .unwrap_or(true); - - let credentials = { - match url.username() { - "" => None, - username => Some(( - username.to_owned(), - url.password().unwrap_or_default().to_owned(), - )), - } - }; - - let max_incoming_packet_size = queries - .remove("max_incoming_packet_size_bytes") - .map(|v| { - v.parse::() - .map_err(|_| OptionError::MaxIncomingPacketSize) - }) - .transpose()? - .unwrap_or(10 * 1024); - - let max_outgoing_packet_size = queries - .remove("max_outgoing_packet_size_bytes") - .map(|v| { - v.parse::() - .map_err(|_| OptionError::MaxOutgoingPacketSize) - }) - .transpose()? - .unwrap_or(10 * 1024); - - let request_channel_capacity = queries - .remove("request_channel_capacity_num") - .map(|v| { - v.parse::() - .map_err(|_| OptionError::RequestChannelCapacity) - }) - .transpose()? - .unwrap_or(10); - - let max_request_batch = queries - .remove("max_request_batch_num") - .map(|v| v.parse::().map_err(|_| OptionError::MaxRequestBatch)) - .transpose()? - .unwrap_or(0); - - let pending_throttle = Duration::from_micros( - queries - .remove("pending_throttle_usecs") - .map(|v| v.parse::().map_err(|_| OptionError::PendingThrottle)) - .transpose()? - .unwrap_or(0), - ); - - let inflight = queries - .remove("inflight_num") - .map(|v| v.parse::().map_err(|_| OptionError::Inflight)) - .transpose()? - .unwrap_or(100); - - let conn_timeout = queries - .remove("conn_timeout_secs") - .map(|v| v.parse::().map_err(|_| OptionError::ConnTimeout)) - .transpose()? - .unwrap_or(5); - - if let Some((opt, _)) = queries.into_iter().next() { - return Err(OptionError::Unknown(opt.into_owned())); - } - - Ok(Self { - broker_addr, - port, - transport, - keep_alive, - clean_session, - client_id, - credentials, - max_incoming_packet_size, - max_outgoing_packet_size, - request_channel_capacity, - max_request_batch, - pending_throttle, - inflight, - last_will: None, - conn_timeout, - manual_acks: false, - }) - } -} - -// Implement Debug manually because ClientConfig doesn't implement it, so derive(Debug) doesn't -// work. -impl Debug for MqttOptions { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - f.debug_struct("MqttOptions") - .field("broker_addr", &self.broker_addr) - .field("port", &self.port) - .field("keep_alive", &self.keep_alive) - .field("clean_session", &self.clean_session) - .field("client_id", &self.client_id) - .field("credentials", &self.credentials) - .field("max_packet_size", &self.max_incoming_packet_size) - .field("request_channel_capacity", &self.request_channel_capacity) - .field("max_request_batch", &self.max_request_batch) - .field("pending_throttle", &self.pending_throttle) - .field("inflight", &self.inflight) - .field("last_will", &self.last_will) - .field("conn_timeout", &self.conn_timeout) - .field("manual_acks", &self.manual_acks) - .finish() - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - #[should_panic] - fn client_id_startswith_space() { - let _mqtt_opts = MqttOptions::new(" client_a", "127.0.0.1", 1883).set_clean_session(true); - } - - #[test] - #[cfg(all(feature = "use-rustls", feature = "websocket"))] - fn no_scheme() { - let mut _mqtt_opts = MqttOptions::new("client_a", "a3f8czas.iot.eu-west-1.amazonaws.com/mqtt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=MyCreds%2F20201001%2Feu-west-1%2Fiotdevicegateway%2Faws4_request&X-Amz-Date=20201001T130812Z&X-Amz-Expires=7200&X-Amz-Signature=9ae09b49896f44270f2707551581953e6cac71a4ccf34c7c3415555be751b2d1&X-Amz-SignedHeaders=host", 443); - - _mqtt_opts.set_transport(crate::v4::Transport::wss(Vec::from("Test CA"), None, None)); - - if let crate::v4::Transport::Wss(TlsConfiguration::Simple { - ca, - client_auth, - alpn, - }) = _mqtt_opts.transport - { - assert_eq!(ca, Vec::from("Test CA")); - assert_eq!(client_auth, None); - assert_eq!(alpn, None); - } else { - panic!("Unexpected transport!"); - } - - assert_eq!(_mqtt_opts.broker_addr, "a3f8czas.iot.eu-west-1.amazonaws.com/mqtt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=MyCreds%2F20201001%2Feu-west-1%2Fiotdevicegateway%2Faws4_request&X-Amz-Date=20201001T130812Z&X-Amz-Expires=7200&X-Amz-Signature=9ae09b49896f44270f2707551581953e6cac71a4ccf34c7c3415555be751b2d1&X-Amz-SignedHeaders=host"); - } - - #[test] - #[cfg(feature = "url")] - fn from_url() { - use std::convert::TryInto; - use std::str::FromStr; - - fn opt(s: &str) -> Result { - url::Url::from_str(s).expect("valid url").try_into() - } - fn ok(s: &str) -> MqttOptions { - opt(s).expect("valid options") - } - fn err(s: &str) -> OptionError { - opt(s).expect_err("invalid options") - } - - let v = ok("mqtt://host:42?client_id=foo"); - assert_eq!(v.broker_address(), ("host".to_owned(), 42)); - assert_eq!(v.client_id(), "foo".to_owned()); - - let v = ok("mqtt://host:42?client_id=foo&keep_alive_secs=5"); - assert_eq!(v.keep_alive, Duration::from_secs(5)); - - assert_eq!(err("mqtt://host:42"), OptionError::ClientId); - assert_eq!( - err("mqtt://host:42?client_id=foo&foo=bar"), - OptionError::Unknown("foo".to_owned()) - ); - assert_eq!(err("mqt://host:42?client_id=foo"), OptionError::Scheme); - assert_eq!( - err("mqtt://host:42?client_id=foo&keep_alive_secs=foo"), - OptionError::KeepAlive - ); - assert_eq!( - err("mqtt://host:42?client_id=foo&clean_session=foo"), - OptionError::CleanSession - ); - assert_eq!( - err("mqtt://host:42?client_id=foo&max_incoming_packet_size_bytes=foo"), - OptionError::MaxIncomingPacketSize - ); - assert_eq!( - err("mqtt://host:42?client_id=foo&max_outgoing_packet_size_bytes=foo"), - OptionError::MaxOutgoingPacketSize - ); - assert_eq!( - err("mqtt://host:42?client_id=foo&request_channel_capacity_num=foo"), - OptionError::RequestChannelCapacity - ); - assert_eq!( - err("mqtt://host:42?client_id=foo&max_request_batch_num=foo"), - OptionError::MaxRequestBatch - ); - assert_eq!( - err("mqtt://host:42?client_id=foo&pending_throttle_usecs=foo"), - OptionError::PendingThrottle - ); - assert_eq!( - err("mqtt://host:42?client_id=foo&inflight_num=foo"), - OptionError::Inflight - ); - assert_eq!( - err("mqtt://host:42?client_id=foo&conn_timeout_secs=foo"), - OptionError::ConnTimeout - ); - } - - #[test] - #[should_panic] - fn no_client_id() { - let _mqtt_opts = MqttOptions::new("", "127.0.0.1", 1883).set_clean_session(true); - } -} diff --git a/rumqttc/src/v5/client/mod.rs b/rumqttc/src/v5/client/mod.rs index 811b6889d..f20efe87e 100644 --- a/rumqttc/src/v5/client/mod.rs +++ b/rumqttc/src/v5/client/mod.rs @@ -33,7 +33,6 @@ fn get_ack_req(qos: QoS, pkid: u16) -> Option { Some(ack) } - /// MQTT connection. Maintains all the necessary state pub struct Connection { pub eventloop: EventLoop, diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index 74edd7a6c..90deb7c11 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -1,9 +1,9 @@ -use crate::v5::{ - framed::Network, packet::*, Incoming, MqttOptions, MqttState, Packet, Request, - StateError, Transport, -}; #[cfg(feature = "use-rustls")] use crate::v5::tls; +use crate::v5::{ + framed::Network, packet::*, Incoming, MqttOptions, MqttState, Packet, Request, StateError, + Transport, +}; #[cfg(feature = "websocket")] use async_tungstenite::tokio::{connect_async, connect_async_with_tls_connector}; diff --git a/rumqttc/tests/broker.rs b/rumqttc/tests/broker.rs index 92f18f9bd..dc788c0e4 100644 --- a/rumqttc/tests/broker.rs +++ b/rumqttc/tests/broker.rs @@ -9,7 +9,7 @@ use tokio::{task, time}; use async_channel::{bounded, Receiver, Sender}; use bytes::BytesMut; -use rumqttc::v4::{Event, Incoming, Outgoing, Packet}; +use rumqttc::{Event, Incoming, Outgoing, Packet}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; pub struct Broker { diff --git a/rumqttc/tests/reliability.rs b/rumqttc/tests/reliability.rs index 68337a452..3fc3e48c9 100644 --- a/rumqttc/tests/reliability.rs +++ b/rumqttc/tests/reliability.rs @@ -5,7 +5,7 @@ use tokio::{task, time}; mod broker; use broker::*; -use rumqttc::v4::*; +use rumqttc::*; async fn start_requests(count: u8, qos: QoS, delay: u64, requests_tx: Sender) { for i in 1..=count { From 5d69e48b5eb270a07ff8ffff0140944ea574dea3 Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Fri, 1 Apr 2022 22:02:55 +0530 Subject: [PATCH 31/41] rumqttc: v5: move shared buffer to separate struct Signed-off-by: Abhik Jain --- rumqttc/src/v5/client/asyncclient.rs | 208 ++++++---- rumqttc/src/v5/client/syncclient.rs | 74 +++- rumqttc/src/v5/eventloop.rs | 32 +- rumqttc/src/v5/mod.rs | 7 +- rumqttc/src/v5/outgoing_buf.rs | 34 ++ rumqttc/src/v5/state.rs | 596 ++++++++++++++------------- 6 files changed, 549 insertions(+), 402 deletions(-) create mode 100644 rumqttc/src/v5/outgoing_buf.rs diff --git a/rumqttc/src/v5/client/asyncclient.rs b/rumqttc/src/v5/client/asyncclient.rs index 798439e79..fb39f8d6a 100644 --- a/rumqttc/src/v5/client/asyncclient.rs +++ b/rumqttc/src/v5/client/asyncclient.rs @@ -1,13 +1,11 @@ -use std::{ - collections::VecDeque, - sync::{Arc, Mutex}, -}; +use std::sync::{Arc, Mutex}; use bytes::Bytes; use flume::{SendError, Sender, TrySendError}; use crate::v5::{ client::get_ack_req, + outgoing_buf::OutgoingBuf, packet::{Publish, Subscribe, SubscribeFilter, Unsubscribe}, ClientError, EventLoop, MqttOptions, QoS, Request, }; @@ -16,10 +14,7 @@ use crate::v5::{ /// This is cloneable and can be used to asynchronously Publish, Subscribe. #[derive(Debug)] pub struct AsyncClient { - pub(crate) outgoing_buf: Arc>>, - pub(crate) outgoing_buf_capacity: usize, - pub(crate) pkid_counter: u16, - pub(crate) max_inflight: u16, + pub(crate) outgoing_buf: Arc>, pub(crate) request_tx: Sender<()>, } @@ -27,15 +22,11 @@ impl AsyncClient { /// Create a new `AsyncClient` pub fn new(options: MqttOptions, cap: usize) -> (AsyncClient, EventLoop) { let eventloop = EventLoop::new(options, cap); - let outgoing_buf = eventloop.request_buf().clone(); + let outgoing_buf = eventloop.state.outgoing_buf.clone(); let request_tx = eventloop.handle(); - let max_inflight = eventloop.state.max_inflight; let client = AsyncClient { outgoing_buf, - outgoing_buf_capacity: cap, - pkid_counter: 0, - max_inflight, request_tx, }; @@ -57,13 +48,18 @@ impl AsyncClient { let mut publish = Publish::new(topic, qos, payload); publish.retain = retain; let pkid = if qos != QoS::AtMostOnce { - self.increment_pkid() + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + publish.pkid = pkid; + request_buf.buf.push_back(Request::Publish(publish)); + pkid } else { 0 }; - publish.pkid = pkid; - self.push_and_async_notify(Request::Publish(publish)) - .await?; + self.notify_async().await?; Ok(pkid) } @@ -82,19 +78,30 @@ impl AsyncClient { let mut publish = Publish::new(topic, qos, payload); publish.retain = retain; let pkid = if qos != QoS::AtMostOnce { - self.increment_pkid() + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + publish.pkid = pkid; + request_buf.buf.push_back(Request::Publish(publish)); + pkid } else { 0 }; - publish.pkid = pkid; - self.push_and_try_notify(Request::Publish(publish))?; + self.try_notify()?; Ok(pkid) } /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. pub async fn ack(&mut self, publish: &Publish) -> Result<(), ClientError> { if let Some(ack) = get_ack_req(publish.qos, publish.pkid) { - self.push_and_async_notify(ack).await?; + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + request_buf.buf.push_back(ack); + self.notify_async().await?; } Ok(()) } @@ -102,7 +109,12 @@ impl AsyncClient { /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. pub fn try_ack(&mut self, publish: &Publish) -> Result<(), ClientError> { if let Some(ack) = get_ack_req(publish.qos, publish.pkid) { - self.push_and_try_notify(ack)?; + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + request_buf.buf.push_back(ack); + self.try_notify()?; } Ok(()) } @@ -121,13 +133,18 @@ impl AsyncClient { let mut publish = Publish::from_bytes(topic, qos, payload); publish.retain = retain; let pkid = if qos != QoS::AtMostOnce { - self.increment_pkid() + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + publish.pkid = pkid; + request_buf.buf.push_back(Request::Publish(publish)); + pkid } else { 0 }; - publish.pkid = pkid; - self.push_and_async_notify(Request::Publish(publish)) - .await?; + self.notify_async().await?; Ok(pkid) } @@ -138,10 +155,17 @@ impl AsyncClient { qos: QoS, ) -> Result { let mut subscribe = Subscribe::new(topic.into(), qos); - let pkid = self.increment_pkid(); - subscribe.pkid = pkid; - self.push_and_async_notify(Request::Subscribe(subscribe)) - .await?; + let pkid = { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + subscribe.pkid = pkid; + request_buf.buf.push_back(Request::Subscribe(subscribe)); + pkid + }; + self.notify_async().await?; Ok(pkid) } @@ -152,9 +176,17 @@ impl AsyncClient { qos: QoS, ) -> Result { let mut subscribe = Subscribe::new(topic.into(), qos); - let pkid = self.increment_pkid(); - subscribe.pkid = pkid; - self.push_and_try_notify(Request::Subscribe(subscribe))?; + let pkid = { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + subscribe.pkid = pkid; + request_buf.buf.push_back(Request::Subscribe(subscribe)); + pkid + }; + self.try_notify()?; Ok(pkid) } @@ -164,10 +196,17 @@ impl AsyncClient { T: IntoIterator, { let mut subscribe = Subscribe::new_many(topics); - let pkid = self.increment_pkid(); - subscribe.pkid = pkid; - self.push_and_async_notify(Request::Subscribe(subscribe)) - .await?; + let pkid = { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + subscribe.pkid = pkid; + request_buf.buf.push_back(Request::Subscribe(subscribe)); + pkid + }; + self.notify_async().await?; Ok(pkid) } @@ -177,86 +216,97 @@ impl AsyncClient { T: IntoIterator, { let mut subscribe = Subscribe::new_many(topics); - let pkid = self.increment_pkid(); - subscribe.pkid = pkid; - self.push_and_try_notify(Request::Subscribe(subscribe))?; + let pkid = { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + subscribe.pkid = pkid; + request_buf.buf.push_back(Request::Subscribe(subscribe)); + pkid + }; + self.try_notify()?; Ok(pkid) } /// Sends a MQTT Unsubscribe to the eventloop pub async fn unsubscribe>(&mut self, topic: S) -> Result { let mut unsubscribe = Unsubscribe::new(topic.into()); - let pkid = self.increment_pkid(); - unsubscribe.pkid = pkid; - self.push_and_async_notify(Request::Unsubscribe(unsubscribe)) - .await?; + let pkid = { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + unsubscribe.pkid = pkid; + request_buf.buf.push_back(Request::Unsubscribe(unsubscribe)); + pkid + }; + self.notify_async().await?; Ok(pkid) } /// Sends a MQTT Unsubscribe to the eventloop pub fn try_unsubscribe>(&mut self, topic: S) -> Result { let mut unsubscribe = Unsubscribe::new(topic.into()); - let pkid = self.increment_pkid(); - unsubscribe.pkid = pkid; - self.push_and_try_notify(Request::Unsubscribe(unsubscribe))?; + let pkid = { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + unsubscribe.pkid = pkid; + request_buf.buf.push_back(Request::Unsubscribe(unsubscribe)); + pkid + }; + self.try_notify()?; Ok(pkid) } /// Sends a MQTT disconnect to the eventloop + #[inline] pub async fn disconnect(&mut self) -> Result<(), ClientError> { - self.push_and_async_notify(Request::Disconnect).await + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + request_buf.buf.push_back(Request::Disconnect); + self.notify_async().await } /// Sends a MQTT disconnect to the eventloop + #[inline] pub fn try_disconnect(&mut self) -> Result<(), ClientError> { - self.push_and_try_notify(Request::Disconnect) + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + request_buf.buf.push_back(Request::Disconnect); + self.try_notify() } - async fn push_and_async_notify(&self, request: Request) -> Result<(), ClientError> { - { - let mut request_buf = self.outgoing_buf.lock().unwrap(); - if request_buf.len() == self.outgoing_buf_capacity { - return Err(ClientError::RequestsFull); - } - request_buf.push_back(request); - } + #[inline] + async fn notify_async(&self) -> Result<(), ClientError> { if let Err(SendError(_)) = self.request_tx.send_async(()).await { return Err(ClientError::EventloopClosed); }; Ok(()) } - pub(crate) fn push_and_notify(&self, request: Request) -> Result<(), ClientError> { - let mut request_buf = self.outgoing_buf.lock().unwrap(); - if request_buf.len() == self.outgoing_buf_capacity { - return Err(ClientError::RequestsFull); - } - request_buf.push_back(request); + #[inline] + pub(crate) fn notify(&self) -> Result<(), ClientError> { if let Err(SendError(_)) = self.request_tx.send(()) { return Err(ClientError::EventloopClosed); }; Ok(()) } - fn push_and_try_notify(&self, request: Request) -> Result<(), ClientError> { - let mut request_buf = self.outgoing_buf.lock().unwrap(); - if request_buf.len() == self.outgoing_buf_capacity { - return Err(ClientError::RequestsFull); - } - request_buf.push_back(request); + #[inline] + fn try_notify(&self) -> Result<(), ClientError> { if let Err(TrySendError::Disconnected(_)) = self.request_tx.try_send(()) { return Err(ClientError::EventloopClosed); } Ok(()) } - - #[inline] - pub(crate) fn increment_pkid(&mut self) -> u16 { - self.pkid_counter = if self.pkid_counter == self.max_inflight { - 1 - } else { - self.pkid_counter + 1 - }; - self.pkid_counter - } } diff --git a/rumqttc/src/v5/client/syncclient.rs b/rumqttc/src/v5/client/syncclient.rs index ae94a1119..425ad4e03 100644 --- a/rumqttc/src/v5/client/syncclient.rs +++ b/rumqttc/src/v5/client/syncclient.rs @@ -42,8 +42,19 @@ impl Client { { let mut publish = Publish::new(topic, qos, payload); publish.retain = retain; - let pkid = publish.pkid; - self.client.push_and_notify(Request::Publish(publish))?; + let pkid = if qos != QoS::AtMostOnce { + let mut request_buf = self.client.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + publish.pkid = pkid; + request_buf.buf.push_back(Request::Publish(publish)); + pkid + } else { + 0 + }; + self.client.notify()?; Ok(pkid) } @@ -64,7 +75,12 @@ impl Client { /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. pub fn ack(&self, publish: &Publish) -> Result<(), ClientError> { if let Some(ack) = get_ack_req(publish.qos, publish.pkid) { - self.client.push_and_notify(ack)?; + let mut request_buf = self.client.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + request_buf.buf.push_back(ack); + self.client.notify()?; } Ok(()) } @@ -77,9 +93,19 @@ impl Client { /// Sends a MQTT Subscribe to the eventloop pub fn subscribe>(&mut self, topic: S, qos: QoS) -> Result { let mut subscribe = Subscribe::new(topic.into(), qos); - let pkid = self.client.increment_pkid(); - subscribe.pkid = pkid; - self.client.push_and_notify(Request::Subscribe(subscribe))?; + let pkid = if qos != QoS::AtMostOnce { + let mut request_buf = self.client.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + subscribe.pkid = pkid; + request_buf.buf.push_back(Request::Subscribe(subscribe)); + pkid + } else { + 0 + }; + self.client.notify()?; Ok(pkid) } @@ -98,9 +124,17 @@ impl Client { T: IntoIterator, { let mut subscribe = Subscribe::new_many(topics); - let pkid = self.client.increment_pkid(); - subscribe.pkid = pkid; - self.client.push_and_notify(Request::Subscribe(subscribe))?; + let pkid = { + let mut request_buf = self.client.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + subscribe.pkid = pkid; + request_buf.buf.push_back(Request::Subscribe(subscribe)); + pkid + }; + self.client.notify()?; Ok(pkid) } @@ -114,10 +148,17 @@ impl Client { /// Sends a MQTT Unsubscribe to the eventloop pub fn unsubscribe>(&mut self, topic: S) -> Result { let mut unsubscribe = Unsubscribe::new(topic.into()); - let pkid = self.client.increment_pkid(); - unsubscribe.pkid = pkid; - self.client - .push_and_notify(Request::Unsubscribe(unsubscribe))?; + let pkid = { + let mut request_buf = self.client.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + unsubscribe.pkid = pkid; + request_buf.buf.push_back(Request::Unsubscribe(unsubscribe)); + pkid + }; + self.client.notify()?; Ok(pkid) } @@ -128,7 +169,12 @@ impl Client { /// Sends a MQTT disconnect to the eventloop pub fn disconnect(&mut self) -> Result<(), ClientError> { - self.client.push_and_notify(Request::Disconnect) + let mut request_buf = self.client.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + request_buf.buf.push_back(Request::Disconnect); + self.client.notify() } /// Sends a MQTT disconnect to the eventloop diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index 90deb7c11..7725c373f 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -1,8 +1,8 @@ #[cfg(feature = "use-rustls")] use crate::v5::tls; use crate::v5::{ - framed::Network, packet::*, Incoming, MqttOptions, MqttState, Packet, Request, StateError, - Transport, + framed::Network, outgoing_buf::OutgoingBuf, packet::*, Incoming, MqttOptions, MqttState, + Packet, Request, StateError, Transport, }; #[cfg(feature = "websocket")] @@ -55,8 +55,8 @@ pub struct EventLoop { pub options: MqttOptions, /// Current state of the connection pub state: MqttState, - incoming_buf: Arc>>, - incoming_buf_cache: VecDeque, + outgoing_buf: Arc>, + outgoing_buf_cache: VecDeque, /// Request stream pub incoming_rx: Receiver<()>, /// Requests handle to send requests @@ -76,17 +76,18 @@ impl EventLoop { /// access and update `options`, `state` and `requests`. pub fn new(options: MqttOptions, cap: usize) -> EventLoop { let (incoming_tx, incoming_rx) = bounded(1); - let request_buf = Arc::new(Mutex::new(VecDeque::with_capacity(cap))); let pending = Vec::new(); let pending = pending.into_iter(); let max_inflight = options.inflight; let manual_acks = options.manual_acks; + let state = MqttState::new(max_inflight, manual_acks, cap); + let outgoing_buf = state.outgoing_buf.clone(); EventLoop { options, - state: MqttState::new(max_inflight, manual_acks, cap), - incoming_buf: request_buf, - incoming_buf_cache: VecDeque::with_capacity(cap), + state, + outgoing_buf, + outgoing_buf_cache: VecDeque::with_capacity(cap), incoming_tx, incoming_rx, pending, @@ -96,18 +97,11 @@ impl EventLoop { } /// Returns a handle to communicate with this eventloop + #[inline] pub fn handle(&self) -> Sender<()> { self.incoming_tx.clone() } - pub fn request_buf(&self) -> &Arc>> { - &self.incoming_buf - } - - pub fn sub_events_buf(&self) -> &Arc>> { - &self.state.incoming_buf - } - fn clean(&mut self) { self.network = None; self.keepalive_timeout = None; @@ -187,11 +181,11 @@ impl EventLoop { o = self.incoming_rx.recv_async(), if !inflight_full && !pending && !collision => match o { Ok(_request_notif) => { // swapping to avoid blocking the mutex - std::mem::swap(&mut self.incoming_buf_cache,&mut *self.incoming_buf.lock().unwrap()); - if self.incoming_buf_cache.is_empty() { + std::mem::swap(&mut self.outgoing_buf_cache, &mut self.outgoing_buf.lock().unwrap().buf); + if self.outgoing_buf_cache.is_empty() { continue; } - for request in self.incoming_buf_cache.drain(..) { + for request in self.outgoing_buf_cache.drain(..) { self.state.handle_outgoing_packet(request)?; } network.flush(&mut self.state.write).await?; diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index 615d18148..07230ab62 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -10,6 +10,7 @@ mod client; mod eventloop; mod framed; mod notifier; +mod outgoing_buf; #[allow(clippy::all)] mod packet; mod state; @@ -592,17 +593,13 @@ impl Debug for MqttOptions { pub async fn connect(options: MqttOptions, cap: usize) -> Result<(AsyncClient, Notifier), ()> { let mut eventloop = EventLoop::new(options, cap); - let outgoing_buf = eventloop.request_buf().clone(); + let outgoing_buf = eventloop.state.outgoing_buf.clone(); let incoming_buf = eventloop.state.incoming_buf.clone(); let incoming_buf_cache = VecDeque::with_capacity(cap); let request_tx = eventloop.handle(); - let max_inflight = eventloop.state.max_inflight; let client = AsyncClient { outgoing_buf, - outgoing_buf_capacity: cap, - pkid_counter: 0, - max_inflight, request_tx, }; diff --git a/rumqttc/src/v5/outgoing_buf.rs b/rumqttc/src/v5/outgoing_buf.rs new file mode 100644 index 000000000..d036dd0c8 --- /dev/null +++ b/rumqttc/src/v5/outgoing_buf.rs @@ -0,0 +1,34 @@ +use std::{ + collections::VecDeque, + sync::{Arc, Mutex}, +}; + +use crate::v5::Request; + +#[derive(Debug)] +pub struct OutgoingBuf { + pub(crate) buf: VecDeque, + pub(crate) pkid_counter: u16, + pub(crate) capacity: usize, +} + +impl OutgoingBuf { + #[inline] + pub fn new(cap: usize) -> Arc> { + Arc::new(Mutex::new(Self { + buf: VecDeque::with_capacity(cap), + pkid_counter: 0, + capacity: cap, + })) + } + + #[inline] + pub fn increment_pkid(&mut self) -> u16 { + self.pkid_counter = if self.pkid_counter == self.capacity as u16 { + 1 + } else { + self.pkid_counter + 1 + }; + self.pkid_counter + } +} diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 0b304f666..b6a8ebfd9 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -1,13 +1,14 @@ -use super::{packet::*, Incoming, Request}; - -use bytes::BytesMut; -use std::collections::VecDeque; use std::{ + collections::VecDeque, io, mem, sync::{Arc, Mutex}, time::Instant, }; +use bytes::BytesMut; + +use crate::v5::{outgoing_buf::OutgoingBuf, packet::*, Incoming, Request}; + /// Errors during state handling #[derive(Debug, thiserror::Error)] pub enum StateError { @@ -61,8 +62,6 @@ pub struct MqttState { last_outgoing: Instant, /// Number of outgoing inflight publishes pub(crate) inflight: u16, - /// Maximum number of allowed inflight - pub(crate) max_inflight: u16, /// Outgoing QoS 1, 2 publishes which aren't acked yet pub(crate) outgoing_pub: Vec>, /// Packet ids of released QoS 2 publishes @@ -76,6 +75,7 @@ pub struct MqttState { /// Indicates if acknowledgements should be send immediately pub manual_acks: bool, pub(crate) incoming_buf: Arc>>, + pub(crate) outgoing_buf: Arc>, } impl MqttState { @@ -89,7 +89,6 @@ impl MqttState { last_incoming: Instant::now(), last_outgoing: Instant::now(), inflight: 0, - max_inflight, // index 0 is wasted as 0 is not a valid packet id outgoing_pub: vec![None; max_inflight as usize + 1], outgoing_rel: vec![None; max_inflight as usize + 1], @@ -99,6 +98,7 @@ impl MqttState { write: BytesMut::with_capacity(10 * 1024), manual_acks, incoming_buf: Arc::new(Mutex::new(VecDeque::with_capacity(cap))), + outgoing_buf: OutgoingBuf::new(max_inflight as usize), } } @@ -137,6 +137,11 @@ impl MqttState { self.inflight } + #[inline] + pub fn cur_pkid(&self) -> u16 { + self.outgoing_buf.lock().unwrap().pkid_counter + } + /// Consolidates handling of all outgoing mqtt packet logic. Returns a packet which should /// be put on to the network by the eventloop pub fn handle_outgoing_packet(&mut self, request: Request) -> Result<(), StateError> { @@ -430,6 +435,11 @@ impl MqttState { Ok(pubrel) } + #[inline] + pub fn increment_pkid(&self) -> u16 { + self.outgoing_buf.lock().unwrap().increment_pkid() + } + ///// http://stackoverflow.com/questions/11115364/mqtt-messageid-practical-implementation ///// Packet ids are incremented till maximum set inflight messages and reset to 1 after that. ///// @@ -452,282 +462,298 @@ impl MqttState { #[cfg(test)] mod test { - // use super::{MqttState, StateError}; - // use crate::v5::{packet::*, Incoming, MqttOptions, Request}; - - // fn build_outgoing_publish(qos: QoS) -> Publish { - // let topic = "hello/world".to_owned(); - // let payload = vec![1, 2, 3]; - - // let mut publish = Publish::new(topic, QoS::AtLeastOnce, payload); - // publish.qos = qos; - // publish - // } - - // fn build_incoming_publish(qos: QoS, pkid: u16) -> Publish { - // let topic = "hello/world".to_owned(); - // let payload = vec![1, 2, 3]; - - // let mut publish = Publish::new(topic, QoS::AtLeastOnce, payload); - // publish.pkid = pkid; - // publish.qos = qos; - // publish - // } - - // fn build_mqttstate() -> MqttState { - // MqttState::new(100, false, 100) - // } - - // #[test] - // fn next_pkid_increments_as_expected() { - // let mqtt = build_mqttstate(); - - // for i in 1..=100 { - // let pkid = mqtt.increment_pkid(); - - // // loops between 0-99. % 100 == 0 implies border - // let expected = i % 100; - // if expected == 0 { - // break; - // } - - // assert_eq!(expected, pkid); - // } - // } - - // #[test] - // fn outgoing_publish_should_set_pkid_and_add_publish_to_queue() { - // let mut mqtt = build_mqttstate(); - - // // QoS0 Publish - // let publish = build_outgoing_publish(QoS::AtMostOnce); - - // // QoS 0 publish shouldn't be saved in queue - // mqtt.outgoing_publish(publish).unwrap(); - // assert_eq!(mqtt.cur_pkid(), 0); - // assert_eq!(mqtt.inflight, 0); - - // // QoS1 Publish - // let publish = build_outgoing_publish(QoS::AtLeastOnce); - - // // Packet id should be set and publish should be saved in queue - // mqtt.outgoing_publish(publish.clone()).unwrap(); - // assert_eq!(mqtt.cur_pkid(), 1); - // assert_eq!(mqtt.inflight, 1); - - // // Packet id should be incremented and publish should be saved in queue - // mqtt.outgoing_publish(publish).unwrap(); - // assert_eq!(mqtt.cur_pkid(), 2); - // assert_eq!(mqtt.inflight, 2); - - // // QoS1 Publish - // let publish = build_outgoing_publish(QoS::ExactlyOnce); - - // // Packet id should be set and publish should be saved in queue - // mqtt.outgoing_publish(publish.clone()).unwrap(); - // assert_eq!(mqtt.cur_pkid(), 3); - // assert_eq!(mqtt.inflight, 3); - - // // Packet id should be incremented and publish should be saved in queue - // mqtt.outgoing_publish(publish).unwrap(); - // assert_eq!(mqtt.cur_pkid(), 4); - // assert_eq!(mqtt.inflight, 4); - // } - - // #[test] - // fn incoming_publish_should_be_added_to_queue_correctly() { - // let mut mqtt = build_mqttstate(); - - // // QoS0, 1, 2 Publishes - // let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); - // let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); - // let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); - - // mqtt.handle_incoming_publish(&publish1).unwrap(); - // mqtt.handle_incoming_publish(&publish2).unwrap(); - // mqtt.handle_incoming_publish(&publish3).unwrap(); - - // let pkid = mqtt.incoming_pub[3].unwrap(); - - // // only qos2 publish should be add to queue - // assert_eq!(pkid, 3); - // } - - // #[test] - // fn incoming_publish_should_be_acked() { - // let mut mqtt = build_mqttstate(); - - // // QoS0, 1, 2 Publishes - // let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); - // let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); - // let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); - - // mqtt.handle_incoming_publish(&publish1).unwrap(); - // mqtt.handle_incoming_publish(&publish2).unwrap(); - // mqtt.handle_incoming_publish(&publish3).unwrap(); - // } - - // #[test] - // fn incoming_publish_should_not_be_acked_with_manual_acks() { - // let mut mqtt = build_mqttstate(); - // mqtt.manual_acks = true; - - // // QoS0, 1, 2 Publishes - // let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); - // let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); - // let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); - - // mqtt.handle_incoming_publish(&publish1).unwrap(); - // mqtt.handle_incoming_publish(&publish2).unwrap(); - // mqtt.handle_incoming_publish(&publish3).unwrap(); - - // let pkid = mqtt.incoming_pub[3].unwrap(); - // assert_eq!(pkid, 3); - - // assert!(mqtt.incoming_buf.lock().unwrap().is_empty()); - // } - - // #[test] - // fn incoming_qos2_publish_should_send_rec_to_network_and_publish_to_user() { - // let mut mqtt = build_mqttstate(); - // let publish = build_incoming_publish(QoS::ExactlyOnce, 1); - - // mqtt.handle_incoming_publish(&publish).unwrap(); - // let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); - // match packet { - // Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), - // _ => panic!("Invalid network request: {:?}", packet), - // } - // } - - // #[test] - // fn incoming_puback_should_remove_correct_publish_from_queue() { - // let mut mqtt = build_mqttstate(); - - // let publish1 = build_outgoing_publish(QoS::AtLeastOnce); - // let publish2 = build_outgoing_publish(QoS::ExactlyOnce); - - // mqtt.outgoing_publish(publish1).unwrap(); - // mqtt.outgoing_publish(publish2).unwrap(); - // assert_eq!(mqtt.inflight, 2); - - // mqtt.handle_incoming_puback(&PubAck::new(1)).unwrap(); - // assert_eq!(mqtt.inflight, 1); - - // mqtt.handle_incoming_puback(&PubAck::new(2)).unwrap(); - // assert_eq!(mqtt.inflight, 0); - - // assert!(mqtt.outgoing_pub[1].is_none()); - // assert!(mqtt.outgoing_pub[2].is_none()); - // } - - // #[test] - // fn incoming_pubrec_should_release_publish_from_queue_and_add_relid_to_rel_queue() { - // let mut mqtt = build_mqttstate(); - - // let publish1 = build_outgoing_publish(QoS::AtLeastOnce); - // let publish2 = build_outgoing_publish(QoS::ExactlyOnce); - - // let _publish_out = mqtt.outgoing_publish(publish1); - // let _publish_out = mqtt.outgoing_publish(publish2); - - // mqtt.handle_incoming_pubrec(&PubRec::new(2)).unwrap(); - // assert_eq!(mqtt.inflight, 2); - - // // check if the remaining element's pkid is 1 - // let backup = mqtt.outgoing_pub[1].clone(); - // assert_eq!(backup.unwrap().pkid, 1); - - // // check if the qos2 element's release pkid is 2 - // assert_eq!(mqtt.outgoing_rel[2].unwrap(), 2); - // } - - // #[test] - // fn incoming_pubrec_should_send_release_to_network_and_nothing_to_user() { - // let mut mqtt = build_mqttstate(); - - // let publish = build_outgoing_publish(QoS::ExactlyOnce); - // mqtt.outgoing_publish(publish).unwrap(); - // let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); - // match packet { - // Packet::Publish(publish) => assert_eq!(publish.pkid, 1), - // packet => panic!("Invalid network request: {:?}", packet), - // } - - // mqtt.handle_incoming_pubrec(&PubRec::new(1)).unwrap(); - // let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); - // match packet { - // Packet::PubRel(pubrel) => assert_eq!(pubrel.pkid, 1), - // packet => panic!("Invalid network request: {:?}", packet), - // } - // } - - // #[test] - // fn incoming_pubrel_should_send_comp_to_network_and_nothing_to_user() { - // let mut mqtt = build_mqttstate(); - // let publish = build_incoming_publish(QoS::ExactlyOnce, 1); - - // mqtt.handle_incoming_publish(&publish).unwrap(); - // let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); - // match packet { - // Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), - // packet => panic!("Invalid network request: {:?}", packet), - // } - - // mqtt.handle_incoming_pubrel(&PubRel::new(1)).unwrap(); - // let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); - // match packet { - // Packet::PubComp(pubcomp) => assert_eq!(pubcomp.pkid, 1), - // packet => panic!("Invalid network request: {:?}", packet), - // } - // } - - // #[test] - // fn incoming_pubcomp_should_release_correct_pkid_from_release_queue() { - // let mut mqtt = build_mqttstate(); - // let publish = build_outgoing_publish(QoS::ExactlyOnce); - - // mqtt.outgoing_publish(publish).unwrap(); - // mqtt.handle_incoming_pubrec(&PubRec::new(1)).unwrap(); - - // mqtt.handle_incoming_pubcomp(&PubComp::new(1)).unwrap(); - // assert_eq!(mqtt.inflight, 0); - // } - - // #[test] - // fn outgoing_ping_handle_should_throw_errors_for_no_pingresp() { - // let mut mqtt = build_mqttstate(); - // let mut opts = MqttOptions::new("test", "localhost", 1883); - // opts.set_keep_alive(std::time::Duration::from_secs(10)); - // mqtt.outgoing_ping().unwrap(); - - // // network activity other than pingresp - // let publish = build_outgoing_publish(QoS::AtLeastOnce); - // mqtt.handle_outgoing_packet(Request::Publish(publish)) - // .unwrap(); - // mqtt.handle_incoming_packet(Incoming::PubAck(PubAck::new(1))) - // .unwrap(); - - // // should throw error because we didn't get pingresp for previous ping - // match mqtt.outgoing_ping() { - // Ok(_) => panic!("Should throw pingresp await error"), - // Err(StateError::AwaitPingResp) => (), - // Err(e) => panic!("Should throw pingresp await error. Error = {:?}", e), - // } - // } - - // #[test] - // fn outgoing_ping_handle_should_succeed_if_pingresp_is_received() { - // let mut mqtt = build_mqttstate(); - - // let mut opts = MqttOptions::new("test", "localhost", 1883); - // opts.set_keep_alive(std::time::Duration::from_secs(10)); - - // // should ping - // mqtt.outgoing_ping().unwrap(); - // mqtt.handle_incoming_packet(Incoming::PingResp).unwrap(); - - // // should ping - // mqtt.outgoing_ping().unwrap(); - // } + use super::{MqttState, StateError}; + use crate::v5::{packet::*, Incoming, MqttOptions, Request}; + + fn build_outgoing_publish(qos: QoS) -> Publish { + let topic = "hello/world".to_owned(); + let payload = vec![1, 2, 3]; + + let mut publish = Publish::new(topic, QoS::AtLeastOnce, payload); + publish.qos = qos; + publish + } + + fn build_incoming_publish(qos: QoS, pkid: u16) -> Publish { + let topic = "hello/world".to_owned(); + let payload = vec![1, 2, 3]; + + let mut publish = Publish::new(topic, QoS::AtLeastOnce, payload); + publish.pkid = pkid; + publish.qos = qos; + publish + } + + fn build_mqttstate() -> MqttState { + MqttState::new(100, false, 100) + } + + #[test] + fn next_pkid_increments_as_expected() { + let mqtt = build_mqttstate(); + + for i in 1..=100 { + let pkid = mqtt.increment_pkid(); + + // loops between 0-99. % 100 == 0 implies border + let expected = i % 100; + if expected == 0 { + break; + } + + assert_eq!(expected, pkid); + } + } + + #[test] + fn outgoing_publish_should_set_pkid_and_add_publish_to_queue() { + let mut mqtt = build_mqttstate(); + + // QoS0 Publish + let mut publish = build_outgoing_publish(QoS::AtMostOnce); + publish.pkid = 1; + + // QoS 0 publish shouldn't be saved in queue + mqtt.outgoing_publish(publish).unwrap(); + assert_eq!(mqtt.cur_pkid(), 0); + assert_eq!(mqtt.inflight, 0); + + // QoS1 Publish + let mut publish = build_outgoing_publish(QoS::AtLeastOnce); + publish.pkid = 2; + + // Packet id should be set and publish should be saved in queue + mqtt.outgoing_publish(publish.clone()).unwrap(); + // cur_pkid == 0 as there is no client to update it + assert_eq!(mqtt.cur_pkid(), 0); + assert_eq!(mqtt.inflight, 1); + + // Packet id should be incremented and publish should be saved in queue + publish.pkid = 3; + mqtt.outgoing_publish(publish).unwrap(); + // cur_pkid == 0 as there is no client to update it + assert_eq!(mqtt.cur_pkid(), 0); + assert_eq!(mqtt.inflight, 2); + + // QoS1 Publish + let mut publish = build_outgoing_publish(QoS::ExactlyOnce); + publish.pkid = 4; + + // Packet id should be set and publish should be saved in queue + mqtt.outgoing_publish(publish.clone()).unwrap(); + // cur_pkid == 0 as there is no client to update it + assert_eq!(mqtt.cur_pkid(), 0); + assert_eq!(mqtt.inflight, 3); + + publish.pkid = 5; + // Packet id should be incremented and publish should be saved in queue + mqtt.outgoing_publish(publish).unwrap(); + // cur_pkid == 0 as there is no client to update it + assert_eq!(mqtt.cur_pkid(), 0); + assert_eq!(mqtt.inflight, 4); + } + + #[test] + fn incoming_publish_should_be_added_to_queue_correctly() { + let mut mqtt = build_mqttstate(); + + // QoS0, 1, 2 Publishes + let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); + let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); + let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); + + mqtt.handle_incoming_publish(&publish1).unwrap(); + mqtt.handle_incoming_publish(&publish2).unwrap(); + mqtt.handle_incoming_publish(&publish3).unwrap(); + + let pkid = mqtt.incoming_pub[3].unwrap(); + + // only qos2 publish should be add to queue + assert_eq!(pkid, 3); + } + + #[test] + fn incoming_publish_should_be_acked() { + let mut mqtt = build_mqttstate(); + + // QoS0, 1, 2 Publishes + let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); + let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); + let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); + + mqtt.handle_incoming_publish(&publish1).unwrap(); + mqtt.handle_incoming_publish(&publish2).unwrap(); + mqtt.handle_incoming_publish(&publish3).unwrap(); + } + + #[test] + fn incoming_publish_should_not_be_acked_with_manual_acks() { + let mut mqtt = build_mqttstate(); + mqtt.manual_acks = true; + + // QoS0, 1, 2 Publishes + let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); + let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); + let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); + + mqtt.handle_incoming_publish(&publish1).unwrap(); + mqtt.handle_incoming_publish(&publish2).unwrap(); + mqtt.handle_incoming_publish(&publish3).unwrap(); + + let pkid = mqtt.incoming_pub[3].unwrap(); + assert_eq!(pkid, 3); + + assert!(mqtt.incoming_buf.lock().unwrap().is_empty()); + } + + #[test] + fn incoming_qos2_publish_should_send_rec_to_network_and_publish_to_user() { + let mut mqtt = build_mqttstate(); + let publish = build_incoming_publish(QoS::ExactlyOnce, 1); + + mqtt.handle_incoming_publish(&publish).unwrap(); + let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + match packet { + Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), + _ => panic!("Invalid network request: {:?}", packet), + } + } + + #[test] + fn incoming_puback_should_remove_correct_publish_from_queue() { + let mut mqtt = build_mqttstate(); + + let mut publish1 = build_outgoing_publish(QoS::AtLeastOnce); + let mut publish2 = build_outgoing_publish(QoS::ExactlyOnce); + publish1.pkid = 1; + publish2.pkid = 2; + + mqtt.outgoing_publish(publish1).unwrap(); + mqtt.outgoing_publish(publish2).unwrap(); + assert_eq!(mqtt.inflight, 2); + + mqtt.handle_incoming_puback(&PubAck::new(1)).unwrap(); + assert_eq!(mqtt.inflight, 1); + + mqtt.handle_incoming_puback(&PubAck::new(2)).unwrap(); + assert_eq!(mqtt.inflight, 0); + + assert!(mqtt.outgoing_pub[1].is_none()); + assert!(mqtt.outgoing_pub[2].is_none()); + } + + #[test] + fn incoming_pubrec_should_release_publish_from_queue_and_add_relid_to_rel_queue() { + let mut mqtt = build_mqttstate(); + + let mut publish1 = build_outgoing_publish(QoS::AtLeastOnce); + let mut publish2 = build_outgoing_publish(QoS::ExactlyOnce); + publish1.pkid = 1; + publish2.pkid = 2; + + let _publish_out = mqtt.outgoing_publish(publish1); + let _publish_out = mqtt.outgoing_publish(publish2); + + mqtt.handle_incoming_pubrec(&PubRec::new(2)).unwrap(); + assert_eq!(mqtt.inflight, 2); + + // check if the remaining element's pkid is 1 + let backup = mqtt.outgoing_pub[1].clone(); + assert_eq!(backup.unwrap().pkid, 1); + + // check if the qos2 element's release pkid is 2 + assert_eq!(mqtt.outgoing_rel[2].unwrap(), 2); + } + + #[test] + fn incoming_pubrec_should_send_release_to_network_and_nothing_to_user() { + let mut mqtt = build_mqttstate(); + + let mut publish = build_outgoing_publish(QoS::ExactlyOnce); + publish.pkid = 1; + mqtt.outgoing_publish(publish).unwrap(); + let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + match packet { + Packet::Publish(publish) => assert_eq!(publish.pkid, 1), + packet => panic!("Invalid network request: {:?}", packet), + } + + mqtt.handle_incoming_pubrec(&PubRec::new(1)).unwrap(); + let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + match packet { + Packet::PubRel(pubrel) => assert_eq!(pubrel.pkid, 1), + packet => panic!("Invalid network request: {:?}", packet), + } + } + + #[test] + fn incoming_pubrel_should_send_comp_to_network_and_nothing_to_user() { + let mut mqtt = build_mqttstate(); + let publish = build_incoming_publish(QoS::ExactlyOnce, 1); + + mqtt.handle_incoming_publish(&publish).unwrap(); + let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + match packet { + Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), + packet => panic!("Invalid network request: {:?}", packet), + } + + mqtt.handle_incoming_pubrel(&PubRel::new(1)).unwrap(); + let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + match packet { + Packet::PubComp(pubcomp) => assert_eq!(pubcomp.pkid, 1), + packet => panic!("Invalid network request: {:?}", packet), + } + } + + #[test] + fn incoming_pubcomp_should_release_correct_pkid_from_release_queue() { + let mut mqtt = build_mqttstate(); + let mut publish = build_outgoing_publish(QoS::ExactlyOnce); + publish.pkid = 1; + + mqtt.outgoing_publish(publish).unwrap(); + mqtt.handle_incoming_pubrec(&PubRec::new(1)).unwrap(); + + mqtt.handle_incoming_pubcomp(&PubComp::new(1)).unwrap(); + assert_eq!(mqtt.inflight, 0); + } + + #[test] + fn outgoing_ping_handle_should_throw_errors_for_no_pingresp() { + let mut mqtt = build_mqttstate(); + let mut opts = MqttOptions::new("test", "localhost", 1883); + opts.set_keep_alive(std::time::Duration::from_secs(10)); + mqtt.outgoing_ping().unwrap(); + + // network activity other than pingresp + let mut publish = build_outgoing_publish(QoS::AtLeastOnce); + publish.pkid = 1; + mqtt.handle_outgoing_packet(Request::Publish(publish)) + .unwrap(); + mqtt.handle_incoming_packet(Incoming::PubAck(PubAck::new(1))) + .unwrap(); + + // should throw error because we didn't get pingresp for previous ping + match mqtt.outgoing_ping() { + Ok(_) => panic!("Should throw pingresp await error"), + Err(StateError::AwaitPingResp) => (), + Err(e) => panic!("Should throw pingresp await error. Error = {:?}", e), + } + } + + #[test] + fn outgoing_ping_handle_should_succeed_if_pingresp_is_received() { + let mut mqtt = build_mqttstate(); + + let mut opts = MqttOptions::new("test", "localhost", 1883); + opts.set_keep_alive(std::time::Duration::from_secs(10)); + + // should ping + mqtt.outgoing_ping().unwrap(); + mqtt.handle_incoming_packet(Incoming::PingResp).unwrap(); + + // should ping + mqtt.outgoing_ping().unwrap(); + } } From 19cbf6f78e79703ff90c7fb49325696d64e63641 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Sat, 23 Apr 2022 19:28:52 +0530 Subject: [PATCH 32/41] fix benchmark updates --- benchmarks/clients/rumqttasync.rs | 3 ++- benchmarks/clients/rumqttasyncqos0.rs | 3 ++- benchmarks/clients/rumqttsync.rs | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/benchmarks/clients/rumqttasync.rs b/benchmarks/clients/rumqttasync.rs index 6e4844ad8..d16cf466b 100644 --- a/benchmarks/clients/rumqttasync.rs +++ b/benchmarks/clients/rumqttasync.rs @@ -1,4 +1,5 @@ -use rumqttc::v4::*; +use rumqttc::{AsyncClient, Event, Incoming, MqttOptions, QoS}; + use std::error::Error; use std::time::{Duration, Instant}; diff --git a/benchmarks/clients/rumqttasyncqos0.rs b/benchmarks/clients/rumqttasyncqos0.rs index 081f19ae0..d00dd300a 100644 --- a/benchmarks/clients/rumqttasyncqos0.rs +++ b/benchmarks/clients/rumqttasyncqos0.rs @@ -1,4 +1,5 @@ -use rumqttc::v4::*; +use rumqttc::{AsyncClient, Event, Incoming, MqttOptions, QoS}; + use std::error::Error; use std::time::{Duration, Instant}; diff --git a/benchmarks/clients/rumqttsync.rs b/benchmarks/clients/rumqttsync.rs index fd94bd5df..20c3e5c20 100644 --- a/benchmarks/clients/rumqttsync.rs +++ b/benchmarks/clients/rumqttsync.rs @@ -1,4 +1,4 @@ -use rumqttc::v4::{Client, Event, Incoming, MqttOptions, QoS}; +use rumqttc::{Client, Event, Incoming, MqttOptions, QoS}; use std::error::Error; use std::thread; use std::time::{Duration, Instant}; From 662b16f9acf814dcbd5b8761fd7335fe50a0f36f Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Sat, 23 Apr 2022 19:34:58 +0530 Subject: [PATCH 33/41] Fix clippy warning --- rumqttc/src/v5/notifier.rs | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/rumqttc/src/v5/notifier.rs b/rumqttc/src/v5/notifier.rs index 4aad90d16..3ef34cb85 100644 --- a/rumqttc/src/v5/notifier.rs +++ b/rumqttc/src/v5/notifier.rs @@ -25,7 +25,16 @@ impl Notifier { } #[inline] - pub fn next(&mut self) -> Option { + pub fn iter(&mut self) -> NotifierIter<'_> { + NotifierIter(self) + } +} + +impl Iterator for Notifier { + type Item = Incoming; + + #[inline] + fn next(&mut self) -> Option { match self.incoming_buf_cache.pop_front() { None => { mem::swap( @@ -37,11 +46,6 @@ impl Notifier { val => val, } } - - #[inline] - pub fn iter(&mut self) -> NotifierIter<'_> { - NotifierIter(self) - } } pub struct NotifierIter<'a>(&'a mut Notifier); From e9762a358d4a10e22eb1b28e0ca2df9d46463757 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Wed, 11 May 2022 22:43:41 +0530 Subject: [PATCH 34/41] Fixes to more issues pointed out by clippy --- rumqttc/src/lib.rs | 7 ++++--- rumqttc/src/v5/client/asyncclient.rs | 20 ++++++++++++-------- rumqttc/src/v5/mod.rs | 7 ++++--- 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index a124168d3..b8a81a020 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -401,9 +401,10 @@ impl MqttOptions { /// Set number of seconds after which client should ping the broker /// if there is no other data exchange pub fn set_keep_alive(&mut self, duration: Duration) -> &mut Self { - if duration.as_secs() < 5 { - panic!("Keep alives should be >= 5 secs"); - } + assert!( + !(duration.as_secs() < 5), + "Keep alives should be >= 5 secs" + ); self.keep_alive = duration; self diff --git a/rumqttc/src/v5/client/asyncclient.rs b/rumqttc/src/v5/client/asyncclient.rs index fb39f8d6a..39a436217 100644 --- a/rumqttc/src/v5/client/asyncclient.rs +++ b/rumqttc/src/v5/client/asyncclient.rs @@ -96,11 +96,13 @@ impl AsyncClient { /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. pub async fn ack(&mut self, publish: &Publish) -> Result<(), ClientError> { if let Some(ack) = get_ack_req(publish.qos, publish.pkid) { - let mut request_buf = self.outgoing_buf.lock().unwrap(); - if request_buf.buf.len() == request_buf.capacity { - return Err(ClientError::RequestsFull); + { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + request_buf.buf.push_back(ack); } - request_buf.buf.push_back(ack); self.notify_async().await?; } Ok(()) @@ -267,11 +269,13 @@ impl AsyncClient { /// Sends a MQTT disconnect to the eventloop #[inline] pub async fn disconnect(&mut self) -> Result<(), ClientError> { - let mut request_buf = self.outgoing_buf.lock().unwrap(); - if request_buf.buf.len() == request_buf.capacity { - return Err(ClientError::RequestsFull); + { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + request_buf.buf.push_back(Request::Disconnect); } - request_buf.buf.push_back(Request::Disconnect); self.notify_async().await } diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index 07230ab62..a179eb8ea 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -278,9 +278,10 @@ impl MqttOptions { /// Set number of seconds after which client should ping the broker /// if there is no other data exchange pub fn set_keep_alive(&mut self, duration: Duration) -> &mut Self { - if duration.as_secs() < 5 { - panic!("Keep alives should be >= 5 secs"); - } + assert!( + !(duration.as_secs() < 5), + "Keep alives should be >= 5 secs" + ); self.keep_alive = duration; self From 8989f083dcfc32883a13fa4afc582af9ab7d5b6e Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Thu, 12 May 2022 14:06:04 +0530 Subject: [PATCH 35/41] Further clippy fixes --- rumqttc/src/lib.rs | 9 ++------- rumqttc/src/v5/mod.rs | 5 +---- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index 6ec938bce..76776b683 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -435,10 +435,7 @@ impl MqttOptions { /// Set number of seconds after which client should ping the broker /// if there is no other data exchange pub fn set_keep_alive(&mut self, duration: Duration) -> &mut Self { - assert!( - !(duration.as_secs() < 5), - "Keep alives should be >= 5 secs" - ); + assert!(duration.as_secs() >= 5, "Keep alives should be >= 5 secs"); self.keep_alive = duration; self @@ -521,9 +518,7 @@ impl MqttOptions { /// Set number of concurrent in flight messages pub fn set_inflight(&mut self, inflight: u16) -> &mut Self { - if inflight == 0 { - panic!("zero in flight is not allowed") - } + assert_eq!(inflight, 0, "zero in flight is not allowed"); self.inflight = inflight; self diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index a179eb8ea..bddc7013b 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -278,10 +278,7 @@ impl MqttOptions { /// Set number of seconds after which client should ping the broker /// if there is no other data exchange pub fn set_keep_alive(&mut self, duration: Duration) -> &mut Self { - assert!( - !(duration.as_secs() < 5), - "Keep alives should be >= 5 secs" - ); + assert!(duration.as_secs() >= 5, "Keep alives should be >= 5 secs"); self.keep_alive = duration; self From 500105d5549a94394550febf1613a455b8f0c337 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Thu, 12 May 2022 15:26:30 +0530 Subject: [PATCH 36/41] Fix wrong check --- rumqttc/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index 76776b683..d1cf4af68 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -518,7 +518,7 @@ impl MqttOptions { /// Set number of concurrent in flight messages pub fn set_inflight(&mut self, inflight: u16) -> &mut Self { - assert_eq!(inflight, 0, "zero in flight is not allowed"); + assert!(inflight != 0, "zero in flight is not allowed"); self.inflight = inflight; self From 9b068c0bab7b0e3649790f9db12a72c8beefc6a7 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Thu, 12 May 2022 15:48:16 +0530 Subject: [PATCH 37/41] Fix missing `Clone` --- rumqttc/src/v5/client/asyncclient.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rumqttc/src/v5/client/asyncclient.rs b/rumqttc/src/v5/client/asyncclient.rs index 39a436217..f7d3f2daa 100644 --- a/rumqttc/src/v5/client/asyncclient.rs +++ b/rumqttc/src/v5/client/asyncclient.rs @@ -12,7 +12,7 @@ use crate::v5::{ /// `AsyncClient` to communicate with MQTT `Eventloop` /// This is cloneable and can be used to asynchronously Publish, Subscribe. -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct AsyncClient { pub(crate) outgoing_buf: Arc>, pub(crate) request_tx: Sender<()>, From fc26ecf44549c584af28e42e894144df553bc387 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Thu, 12 May 2022 17:03:33 +0530 Subject: [PATCH 38/41] Remove seemingly unnecessary `mut` --- rumqttc/src/v5/client/asyncclient.rs | 34 +++++++++++----------------- 1 file changed, 13 insertions(+), 21 deletions(-) diff --git a/rumqttc/src/v5/client/asyncclient.rs b/rumqttc/src/v5/client/asyncclient.rs index f7d3f2daa..6c6e498bd 100644 --- a/rumqttc/src/v5/client/asyncclient.rs +++ b/rumqttc/src/v5/client/asyncclient.rs @@ -35,7 +35,7 @@ impl AsyncClient { /// Sends a MQTT Publish to the eventloop pub async fn publish( - &mut self, + &self, topic: S, qos: QoS, retain: bool, @@ -65,7 +65,7 @@ impl AsyncClient { /// Sends a MQTT Publish to the eventloop pub fn try_publish( - &mut self, + &self, topic: S, qos: QoS, retain: bool, @@ -94,7 +94,7 @@ impl AsyncClient { } /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. - pub async fn ack(&mut self, publish: &Publish) -> Result<(), ClientError> { + pub async fn ack(&self, publish: &Publish) -> Result<(), ClientError> { if let Some(ack) = get_ack_req(publish.qos, publish.pkid) { { let mut request_buf = self.outgoing_buf.lock().unwrap(); @@ -109,7 +109,7 @@ impl AsyncClient { } /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. - pub fn try_ack(&mut self, publish: &Publish) -> Result<(), ClientError> { + pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { if let Some(ack) = get_ack_req(publish.qos, publish.pkid) { let mut request_buf = self.outgoing_buf.lock().unwrap(); if request_buf.buf.len() == request_buf.capacity { @@ -123,7 +123,7 @@ impl AsyncClient { /// Sends a MQTT Publish to the eventloop pub async fn publish_bytes( - &mut self, + &self, topic: S, qos: QoS, retain: bool, @@ -151,11 +151,7 @@ impl AsyncClient { } /// Sends a MQTT Subscribe to the eventloop - pub async fn subscribe>( - &mut self, - topic: S, - qos: QoS, - ) -> Result { + pub async fn subscribe>(&self, topic: S, qos: QoS) -> Result { let mut subscribe = Subscribe::new(topic.into(), qos); let pkid = { let mut request_buf = self.outgoing_buf.lock().unwrap(); @@ -172,11 +168,7 @@ impl AsyncClient { } /// Sends a MQTT Subscribe to the eventloop - pub fn try_subscribe>( - &mut self, - topic: S, - qos: QoS, - ) -> Result { + pub fn try_subscribe>(&self, topic: S, qos: QoS) -> Result { let mut subscribe = Subscribe::new(topic.into(), qos); let pkid = { let mut request_buf = self.outgoing_buf.lock().unwrap(); @@ -193,7 +185,7 @@ impl AsyncClient { } /// Sends a MQTT Subscribe for multiple topics to the eventloop - pub async fn subscribe_many(&mut self, topics: T) -> Result + pub async fn subscribe_many(&self, topics: T) -> Result where T: IntoIterator, { @@ -213,7 +205,7 @@ impl AsyncClient { } /// Sends a MQTT Subscribe for multiple topics to the eventloop - pub fn try_subscribe_many(&mut self, topics: T) -> Result + pub fn try_subscribe_many(&self, topics: T) -> Result where T: IntoIterator, { @@ -233,7 +225,7 @@ impl AsyncClient { } /// Sends a MQTT Unsubscribe to the eventloop - pub async fn unsubscribe>(&mut self, topic: S) -> Result { + pub async fn unsubscribe>(&self, topic: S) -> Result { let mut unsubscribe = Unsubscribe::new(topic.into()); let pkid = { let mut request_buf = self.outgoing_buf.lock().unwrap(); @@ -250,7 +242,7 @@ impl AsyncClient { } /// Sends a MQTT Unsubscribe to the eventloop - pub fn try_unsubscribe>(&mut self, topic: S) -> Result { + pub fn try_unsubscribe>(&self, topic: S) -> Result { let mut unsubscribe = Unsubscribe::new(topic.into()); let pkid = { let mut request_buf = self.outgoing_buf.lock().unwrap(); @@ -268,7 +260,7 @@ impl AsyncClient { /// Sends a MQTT disconnect to the eventloop #[inline] - pub async fn disconnect(&mut self) -> Result<(), ClientError> { + pub async fn disconnect(&self) -> Result<(), ClientError> { { let mut request_buf = self.outgoing_buf.lock().unwrap(); if request_buf.buf.len() == request_buf.capacity { @@ -281,7 +273,7 @@ impl AsyncClient { /// Sends a MQTT disconnect to the eventloop #[inline] - pub fn try_disconnect(&mut self) -> Result<(), ClientError> { + pub fn try_disconnect(&self) -> Result<(), ClientError> { let mut request_buf = self.outgoing_buf.lock().unwrap(); if request_buf.buf.len() == request_buf.capacity { return Err(ClientError::RequestsFull); From a47707eb48fc5e410421c01120cde6c3167dad10 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Thu, 12 May 2022 17:06:30 +0530 Subject: [PATCH 39/41] Remove `mut` in syncclient and examples --- rumqttc/examples/asyncpubsub_v5.rs | 2 +- rumqttc/examples/syncpubsub_v5.rs | 2 +- rumqttc/src/v5/client/syncclient.rs | 26 +++++++++++--------------- 3 files changed, 13 insertions(+), 17 deletions(-) diff --git a/rumqttc/examples/asyncpubsub_v5.rs b/rumqttc/examples/asyncpubsub_v5.rs index b398a5e3f..b0b7a6698 100644 --- a/rumqttc/examples/asyncpubsub_v5.rs +++ b/rumqttc/examples/asyncpubsub_v5.rs @@ -24,7 +24,7 @@ async fn main() -> Result<(), Box> { } } -async fn requests(mut client: AsyncClient) { +async fn requests(client: AsyncClient) { client .subscribe("hello/world", QoS::AtMostOnce) .await diff --git a/rumqttc/examples/syncpubsub_v5.rs b/rumqttc/examples/syncpubsub_v5.rs index 857ab4789..12957aaa9 100644 --- a/rumqttc/examples/syncpubsub_v5.rs +++ b/rumqttc/examples/syncpubsub_v5.rs @@ -21,7 +21,7 @@ fn main() { println!("Done with the stream!!"); } -fn publish(mut client: Client) { +fn publish(client: Client) { client.subscribe("hello/+/world", QoS::AtMostOnce).unwrap(); for i in 0..10 { let payload = vec![1; i as usize]; diff --git a/rumqttc/src/v5/client/syncclient.rs b/rumqttc/src/v5/client/syncclient.rs index 425ad4e03..420b452aa 100644 --- a/rumqttc/src/v5/client/syncclient.rs +++ b/rumqttc/src/v5/client/syncclient.rs @@ -30,7 +30,7 @@ impl Client { /// Sends a MQTT Publish to the eventloop pub fn publish( - &mut self, + &self, topic: S, qos: QoS, retain: bool, @@ -59,7 +59,7 @@ impl Client { } pub fn try_publish( - &mut self, + &self, topic: S, qos: QoS, retain: bool, @@ -86,12 +86,12 @@ impl Client { } /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. - pub fn try_ack(&mut self, publish: &Publish) -> Result<(), ClientError> { + pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { self.client.try_ack(publish) } /// Sends a MQTT Subscribe to the eventloop - pub fn subscribe>(&mut self, topic: S, qos: QoS) -> Result { + pub fn subscribe>(&self, topic: S, qos: QoS) -> Result { let mut subscribe = Subscribe::new(topic.into(), qos); let pkid = if qos != QoS::AtMostOnce { let mut request_buf = self.client.outgoing_buf.lock().unwrap(); @@ -110,16 +110,12 @@ impl Client { } /// Sends a MQTT Subscribe to the eventloop - pub fn try_subscribe>( - &mut self, - topic: S, - qos: QoS, - ) -> Result { + pub fn try_subscribe>(&self, topic: S, qos: QoS) -> Result { self.client.try_subscribe(topic, qos) } /// Sends a MQTT Subscribe for multiple topics to the eventloop - pub fn subscribe_many(&mut self, topics: T) -> Result + pub fn subscribe_many(&self, topics: T) -> Result where T: IntoIterator, { @@ -138,7 +134,7 @@ impl Client { Ok(pkid) } - pub fn try_subscribe_many(&mut self, topics: T) -> Result + pub fn try_subscribe_many(&self, topics: T) -> Result where T: IntoIterator, { @@ -146,7 +142,7 @@ impl Client { } /// Sends a MQTT Unsubscribe to the eventloop - pub fn unsubscribe>(&mut self, topic: S) -> Result { + pub fn unsubscribe>(&self, topic: S) -> Result { let mut unsubscribe = Unsubscribe::new(topic.into()); let pkid = { let mut request_buf = self.client.outgoing_buf.lock().unwrap(); @@ -163,12 +159,12 @@ impl Client { } /// Sends a MQTT Unsubscribe to the eventloop - pub fn try_unsubscribe>(&mut self, topic: S) -> Result { + pub fn try_unsubscribe>(&self, topic: S) -> Result { self.client.try_unsubscribe(topic) } /// Sends a MQTT disconnect to the eventloop - pub fn disconnect(&mut self) -> Result<(), ClientError> { + pub fn disconnect(&self) -> Result<(), ClientError> { let mut request_buf = self.client.outgoing_buf.lock().unwrap(); if request_buf.buf.len() == request_buf.capacity { return Err(ClientError::RequestsFull); @@ -178,7 +174,7 @@ impl Client { } /// Sends a MQTT disconnect to the eventloop - pub fn try_disconnect(&mut self) -> Result<(), ClientError> { + pub fn try_disconnect(&self) -> Result<(), ClientError> { self.client.try_disconnect() } } From adac2066c26df9318b7a84e9e10522ad3bc539a8 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Thu, 12 May 2022 17:15:38 +0530 Subject: [PATCH 40/41] clippy suggestion benchmarks/v5 --- benchmarks/simplerouter/src/protocol/v5.rs | 140 ++++++++++----------- 1 file changed, 69 insertions(+), 71 deletions(-) diff --git a/benchmarks/simplerouter/src/protocol/v5.rs b/benchmarks/simplerouter/src/protocol/v5.rs index c63fe816e..ab58f8ebd 100644 --- a/benchmarks/simplerouter/src/protocol/v5.rs +++ b/benchmarks/simplerouter/src/protocol/v5.rs @@ -135,16 +135,16 @@ pub(crate) mod connect { len } - fn read(connect_flags: u8, mut bytes: &mut Bytes) -> Result, Error> { + fn read(connect_flags: u8, bytes: &mut Bytes) -> Result, Error> { let last_will = match connect_flags & 0b100 { 0 if (connect_flags & 0b0011_1000) != 0 => { return Err(Error::IncorrectPacketFormat); } 0 => None, _ => { - let will_topic = read_mqtt_bytes(&mut bytes)?; + let will_topic = read_mqtt_bytes(bytes)?; let will_topic = std::str::from_utf8(&will_topic)?.to_owned(); - let will_message = read_mqtt_bytes(&mut bytes)?; + let will_message = read_mqtt_bytes(bytes)?; let will_qos = qos((connect_flags & 0b11000) >> 3)?; Some(LastWill { topic: will_topic, @@ -173,11 +173,11 @@ pub(crate) mod connect { } } - fn read(connect_flags: u8, mut bytes: &mut Bytes) -> Result, Error> { + fn read(connect_flags: u8, bytes: &mut Bytes) -> Result, Error> { let username = match connect_flags & 0b1000_0000 { 0 => String::new(), _ => { - let username = read_mqtt_bytes(&mut bytes)?; + let username = read_mqtt_bytes(bytes)?; std::str::from_utf8(&username)?.to_owned() } }; @@ -185,7 +185,7 @@ pub(crate) mod connect { let password = match connect_flags & 0b0100_0000 { 0 => String::new(), _ => { - let password = read_mqtt_bytes(&mut bytes)?; + let password = read_mqtt_bytes(bytes)?; std::str::from_utf8(&password)?.to_owned() } }; @@ -247,7 +247,7 @@ pub(crate) mod connect { } } - fn read(mut bytes: &mut Bytes) -> Result, Error> { + fn read(bytes: &mut Bytes) -> Result, Error> { let mut session_expiry_interval = None; let mut receive_maximum = None; let mut max_packet_size = None; @@ -267,49 +267,49 @@ pub(crate) mod connect { let mut cursor = 0; // read until cursor reaches property length. properties_len = 0 will skip this loop while cursor < properties_len { - let prop = read_u8(&mut bytes)?; + let prop = read_u8(bytes)?; cursor += 1; match property(prop)? { PropertyType::SessionExpiryInterval => { - session_expiry_interval = Some(read_u32(&mut bytes)?); + session_expiry_interval = Some(read_u32(bytes)?); cursor += 4; } PropertyType::ReceiveMaximum => { - receive_maximum = Some(read_u16(&mut bytes)?); + receive_maximum = Some(read_u16(bytes)?); cursor += 2; } PropertyType::MaximumPacketSize => { - max_packet_size = Some(read_u32(&mut bytes)?); + max_packet_size = Some(read_u32(bytes)?); cursor += 4; } PropertyType::TopicAliasMaximum => { - topic_alias_max = Some(read_u16(&mut bytes)?); + topic_alias_max = Some(read_u16(bytes)?); cursor += 2; } PropertyType::RequestResponseInformation => { - request_response_info = Some(read_u8(&mut bytes)?); + request_response_info = Some(read_u8(bytes)?); cursor += 1; } PropertyType::RequestProblemInformation => { - request_problem_info = Some(read_u8(&mut bytes)?); + request_problem_info = Some(read_u8(bytes)?); cursor += 1; } PropertyType::UserProperty => { - let key = read_mqtt_bytes(&mut bytes)?; + let key = read_mqtt_bytes(bytes)?; let key = std::str::from_utf8(&key)?.to_owned(); - let value = read_mqtt_bytes(&mut bytes)?; + let value = read_mqtt_bytes(bytes)?; let value = std::str::from_utf8(&value)?.to_owned(); cursor += 2 + key.len() + 2 + value.len(); user_properties.push((key, value)); } PropertyType::AuthenticationMethod => { - let method = read_mqtt_bytes(&mut bytes)?; + let method = read_mqtt_bytes(bytes)?; let method = std::str::from_utf8(&method)?.to_owned(); cursor += 2 + method.len(); authentication_method = Some(method); } PropertyType::AuthenticationData => { - let data = read_mqtt_bytes(&mut bytes)?; + let data = read_mqtt_bytes(bytes)?; cursor += 2 + data.len(); authentication_data = Some(data); } @@ -541,23 +541,23 @@ pub(crate) mod connack { pub fn len(&self) -> usize { let mut len = 0; - if let Some(_) = &self.session_expiry_interval { + if self.session_expiry_interval.is_some() { len += 1 + 4; } - if let Some(_) = &self.receive_max { + if self.receive_max.is_some() { len += 1 + 2; } - if let Some(_) = &self.max_qos { + if self.max_qos.is_some() { len += 1 + 1; } - if let Some(_) = &self.retain_available { + if self.retain_available.is_some() { len += 1 + 1; } - if let Some(_) = &self.max_packet_size { + if self.max_packet_size.is_some() { len += 1 + 4; } @@ -565,7 +565,7 @@ pub(crate) mod connack { len += 1 + 2 + id.len(); } - if let Some(_) = &self.topic_alias_max { + if self.topic_alias_max.is_some() { len += 1 + 2; } @@ -577,19 +577,19 @@ pub(crate) mod connack { len += 1 + 2 + key.len() + 2 + value.len(); } - if let Some(_) = &self.wildcard_subscription_available { + if self.wildcard_subscription_available.is_some() { len += 1 + 1; } - if let Some(_) = &self.subscription_identifiers_available { + if self.subscription_identifiers_available.is_some() { len += 1 + 1; } - if let Some(_) = &self.shared_subscription_available { + if self.shared_subscription_available.is_some() { len += 1 + 1; } - if let Some(_) = &self.server_keep_alive { + if self.server_keep_alive.is_some() { len += 1 + 2; } @@ -612,7 +612,7 @@ pub(crate) mod connack { len } - pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + pub fn extract(bytes: &mut Bytes) -> Result, Error> { let mut session_expiry_interval = None; let mut receive_max = None; let mut max_qos = None; @@ -640,90 +640,90 @@ pub(crate) mod connack { let mut cursor = 0; // read until cursor reaches property length. properties_len = 0 will skip this loop while cursor < properties_len { - let prop = read_u8(&mut bytes)?; + let prop = read_u8(bytes)?; cursor += 1; match property(prop)? { PropertyType::SessionExpiryInterval => { - session_expiry_interval = Some(read_u32(&mut bytes)?); + session_expiry_interval = Some(read_u32(bytes)?); cursor += 4; } PropertyType::ReceiveMaximum => { - receive_max = Some(read_u16(&mut bytes)?); + receive_max = Some(read_u16(bytes)?); cursor += 2; } PropertyType::MaximumQos => { - max_qos = Some(read_u8(&mut bytes)?); + max_qos = Some(read_u8(bytes)?); cursor += 1; } PropertyType::RetainAvailable => { - retain_available = Some(read_u8(&mut bytes)?); + retain_available = Some(read_u8(bytes)?); cursor += 1; } PropertyType::AssignedClientIdentifier => { - let bytes = read_mqtt_bytes(&mut bytes)?; + let bytes = read_mqtt_bytes(bytes)?; let id = std::str::from_utf8(&bytes)?.to_owned(); cursor += 2 + id.len(); assigned_client_identifier = Some(id); } PropertyType::MaximumPacketSize => { - max_packet_size = Some(read_u32(&mut bytes)?); + max_packet_size = Some(read_u32(bytes)?); cursor += 4; } PropertyType::TopicAliasMaximum => { - topic_alias_max = Some(read_u16(&mut bytes)?); + topic_alias_max = Some(read_u16(bytes)?); cursor += 2; } PropertyType::ReasonString => { - let reason = read_mqtt_bytes(&mut bytes)?; + let reason = read_mqtt_bytes(bytes)?; let reason = std::str::from_utf8(&reason)?.to_owned(); cursor += 2 + reason.len(); reason_string = Some(reason); } PropertyType::UserProperty => { - let key = read_mqtt_bytes(&mut bytes)?; + let key = read_mqtt_bytes(bytes)?; let key = std::str::from_utf8(&key)?.to_owned(); - let value = read_mqtt_bytes(&mut bytes)?; + let value = read_mqtt_bytes(bytes)?; let value = std::str::from_utf8(&value)?.to_owned(); cursor += 2 + key.len() + 2 + value.len(); user_properties.push((key, value)); } PropertyType::WildcardSubscriptionAvailable => { - wildcard_subscription_available = Some(read_u8(&mut bytes)?); + wildcard_subscription_available = Some(read_u8(bytes)?); cursor += 1; } PropertyType::SubscriptionIdentifierAvailable => { - subscription_identifiers_available = Some(read_u8(&mut bytes)?); + subscription_identifiers_available = Some(read_u8(bytes)?); cursor += 1; } PropertyType::SharedSubscriptionAvailable => { - shared_subscription_available = Some(read_u8(&mut bytes)?); + shared_subscription_available = Some(read_u8(bytes)?); cursor += 1; } PropertyType::ServerKeepAlive => { - server_keep_alive = Some(read_u16(&mut bytes)?); + server_keep_alive = Some(read_u16(bytes)?); cursor += 2; } PropertyType::ResponseInformation => { - let info = read_mqtt_bytes(&mut bytes)?; + let info = read_mqtt_bytes(bytes)?; let info = std::str::from_utf8(&info)?.to_owned(); cursor += 2 + info.len(); response_information = Some(info); } PropertyType::ServerReference => { - let bytes = read_mqtt_bytes(&mut bytes)?; + let bytes = read_mqtt_bytes(bytes)?; let reference = std::str::from_utf8(&bytes)?.to_owned(); cursor += 2 + reference.len(); server_reference = Some(reference); } PropertyType::AuthenticationMethod => { - let bytes = read_mqtt_bytes(&mut bytes)?; + let bytes = read_mqtt_bytes(bytes)?; let method = std::str::from_utf8(&bytes)?.to_owned(); cursor += 2 + method.len(); authentication_method = Some(method); } PropertyType::AuthenticationData => { - let data = read_mqtt_bytes(&mut bytes)?; + let data = read_mqtt_bytes(bytes)?; cursor += 2 + data.len(); authentication_data = Some(data); } @@ -926,7 +926,7 @@ pub(crate) mod publish { // FIXME: Remove indexes and use get method let stream = &self.raw[self.fixed_header.fixed_header_len..]; - let topic_len = view_u16(&stream)? as usize; + let topic_len = view_u16(stream)? as usize; let stream = &stream[2..]; let topic = view_str(stream, topic_len)?; @@ -935,8 +935,7 @@ pub(crate) mod publish { 0 => 0, 1 => { let stream = &stream[topic_len..]; - let pkid = view_u16(stream)?; - pkid + view_u16(stream)? } v => return Err(Error::InvalidQoS(v)), }; @@ -951,7 +950,7 @@ pub(crate) mod publish { pub fn view_topic(&self) -> Result<&str, Error> { // FIXME: Remove indexes let stream = &self.raw[self.fixed_header.fixed_header_len..]; - let topic_len = view_u16(&stream)? as usize; + let topic_len = view_u16(stream)? as usize; let stream = &stream[2..]; let topic = view_str(stream, topic_len)?; @@ -1030,7 +1029,7 @@ pub(crate) mod publish { buffer.put_u16(pkid); } - buffer.extend_from_slice(&payload); + buffer.extend_from_slice(payload); // TODO: Returned length is wrong in other packets. Fix it Ok(1 + count + len) @@ -1164,7 +1163,7 @@ pub(crate) mod puback { len } - pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + pub fn extract(bytes: &mut Bytes) -> Result, Error> { let mut reason_string = None; let mut user_properties = Vec::new(); @@ -1177,20 +1176,20 @@ pub(crate) mod puback { let mut cursor = 0; // read until cursor reaches property length. properties_len = 0 will skip this loop while cursor < properties_len { - let prop = read_u8(&mut bytes)?; + let prop = read_u8(bytes)?; cursor += 1; match property(prop)? { PropertyType::ReasonString => { - let bytes = read_mqtt_bytes(&mut bytes)?; + let bytes = read_mqtt_bytes(bytes)?; let reason = std::str::from_utf8(&bytes)?.to_owned(); cursor += 2 + reason.len(); reason_string = Some(reason); } PropertyType::UserProperty => { - let key = read_mqtt_bytes(&mut bytes)?; + let key = read_mqtt_bytes(bytes)?; let key = std::str::from_utf8(&key)?.to_owned(); - let value = read_mqtt_bytes(&mut bytes)?; + let value = read_mqtt_bytes(bytes)?; let value = std::str::from_utf8(&value)?.to_owned(); cursor += 2 + key.len() + 2 + value.len(); user_properties.push((key, value)); @@ -1264,8 +1263,7 @@ pub(crate) mod subscribe { retain_forward_rule: RetainForwardRule::OnEverySubscribe, }; - let mut filters = Vec::new(); - filters.push(filter); + let filters = vec![filter]; Subscribe { pkid: 0, filters, @@ -1318,10 +1316,10 @@ pub(crate) mod subscribe { let requested_qos = options & 0b0000_0011; let nolocal = options >> 2 & 0b0000_0001; - let nolocal = if nolocal == 0 { false } else { true }; + let nolocal = !(nolocal == 0); let preserve_retain = options >> 3 & 0b0000_0001; - let preserve_retain = if preserve_retain == 0 { false } else { true }; + let preserve_retain = !(preserve_retain == 0); let retain_forward_rule = (options >> 4) & 0b0000_0011; let retain_forward_rule = match retain_forward_rule { @@ -1461,7 +1459,7 @@ pub(crate) mod subscribe { len } - pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + pub fn extract(bytes: &mut Bytes) -> Result, Error> { let mut id = None; let mut user_properties = Vec::new(); @@ -1475,7 +1473,7 @@ pub(crate) mod subscribe { let mut cursor = 0; // read until cursor reaches property length. properties_len = 0 will skip this loop while cursor < properties_len { - let prop = read_u8(&mut bytes)?; + let prop = read_u8(bytes)?; cursor += 1; match property(prop)? { @@ -1487,9 +1485,9 @@ pub(crate) mod subscribe { id = Some(sub_id) } PropertyType::UserProperty => { - let key = read_mqtt_bytes(&mut bytes)?; + let key = read_mqtt_bytes(bytes)?; let key = std::str::from_utf8(&key)?.to_owned(); - let value = read_mqtt_bytes(&mut bytes)?; + let value = read_mqtt_bytes(bytes)?; let value = std::str::from_utf8(&value)?.to_owned(); cursor += 2 + key.len() + 2 + value.len(); user_properties.push((key, value)); @@ -1679,7 +1677,7 @@ pub(crate) mod suback { len } - pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + pub fn extract(bytes: &mut Bytes) -> Result, Error> { let mut reason_string = None; let mut user_properties = Vec::new(); @@ -1692,20 +1690,20 @@ pub(crate) mod suback { let mut cursor = 0; // read until cursor reaches property length. properties_len = 0 will skip this loop while cursor < properties_len { - let prop = read_u8(&mut bytes)?; + let prop = read_u8(bytes)?; cursor += 1; match property(prop)? { PropertyType::ReasonString => { - let bytes = read_mqtt_bytes(&mut bytes)?; + let bytes = read_mqtt_bytes(bytes)?; let reason = std::str::from_utf8(&bytes)?.to_owned(); cursor += 2 + reason.len(); reason_string = Some(reason); } PropertyType::UserProperty => { - let key = read_mqtt_bytes(&mut bytes)?; + let key = read_mqtt_bytes(bytes)?; let key = std::str::from_utf8(&key)?.to_owned(); - let value = read_mqtt_bytes(&mut bytes)?; + let value = read_mqtt_bytes(bytes)?; let value = std::str::from_utf8(&value)?.to_owned(); cursor += 2 + key.len() + 2 + value.len(); user_properties.push((key, value)); From 32a73ce7d35075f6114fffd3d26b07063be67b69 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Thu, 12 May 2022 17:35:02 +0530 Subject: [PATCH 41/41] Fix clippy issue --- rumqttc/examples/async_manual_acks_v5.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rumqttc/examples/async_manual_acks_v5.rs b/rumqttc/examples/async_manual_acks_v5.rs index fde55ad81..079049e01 100644 --- a/rumqttc/examples/async_manual_acks_v5.rs +++ b/rumqttc/examples/async_manual_acks_v5.rs @@ -69,7 +69,7 @@ async fn main() -> Result<(), Box> { // } } -async fn requests(mut client: AsyncClient) { +async fn requests(client: &AsyncClient) { for i in 1..=10 { client .publish("hello/world", QoS::AtLeastOnce, false, vec![1; i])