From 99b2889db330e39f695b844b670da500f5c78ef1 Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Fri, 11 Feb 2022 17:24:32 +0530 Subject: [PATCH 01/38] rumqttc: move old code to v4 module Signed-off-by: Abhik Jain --- Cargo.lock | 109 ++-- benchmarks/clients/mesh.rs | 4 +- benchmarks/clients/rumqttasync.rs | 2 +- benchmarks/clients/rumqttasyncqos0.rs | 2 +- benchmarks/clients/rumqttsync.rs | 2 +- rumqttc/examples/async_manual_acks.rs | 2 +- rumqttc/examples/asyncpubsub.rs | 2 +- rumqttc/examples/syncpubsub.rs | 2 +- rumqttc/examples/tls.rs | 2 +- rumqttc/examples/tls2.rs | 2 +- rumqttc/src/lib.rs | 708 +------------------------- rumqttc/src/{ => v4}/client.rs | 2 +- rumqttc/src/{ => v4}/eventloop.rs | 6 +- rumqttc/src/{ => v4}/framed.rs | 2 +- rumqttc/src/v4/mod.rs | 706 +++++++++++++++++++++++++ rumqttc/src/{ => v4}/state.rs | 4 +- rumqttc/src/{ => v4}/tls.rs | 2 +- rumqttc/src/v5/mod.rs | 0 rumqttc/tests/broker.rs | 2 +- rumqttc/tests/reliability.rs | 2 +- 20 files changed, 807 insertions(+), 756 deletions(-) rename rumqttc/src/{ => v4}/client.rs (99%) rename rumqttc/src/{ => v4}/eventloop.rs (98%) rename rumqttc/src/{ => v4}/framed.rs (98%) create mode 100644 rumqttc/src/v4/mod.rs rename rumqttc/src/{ => v4}/state.rs (99%) rename rumqttc/src/{ => v4}/tls.rs (98%) create mode 100644 rumqttc/src/v5/mod.rs diff --git a/Cargo.lock b/Cargo.lock index 8dbaf5cef..789effd4e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -103,18 +103,18 @@ dependencies = [ [[package]] name = "async-tungstenite" -version = "0.13.1" +version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07b30ef0ea5c20caaa54baea49514a206308989c68be7ecd86c7f956e4da6378" +checksum = "5682ea0913e5c20780fe5785abacb85a411e7437bf52a1bedb93ddb3972cb8dd" dependencies = [ "futures-io", "futures-util", "log", "pin-project-lite 0.2.7", + "rustls-native-certs", "tokio 1.8.2", - "tokio-rustls", - "tungstenite 0.13.0", - "webpki-roots", + "tokio-rustls 0.23.2", + "tungstenite 0.16.0", ] [[package]] @@ -2041,14 +2041,15 @@ dependencies = [ "mqttbytes", "pollster", "pretty_env_logger", - "rustls", + "rustls 0.20.2", "rustls-native-certs", + "rustls-pemfile 0.3.0", "serde", "thiserror", "tokio 1.8.2", - "tokio-rustls", + "tokio-rustls 0.23.2", "url", - "webpki", + "webpki 0.22.0", "ws_stream_tungstenite", ] @@ -2070,7 +2071,7 @@ dependencies = [ "thiserror", "tokio 1.8.2", "tokio-native-tls", - "tokio-rustls", + "tokio-rustls 0.22.0", "warp", ] @@ -2161,22 +2162,52 @@ dependencies = [ "base64 0.13.0", "log", "ring", - "sct", - "webpki", + "sct 0.6.1", + "webpki 0.21.4", +] + +[[package]] +name = "rustls" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d37e5e2290f3e040b594b1a9e04377c2c671f1a1cfd9bfdef82106ac1c113f84" +dependencies = [ + "log", + "ring", + "sct 0.7.0", + "webpki 0.22.0", ] [[package]] name = "rustls-native-certs" -version = "0.5.0" +version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a07b7c1885bd8ed3831c289b7870b13ef46fe0e856d288c30d9cc17d75a2092" +checksum = "5ca9ebdfa27d3fc180e42879037b5338ab1c040c06affd00d8338598e7800943" dependencies = [ "openssl-probe", - "rustls", + "rustls-pemfile 0.2.1", "schannel", "security-framework", ] +[[package]] +name = "rustls-pemfile" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5eebeaeb360c87bfb72e84abdb3447159c0eaececf1bef2aecd65a8be949d1c9" +dependencies = [ + "base64 0.13.0", +] + +[[package]] +name = "rustls-pemfile" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ee86d63972a7c661d1536fefe8c3c8407321c3df668891286de28abcd087360" +dependencies = [ + "base64 0.13.0", +] + [[package]] name = "ryu" version = "1.0.5" @@ -2230,6 +2261,16 @@ dependencies = [ "untrusted", ] +[[package]] +name = "sct" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d53dcdb7c9f8158937a7981b48accfd39a43af418591a5d008c7b22b5e1b7ca4" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "security-framework" version = "2.3.1" @@ -2627,9 +2668,20 @@ version = "0.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bc6844de72e57df1980054b38be3a9f4702aba4858be64dd700181a8a6d0e1b6" dependencies = [ - "rustls", + "rustls 0.19.1", + "tokio 1.8.2", + "webpki 0.21.4", +] + +[[package]] +name = "tokio-rustls" +version = "0.23.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a27d5f2b839802bd8267fa19b0530f5a08b9c08cd417976be2a65d130fe1c11b" +dependencies = [ + "rustls 0.20.2", "tokio 1.8.2", - "webpki", + "webpki 0.22.0", ] [[package]] @@ -2733,25 +2785,23 @@ dependencies = [ [[package]] name = "tungstenite" -version = "0.13.0" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5fe8dada8c1a3aeca77d6b51a4f1314e0f4b8e438b7b1b71e3ddaca8080e4093" +checksum = "6ad3713a14ae247f22a728a0456a545df14acf3867f905adff84be99e23b3ad1" dependencies = [ "base64 0.13.0", "byteorder", "bytes 1.0.1", "http", "httparse", - "input_buffer", "log", "rand 0.8.4", - "rustls", + "rustls 0.20.2", "sha-1", "thiserror", "url", "utf-8", - "webpki", - "webpki-roots", + "webpki 0.22.0", ] [[package]] @@ -2999,12 +3049,13 @@ dependencies = [ ] [[package]] -name = "webpki-roots" -version = "0.21.1" +name = "webpki" +version = "0.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aabe153544e473b775453675851ecc86863d2a81d786d741f6b76778f2a48940" +checksum = "f095d78192e208183081cc07bc5515ef55216397af48b873e5edcd72637fa1bd" dependencies = [ - "webpki", + "ring", + "untrusted", ] [[package]] @@ -3072,9 +3123,9 @@ dependencies = [ [[package]] name = "ws_stream_tungstenite" -version = "0.6.1" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34c786fc3d0a792f8a6e7a69f3b85afa1cf7b2560bbd434d7d5c32a580e153c0" +checksum = "a672ec78525bf189cefa7f1b72c55f928b3edbdb967e680ca49748ab20821045" dependencies = [ "async-tungstenite", "async_io_stream", @@ -3087,5 +3138,5 @@ dependencies = [ "pharos", "rustc_version 0.4.0", "tokio 1.8.2", - "tungstenite 0.13.0", + "tungstenite 0.16.0", ] diff --git a/benchmarks/clients/mesh.rs b/benchmarks/clients/mesh.rs index 8f6d04eec..96f38de50 100644 --- a/benchmarks/clients/mesh.rs +++ b/benchmarks/clients/mesh.rs @@ -3,7 +3,7 @@ use tokio::task; use std::thread; use std::path::PathBuf; use bytes::Bytes; -use rumqttlog::router::{Data}; +use rumqttlog::router::Data; mod common; @@ -74,5 +74,3 @@ async fn read(tx: Sender<(usize, RouterInMessage)>) { println!("Id = {}, Total size = {}", id, total_size); } - - diff --git a/benchmarks/clients/rumqttasync.rs b/benchmarks/clients/rumqttasync.rs index 3bfb49ee4..6e4844ad8 100644 --- a/benchmarks/clients/rumqttasync.rs +++ b/benchmarks/clients/rumqttasync.rs @@ -1,4 +1,4 @@ -use rumqttc::*; +use rumqttc::v4::*; use std::error::Error; use std::time::{Duration, Instant}; diff --git a/benchmarks/clients/rumqttasyncqos0.rs b/benchmarks/clients/rumqttasyncqos0.rs index a2a668b0b..081f19ae0 100644 --- a/benchmarks/clients/rumqttasyncqos0.rs +++ b/benchmarks/clients/rumqttasyncqos0.rs @@ -1,4 +1,4 @@ -use rumqttc::*; +use rumqttc::v4::*; use std::error::Error; use std::time::{Duration, Instant}; diff --git a/benchmarks/clients/rumqttsync.rs b/benchmarks/clients/rumqttsync.rs index da85194dd..fd94bd5df 100644 --- a/benchmarks/clients/rumqttsync.rs +++ b/benchmarks/clients/rumqttsync.rs @@ -1,4 +1,4 @@ -use rumqttc::{self, Client, Event, Incoming, MqttOptions, QoS}; +use rumqttc::v4::{Client, Event, Incoming, MqttOptions, QoS}; use std::error::Error; use std::thread; use std::time::{Duration, Instant}; diff --git a/rumqttc/examples/async_manual_acks.rs b/rumqttc/examples/async_manual_acks.rs index 62f7c20c6..6790991e3 100644 --- a/rumqttc/examples/async_manual_acks.rs +++ b/rumqttc/examples/async_manual_acks.rs @@ -1,6 +1,6 @@ use tokio::{task, time}; -use rumqttc::{self, AsyncClient, Event, EventLoop, Incoming, MqttOptions, QoS}; +use rumqttc::v4::{self, AsyncClient, Event, EventLoop, Incoming, MqttOptions, QoS}; use std::error::Error; use std::time::Duration; diff --git a/rumqttc/examples/asyncpubsub.rs b/rumqttc/examples/asyncpubsub.rs index 4ec2cd983..fef15451d 100644 --- a/rumqttc/examples/asyncpubsub.rs +++ b/rumqttc/examples/asyncpubsub.rs @@ -1,6 +1,6 @@ use tokio::{task, time}; -use rumqttc::{self, AsyncClient, MqttOptions, QoS}; +use rumqttc::v4::{self, AsyncClient, MqttOptions, QoS}; use std::error::Error; use std::time::Duration; diff --git a/rumqttc/examples/syncpubsub.rs b/rumqttc/examples/syncpubsub.rs index c69ffc69f..b13abc65f 100644 --- a/rumqttc/examples/syncpubsub.rs +++ b/rumqttc/examples/syncpubsub.rs @@ -1,4 +1,4 @@ -use rumqttc::{self, Client, LastWill, MqttOptions, QoS}; +use rumqttc::v4::{self, Client, LastWill, MqttOptions, QoS}; use std::thread; use std::time::Duration; diff --git a/rumqttc/examples/tls.rs b/rumqttc/examples/tls.rs index 2bb1b2272..90bb4180f 100644 --- a/rumqttc/examples/tls.rs +++ b/rumqttc/examples/tls.rs @@ -1,6 +1,6 @@ //! Example of how to configure rumqttd to connect to a server using TLS and authentication. -use rumqttc::{self, AsyncClient, Event, Incoming, MqttOptions, Transport}; +use rumqttc::v4::{self, AsyncClient, Event, Incoming, MqttOptions, Transport}; use rustls::ClientConfig; use std::error::Error; diff --git a/rumqttc/examples/tls2.rs b/rumqttc/examples/tls2.rs index c54836819..1bf5cf4fb 100644 --- a/rumqttc/examples/tls2.rs +++ b/rumqttc/examples/tls2.rs @@ -1,6 +1,6 @@ //! Example of how to configure rumqttd to connect to a server using TLS and authentication. -use rumqttc::{self, AsyncClient, Key, MqttOptions, TlsConfiguration, Transport}; +use rumqttc::v4::{self, AsyncClient, Key, MqttOptions, TlsConfiguration, Transport}; use std::error::Error; #[tokio::main] diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index 69a3c721c..0a30cbefd 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -99,709 +99,5 @@ #[macro_use] extern crate log; -use std::fmt::{self, Debug, Formatter}; -use std::sync::Arc; -use std::time::Duration; - -mod client; -mod eventloop; -mod framed; -mod state; -mod tls; - -pub use async_channel::{SendError, Sender, TrySendError}; -pub use client::{AsyncClient, Client, ClientError, Connection}; -pub use eventloop::{ConnectionError, Event, EventLoop}; -pub use mqttbytes::v4::*; -pub use mqttbytes::*; -pub use state::{MqttState, StateError}; -pub use tokio_rustls::rustls::ClientConfig; -pub use tls::Error; - -pub type Incoming = Packet; - -/// Current outgoing activity on the eventloop -#[derive(Debug, Eq, PartialEq, Clone)] -pub enum Outgoing { - /// Publish packet with packet identifier. 0 implies QoS 0 - Publish(u16), - /// Subscribe packet with packet identifier - Subscribe(u16), - /// Unsubscribe packet with packet identifier - Unsubscribe(u16), - /// PubAck packet - PubAck(u16), - /// PubRec packet - PubRec(u16), - /// PubRel packet - PubRel(u16), - /// PubComp packet - PubComp(u16), - /// Ping request packet - PingReq, - /// Ping response packet - PingResp, - /// Disconnect packet - Disconnect, - /// Await for an ack for more outgoing progress - AwaitAck(u16), -} - -/// Requests by the client to mqtt event loop. Request are -/// handled one by one. -#[derive(Clone, Debug, PartialEq)] -pub enum Request { - Publish(Publish), - PubAck(PubAck), - PubRec(PubRec), - PubComp(PubComp), - PubRel(PubRel), - PingReq, - PingResp, - Subscribe(Subscribe), - SubAck(SubAck), - Unsubscribe(Unsubscribe), - UnsubAck(UnsubAck), - Disconnect, -} - -/// Key type for TLS authentication -#[derive(Clone, Debug, Eq, PartialEq)] -pub enum Key { - RSA(Vec), - ECC(Vec), -} - -impl From for Request { - fn from(publish: Publish) -> Request { - Request::Publish(publish) - } -} - -impl From for Request { - fn from(subscribe: Subscribe) -> Request { - Request::Subscribe(subscribe) - } -} - -impl From for Request { - fn from(unsubscribe: Unsubscribe) -> Request { - Request::Unsubscribe(unsubscribe) - } -} - -#[derive(Clone)] -pub enum Transport { - Tcp, - Tls(TlsConfiguration), - #[cfg(unix)] - Unix, - #[cfg(feature = "websocket")] - #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] - Ws, - #[cfg(feature = "websocket")] - #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] - Wss(TlsConfiguration), -} - -impl Default for Transport { - fn default() -> Self { - Self::tcp() - } -} - -impl Transport { - /// Use regular tcp as transport (default) - pub fn tcp() -> Self { - Self::Tcp - } - - /// Use secure tcp with tls as transport - pub fn tls( - ca: Vec, - client_auth: Option<(Vec, Key)>, - alpn: Option>>, - ) -> Self { - let config = TlsConfiguration::Simple { - ca, - alpn, - client_auth, - }; - - Self::tls_with_config(config) - } - - pub fn tls_with_config(tls_config: TlsConfiguration) -> Self { - Self::Tls(tls_config) - } - - #[cfg(unix)] - pub fn unix() -> Self { - Self::Unix - } - - /// Use websockets as transport - #[cfg(feature = "websocket")] - #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] - pub fn ws() -> Self { - Self::Ws - } - - /// Use secure websockets with tls as transport - #[cfg(feature = "websocket")] - #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] - pub fn wss( - ca: Vec, - client_auth: Option<(Vec, Key)>, - alpn: Option>>, - ) -> Self { - let config = TlsConfiguration::Simple { - ca, - client_auth, - alpn, - }; - - Self::wss_with_config(config) - } - - #[cfg(feature = "websocket")] - #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] - pub fn wss_with_config(tls_config: TlsConfiguration) -> Self { - Self::Wss(tls_config) - } -} - -#[derive(Clone)] -pub enum TlsConfiguration { - Simple { - /// connection method - ca: Vec, - /// alpn settings - alpn: Option>>, - /// tls client_authentication - client_auth: Option<(Vec, Key)>, - }, - /// Injected rustls ClientConfig for TLS, to allow more customisation. - Rustls(Arc), -} - -impl From for TlsConfiguration { - fn from(config: ClientConfig) -> Self { - TlsConfiguration::Rustls(Arc::new(config)) - } -} - -// TODO: Should all the options be exposed as public? Drawback -// would be loosing the ability to panic when the user options -// are wrong (e.g empty client id) or aggressive (keep alive time) -/// Options to configure the behaviour of mqtt connection -#[derive(Clone)] -pub struct MqttOptions { - /// broker address that you want to connect to - broker_addr: String, - /// broker port - port: u16, - // What transport protocol to use - transport: Transport, - /// keep alive time to send pingreq to broker when the connection is idle - keep_alive: Duration, - /// clean (or) persistent session - clean_session: bool, - /// client identifier - client_id: String, - /// username and password - credentials: Option<(String, String)>, - /// maximum incoming packet size (verifies remaining length of the packet) - max_incoming_packet_size: usize, - /// Maximum outgoing packet size (only verifies publish payload size) - // TODO Verify this with all packets. This can be packet.write but message left in - // the state might be a footgun as user has to explicitly clean it. Probably state - // has to be moved to network - max_outgoing_packet_size: usize, - /// request (publish, subscribe) channel capacity - request_channel_capacity: usize, - /// Max internal request batching - max_request_batch: usize, - /// Minimum delay time between consecutive outgoing packets - /// while retransmitting pending packets - pending_throttle: Duration, - /// maximum number of outgoing inflight messages - inflight: u16, - /// Last will that will be issued on unexpected disconnect - last_will: Option, - /// Connection timeout - conn_timeout: u64, - /// If set to `true` MQTT acknowledgements are not sent automatically. - /// Every incoming publish packet must be manually acknowledged with `client.ack(...)` method. - manual_acks: bool, -} - -impl MqttOptions { - /// New mqtt options - pub fn new, T: Into>(id: S, host: T, port: u16) -> MqttOptions { - let id = id.into(); - if id.starts_with(' ') || id.is_empty() { - panic!("Invalid client id") - } - - MqttOptions { - broker_addr: host.into(), - port, - transport: Transport::tcp(), - keep_alive: Duration::from_secs(60), - clean_session: true, - client_id: id, - credentials: None, - max_incoming_packet_size: 10 * 1024, - max_outgoing_packet_size: 10 * 1024, - request_channel_capacity: 10, - max_request_batch: 0, - pending_throttle: Duration::from_micros(0), - inflight: 100, - last_will: None, - conn_timeout: 5, - manual_acks: false, - } - } - - /// Broker address - pub fn broker_address(&self) -> (String, u16) { - (self.broker_addr.clone(), self.port) - } - - pub fn set_last_will(&mut self, will: LastWill) -> &mut Self { - self.last_will = Some(will); - self - } - - pub fn last_will(&self) -> Option { - self.last_will.clone() - } - - pub fn set_transport(&mut self, transport: Transport) -> &mut Self { - self.transport = transport; - self - } - - pub fn transport(&self) -> Transport { - self.transport.clone() - } - - /// Set number of seconds after which client should ping the broker - /// if there is no other data exchange - pub fn set_keep_alive(&mut self, duration: Duration) -> &mut Self { - if duration.as_secs() < 5 { - panic!("Keep alives should be >= 5 secs"); - } - - self.keep_alive = duration; - self - } - - /// Keep alive time - pub fn keep_alive(&self) -> Duration { - self.keep_alive - } - - /// Client identifier - pub fn client_id(&self) -> String { - self.client_id.clone() - } - - /// Set packet size limit for outgoing an incoming packets - pub fn set_max_packet_size(&mut self, incoming: usize, outgoing: usize) -> &mut Self { - self.max_incoming_packet_size = incoming; - self.max_outgoing_packet_size = outgoing; - self - } - - /// Maximum packet size - pub fn max_packet_size(&self) -> usize { - self.max_incoming_packet_size - } - - /// `clean_session = true` removes all the state from queues & instructs the broker - /// to clean all the client state when client disconnects. - /// - /// When set `false`, broker will hold the client state and performs pending - /// operations on the client when reconnection with same `client_id` - /// happens. Local queue state is also held to retransmit packets after reconnection. - pub fn set_clean_session(&mut self, clean_session: bool) -> &mut Self { - self.clean_session = clean_session; - self - } - - /// Clean session - pub fn clean_session(&self) -> bool { - self.clean_session - } - - /// Username and password - pub fn set_credentials, P: Into>(&mut self, username: U, password: P) -> &mut Self { - self.credentials = Some((username.into(), password.into())); - self - } - - /// Security options - pub fn credentials(&self) -> Option<(String, String)> { - self.credentials.clone() - } - - /// Set request channel capacity - pub fn set_request_channel_capacity(&mut self, capacity: usize) -> &mut Self { - self.request_channel_capacity = capacity; - self - } - - /// Request channel capacity - pub fn request_channel_capacity(&self) -> usize { - self.request_channel_capacity - } - - /// Enables throttling and sets outoing message rate to the specified 'rate' - pub fn set_pending_throttle(&mut self, duration: Duration) -> &mut Self { - self.pending_throttle = duration; - self - } - - /// Outgoing message rate - pub fn pending_throttle(&self) -> Duration { - self.pending_throttle - } - - /// Set number of concurrent in flight messages - pub fn set_inflight(&mut self, inflight: u16) -> &mut Self { - if inflight == 0 { - panic!("zero in flight is not allowed") - } - - self.inflight = inflight; - self - } - - /// Number of concurrent in flight messages - pub fn inflight(&self) -> u16 { - self.inflight - } - - /// set connection timeout in secs - pub fn set_connection_timeout(&mut self, timeout: u64) -> &mut Self { - self.conn_timeout = timeout; - self - } - - /// get timeout in secs - pub fn connection_timeout(&self) -> u64 { - self.conn_timeout - } - - /// set manual acknowledgements - pub fn set_manual_acks(&mut self, manual_acks: bool) -> &mut Self { - self.manual_acks = manual_acks; - self - } - - /// get manual acknowledgements - pub fn manual_acks(&self) -> bool { - self.manual_acks - } -} - -#[cfg(feature = "url")] -#[derive(Debug, PartialEq, thiserror::Error)] -pub enum OptionError { - #[error("Unsupported URL scheme.")] - Scheme, - - #[error("Missing client ID.")] - ClientId, - - #[error("Invalid keep-alive value.")] - KeepAlive, - - #[error("Invalid clean-session value.")] - CleanSession, - - #[error("Invalid max-incoming-packet-size value.")] - MaxIncomingPacketSize, - - #[error("Invalid max-outgoing-packet-size value.")] - MaxOutgoingPacketSize, - - #[error("Invalid request-channel-capacity value.")] - RequestChannelCapacity, - - #[error("Invalid max-request-batch value.")] - MaxRequestBatch, - - #[error("Invalid pending-throttle value.")] - PendingThrottle, - - #[error("Invalid inflight value.")] - Inflight, - - #[error("Invalid conn-timeout value.")] - ConnTimeout, - - #[error("Unknown option: {0}")] - Unknown(String), -} - -#[cfg(feature = "url")] -impl std::convert::TryFrom for MqttOptions { - type Error = OptionError; - - fn try_from(url: url::Url) -> Result { - use std::collections::HashMap; - - let broker_addr = url.host_str().unwrap_or_default().to_owned(); - - let (transport, default_port) = match url.scheme() { - // Encrypted connections are supported, but require explicit TLS configuration. We fall - // back to the unencrypted transport layer, so that `set_transport` can be used to - // configure the encrypted transport layer with the provided TLS configuration. - "mqtts" | "ssl" => (Transport::Tcp, 8883), - "mqtt" | "tcp" => (Transport::Tcp, 1883), - _ => return Err(OptionError::Scheme), - }; - - let port = url.port().unwrap_or(default_port); - - let mut queries = url.query_pairs().collect::>(); - - let keep_alive = Duration::from_secs( - queries - .remove("keep_alive_secs") - .map(|v| v.parse::().map_err(|_| OptionError::KeepAlive)) - .transpose()? - .unwrap_or(60), - ); - - let client_id = queries - .remove("client_id") - .ok_or(OptionError::ClientId)? - .into_owned(); - - let clean_session = queries - .remove("clean_session") - .map(|v| v.parse::().map_err(|_| OptionError::CleanSession)) - .transpose()? - .unwrap_or(true); - - let credentials = { - match url.username() { - "" => None, - username => Some(( - username.to_owned(), - url.password().unwrap_or_default().to_owned(), - )), - } - }; - - let max_incoming_packet_size = queries - .remove("max_incoming_packet_size_bytes") - .map(|v| { - v.parse::() - .map_err(|_| OptionError::MaxIncomingPacketSize) - }) - .transpose()? - .unwrap_or(10 * 1024); - - let max_outgoing_packet_size = queries - .remove("max_outgoing_packet_size_bytes") - .map(|v| { - v.parse::() - .map_err(|_| OptionError::MaxOutgoingPacketSize) - }) - .transpose()? - .unwrap_or(10 * 1024); - - let request_channel_capacity = queries - .remove("request_channel_capacity_num") - .map(|v| { - v.parse::() - .map_err(|_| OptionError::RequestChannelCapacity) - }) - .transpose()? - .unwrap_or(10); - - let max_request_batch = queries - .remove("max_request_batch_num") - .map(|v| v.parse::().map_err(|_| OptionError::MaxRequestBatch)) - .transpose()? - .unwrap_or(0); - - let pending_throttle = Duration::from_micros( - queries - .remove("pending_throttle_usecs") - .map(|v| v.parse::().map_err(|_| OptionError::PendingThrottle)) - .transpose()? - .unwrap_or(0), - ); - - let inflight = queries - .remove("inflight_num") - .map(|v| v.parse::().map_err(|_| OptionError::Inflight)) - .transpose()? - .unwrap_or(100); - - let conn_timeout = queries - .remove("conn_timeout_secs") - .map(|v| v.parse::().map_err(|_| OptionError::ConnTimeout)) - .transpose()? - .unwrap_or(5); - - if let Some((opt, _)) = queries.into_iter().next() { - return Err(OptionError::Unknown(opt.into_owned())); - } - - Ok(Self { - broker_addr, - port, - transport, - keep_alive, - clean_session, - client_id, - credentials, - max_incoming_packet_size, - max_outgoing_packet_size, - request_channel_capacity, - max_request_batch, - pending_throttle, - inflight, - last_will: None, - conn_timeout, - manual_acks: false - }) - } -} - -// Implement Debug manually because ClientConfig doesn't implement it, so derive(Debug) doesn't -// work. -impl Debug for MqttOptions { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - f.debug_struct("MqttOptions") - .field("broker_addr", &self.broker_addr) - .field("port", &self.port) - .field("keep_alive", &self.keep_alive) - .field("clean_session", &self.clean_session) - .field("client_id", &self.client_id) - .field("credentials", &self.credentials) - .field("max_packet_size", &self.max_incoming_packet_size) - .field("request_channel_capacity", &self.request_channel_capacity) - .field("max_request_batch", &self.max_request_batch) - .field("pending_throttle", &self.pending_throttle) - .field("inflight", &self.inflight) - .field("last_will", &self.last_will) - .field("conn_timeout", &self.conn_timeout) - .field("manual_acks", &self.manual_acks) - .finish() - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - #[should_panic] - fn client_id_startswith_space() { - let _mqtt_opts = MqttOptions::new(" client_a", "127.0.0.1", 1883).set_clean_session(true); - } - - #[test] - #[cfg(feature = "websocket")] - fn no_scheme() { - let mut _mqtt_opts = MqttOptions::new("client_a", "a3f8czas.iot.eu-west-1.amazonaws.com/mqtt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=MyCreds%2F20201001%2Feu-west-1%2Fiotdevicegateway%2Faws4_request&X-Amz-Date=20201001T130812Z&X-Amz-Expires=7200&X-Amz-Signature=9ae09b49896f44270f2707551581953e6cac71a4ccf34c7c3415555be751b2d1&X-Amz-SignedHeaders=host", 443); - - _mqtt_opts.set_transport(crate::Transport::wss(Vec::from("Test CA"), None, None)); - - if let crate::Transport::Wss(TlsConfiguration::Simple { - ca, - client_auth, - alpn, - }) = _mqtt_opts.transport - { - assert_eq!(ca, Vec::from("Test CA")); - assert_eq!(client_auth, None); - assert_eq!(alpn, None); - } else { - panic!("Unexpected transport!"); - } - - assert_eq!(_mqtt_opts.broker_addr, "a3f8czas.iot.eu-west-1.amazonaws.com/mqtt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=MyCreds%2F20201001%2Feu-west-1%2Fiotdevicegateway%2Faws4_request&X-Amz-Date=20201001T130812Z&X-Amz-Expires=7200&X-Amz-Signature=9ae09b49896f44270f2707551581953e6cac71a4ccf34c7c3415555be751b2d1&X-Amz-SignedHeaders=host"); - } - - #[test] - #[cfg(feature = "url")] - fn from_url() { - use std::convert::TryInto; - use std::str::FromStr; - - fn opt(s: &str) -> Result { - url::Url::from_str(s).expect("valid url").try_into() - } - fn ok(s: &str) -> MqttOptions { - opt(s).expect("valid options") - } - fn err(s: &str) -> OptionError { - opt(s).expect_err("invalid options") - } - - let v = ok("mqtt://host:42?client_id=foo"); - assert_eq!(v.broker_address(), ("host".to_owned(), 42)); - assert_eq!(v.client_id(), "foo".to_owned()); - - let v = ok("mqtt://host:42?client_id=foo&keep_alive_secs=5"); - assert_eq!(v.keep_alive, Duration::from_secs(5)); - - assert_eq!(err("mqtt://host:42"), OptionError::ClientId); - assert_eq!( - err("mqtt://host:42?client_id=foo&foo=bar"), - OptionError::Unknown("foo".to_owned()) - ); - assert_eq!(err("mqt://host:42?client_id=foo"), OptionError::Scheme); - assert_eq!( - err("mqtt://host:42?client_id=foo&keep_alive_secs=foo"), - OptionError::KeepAlive - ); - assert_eq!( - err("mqtt://host:42?client_id=foo&clean_session=foo"), - OptionError::CleanSession - ); - assert_eq!( - err("mqtt://host:42?client_id=foo&max_incoming_packet_size_bytes=foo"), - OptionError::MaxIncomingPacketSize - ); - assert_eq!( - err("mqtt://host:42?client_id=foo&max_outgoing_packet_size_bytes=foo"), - OptionError::MaxOutgoingPacketSize - ); - assert_eq!( - err("mqtt://host:42?client_id=foo&request_channel_capacity_num=foo"), - OptionError::RequestChannelCapacity - ); - assert_eq!( - err("mqtt://host:42?client_id=foo&max_request_batch_num=foo"), - OptionError::MaxRequestBatch - ); - assert_eq!( - err("mqtt://host:42?client_id=foo&pending_throttle_usecs=foo"), - OptionError::PendingThrottle - ); - assert_eq!( - err("mqtt://host:42?client_id=foo&inflight_num=foo"), - OptionError::Inflight - ); - assert_eq!( - err("mqtt://host:42?client_id=foo&conn_timeout_secs=foo"), - OptionError::ConnTimeout - ); - } - - #[test] - #[should_panic] - fn no_client_id() { - let _mqtt_opts = MqttOptions::new("", "127.0.0.1", 1883).set_clean_session(true); - } -} +pub mod v4; +pub mod v5; diff --git a/rumqttc/src/client.rs b/rumqttc/src/v4/client.rs similarity index 99% rename from rumqttc/src/client.rs rename to rumqttc/src/v4/client.rs index 139f54389..8b95c1bc2 100644 --- a/rumqttc/src/client.rs +++ b/rumqttc/src/v4/client.rs @@ -1,6 +1,6 @@ //! This module offers a high level synchronous and asynchronous abstraction to //! async eventloop. -use crate::{ConnectionError, Event, EventLoop, MqttOptions, Request}; +use crate::v4::{ConnectionError, Event, EventLoop, MqttOptions, Request}; use async_channel::{SendError, Sender, TrySendError}; use bytes::Bytes; diff --git a/rumqttc/src/eventloop.rs b/rumqttc/src/v4/eventloop.rs similarity index 98% rename from rumqttc/src/eventloop.rs rename to rumqttc/src/v4/eventloop.rs index b4065cf7e..d4ce78c7a 100644 --- a/rumqttc/src/eventloop.rs +++ b/rumqttc/src/v4/eventloop.rs @@ -1,6 +1,6 @@ -use crate::{framed::Network, Transport}; -use crate::{tls, Incoming, MqttState, Packet, Request, StateError}; -use crate::{MqttOptions, Outgoing}; +use crate::v4::{framed::Network, Transport}; +use crate::v4::{tls, Incoming, MqttState, Packet, Request, StateError}; +use crate::v4::{MqttOptions, Outgoing}; use async_channel::{bounded, Receiver, Sender}; #[cfg(feature = "websocket")] diff --git a/rumqttc/src/framed.rs b/rumqttc/src/v4/framed.rs similarity index 98% rename from rumqttc/src/framed.rs rename to rumqttc/src/v4/framed.rs index 995234c70..96bda17da 100644 --- a/rumqttc/src/framed.rs +++ b/rumqttc/src/v4/framed.rs @@ -3,7 +3,7 @@ use mqttbytes::v4::*; use mqttbytes::*; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use crate::{Incoming, MqttState, StateError}; +use crate::v4::{Incoming, MqttState, StateError}; use std::io; /// Network transforms packets <-> frames efficiently. It takes diff --git a/rumqttc/src/v4/mod.rs b/rumqttc/src/v4/mod.rs new file mode 100644 index 000000000..4797b0f2a --- /dev/null +++ b/rumqttc/src/v4/mod.rs @@ -0,0 +1,706 @@ +use std::fmt::{self, Debug, Formatter}; +use std::sync::Arc; +use std::time::Duration; + +mod client; +mod eventloop; +mod framed; +mod state; +mod tls; + +pub use async_channel::{SendError, Sender, TrySendError}; +pub use client::{AsyncClient, Client, ClientError, Connection}; +pub use eventloop::{ConnectionError, Event, EventLoop}; +pub use mqttbytes::v4::*; +pub use mqttbytes::*; +pub use state::{MqttState, StateError}; +pub use tokio_rustls::rustls::ClientConfig; +pub use tls::Error; + +pub type Incoming = Packet; + +/// Current outgoing activity on the eventloop +#[derive(Debug, Eq, PartialEq, Clone)] +pub enum Outgoing { + /// Publish packet with packet identifier. 0 implies QoS 0 + Publish(u16), + /// Subscribe packet with packet identifier + Subscribe(u16), + /// Unsubscribe packet with packet identifier + Unsubscribe(u16), + /// PubAck packet + PubAck(u16), + /// PubRec packet + PubRec(u16), + /// PubRel packet + PubRel(u16), + /// PubComp packet + PubComp(u16), + /// Ping request packet + PingReq, + /// Ping response packet + PingResp, + /// Disconnect packet + Disconnect, + /// Await for an ack for more outgoing progress + AwaitAck(u16), +} + +/// Requests by the client to mqtt event loop. Request are +/// handled one by one. +#[derive(Clone, Debug, PartialEq)] +pub enum Request { + Publish(Publish), + PubAck(PubAck), + PubRec(PubRec), + PubComp(PubComp), + PubRel(PubRel), + PingReq, + PingResp, + Subscribe(Subscribe), + SubAck(SubAck), + Unsubscribe(Unsubscribe), + UnsubAck(UnsubAck), + Disconnect, +} + +/// Key type for TLS authentication +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum Key { + RSA(Vec), + ECC(Vec), +} + +impl From for Request { + fn from(publish: Publish) -> Request { + Request::Publish(publish) + } +} + +impl From for Request { + fn from(subscribe: Subscribe) -> Request { + Request::Subscribe(subscribe) + } +} + +impl From for Request { + fn from(unsubscribe: Unsubscribe) -> Request { + Request::Unsubscribe(unsubscribe) + } +} + +#[derive(Clone)] +pub enum Transport { + Tcp, + Tls(TlsConfiguration), + #[cfg(unix)] + Unix, + #[cfg(feature = "websocket")] + #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + Ws, + #[cfg(feature = "websocket")] + #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + Wss(TlsConfiguration), +} + +impl Default for Transport { + fn default() -> Self { + Self::tcp() + } +} + +impl Transport { + /// Use regular tcp as transport (default) + pub fn tcp() -> Self { + Self::Tcp + } + + /// Use secure tcp with tls as transport + pub fn tls( + ca: Vec, + client_auth: Option<(Vec, Key)>, + alpn: Option>>, + ) -> Self { + let config = TlsConfiguration::Simple { + ca, + alpn, + client_auth, + }; + + Self::tls_with_config(config) + } + + pub fn tls_with_config(tls_config: TlsConfiguration) -> Self { + Self::Tls(tls_config) + } + + #[cfg(unix)] + pub fn unix() -> Self { + Self::Unix + } + + /// Use websockets as transport + #[cfg(feature = "websocket")] + #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + pub fn ws() -> Self { + Self::Ws + } + + /// Use secure websockets with tls as transport + #[cfg(feature = "websocket")] + #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + pub fn wss( + ca: Vec, + client_auth: Option<(Vec, Key)>, + alpn: Option>>, + ) -> Self { + let config = TlsConfiguration::Simple { + ca, + client_auth, + alpn, + }; + + Self::wss_with_config(config) + } + + #[cfg(feature = "websocket")] + #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + pub fn wss_with_config(tls_config: TlsConfiguration) -> Self { + Self::Wss(tls_config) + } +} + +#[derive(Clone)] +pub enum TlsConfiguration { + Simple { + /// connection method + ca: Vec, + /// alpn settings + alpn: Option>>, + /// tls client_authentication + client_auth: Option<(Vec, Key)>, + }, + /// Injected rustls ClientConfig for TLS, to allow more customisation. + Rustls(Arc), +} + +impl From for TlsConfiguration { + fn from(config: ClientConfig) -> Self { + TlsConfiguration::Rustls(Arc::new(config)) + } +} + +// TODO: Should all the options be exposed as public? Drawback +// would be loosing the ability to panic when the user options +// are wrong (e.g empty client id) or aggressive (keep alive time) +/// Options to configure the behaviour of mqtt connection +#[derive(Clone)] +pub struct MqttOptions { + /// broker address that you want to connect to + broker_addr: String, + /// broker port + port: u16, + // What transport protocol to use + transport: Transport, + /// keep alive time to send pingreq to broker when the connection is idle + keep_alive: Duration, + /// clean (or) persistent session + clean_session: bool, + /// client identifier + client_id: String, + /// username and password + credentials: Option<(String, String)>, + /// maximum incoming packet size (verifies remaining length of the packet) + max_incoming_packet_size: usize, + /// Maximum outgoing packet size (only verifies publish payload size) + // TODO Verify this with all packets. This can be packet.write but message left in + // the state might be a footgun as user has to explicitly clean it. Probably state + // has to be moved to network + max_outgoing_packet_size: usize, + /// request (publish, subscribe) channel capacity + request_channel_capacity: usize, + /// Max internal request batching + max_request_batch: usize, + /// Minimum delay time between consecutive outgoing packets + /// while retransmitting pending packets + pending_throttle: Duration, + /// maximum number of outgoing inflight messages + inflight: u16, + /// Last will that will be issued on unexpected disconnect + last_will: Option, + /// Connection timeout + conn_timeout: u64, + /// If set to `true` MQTT acknowledgements are not sent automatically. + /// Every incoming publish packet must be manually acknowledged with `client.ack(...)` method. + manual_acks: bool, +} + +impl MqttOptions { + /// New mqtt options + pub fn new, T: Into>(id: S, host: T, port: u16) -> MqttOptions { + let id = id.into(); + if id.starts_with(' ') || id.is_empty() { + panic!("Invalid client id") + } + + MqttOptions { + broker_addr: host.into(), + port, + transport: Transport::tcp(), + keep_alive: Duration::from_secs(60), + clean_session: true, + client_id: id, + credentials: None, + max_incoming_packet_size: 10 * 1024, + max_outgoing_packet_size: 10 * 1024, + request_channel_capacity: 10, + max_request_batch: 0, + pending_throttle: Duration::from_micros(0), + inflight: 100, + last_will: None, + conn_timeout: 5, + manual_acks: false, + } + } + + /// Broker address + pub fn broker_address(&self) -> (String, u16) { + (self.broker_addr.clone(), self.port) + } + + pub fn set_last_will(&mut self, will: LastWill) -> &mut Self { + self.last_will = Some(will); + self + } + + pub fn last_will(&self) -> Option { + self.last_will.clone() + } + + pub fn set_transport(&mut self, transport: Transport) -> &mut Self { + self.transport = transport; + self + } + + pub fn transport(&self) -> Transport { + self.transport.clone() + } + + /// Set number of seconds after which client should ping the broker + /// if there is no other data exchange + pub fn set_keep_alive(&mut self, duration: Duration) -> &mut Self { + if duration.as_secs() < 5 { + panic!("Keep alives should be >= 5 secs"); + } + + self.keep_alive = duration; + self + } + + /// Keep alive time + pub fn keep_alive(&self) -> Duration { + self.keep_alive + } + + /// Client identifier + pub fn client_id(&self) -> String { + self.client_id.clone() + } + + /// Set packet size limit for outgoing an incoming packets + pub fn set_max_packet_size(&mut self, incoming: usize, outgoing: usize) -> &mut Self { + self.max_incoming_packet_size = incoming; + self.max_outgoing_packet_size = outgoing; + self + } + + /// Maximum packet size + pub fn max_packet_size(&self) -> usize { + self.max_incoming_packet_size + } + + /// `clean_session = true` removes all the state from queues & instructs the broker + /// to clean all the client state when client disconnects. + /// + /// When set `false`, broker will hold the client state and performs pending + /// operations on the client when reconnection with same `client_id` + /// happens. Local queue state is also held to retransmit packets after reconnection. + pub fn set_clean_session(&mut self, clean_session: bool) -> &mut Self { + self.clean_session = clean_session; + self + } + + /// Clean session + pub fn clean_session(&self) -> bool { + self.clean_session + } + + /// Username and password + pub fn set_credentials, P: Into>(&mut self, username: U, password: P) -> &mut Self { + self.credentials = Some((username.into(), password.into())); + self + } + + /// Security options + pub fn credentials(&self) -> Option<(String, String)> { + self.credentials.clone() + } + + /// Set request channel capacity + pub fn set_request_channel_capacity(&mut self, capacity: usize) -> &mut Self { + self.request_channel_capacity = capacity; + self + } + + /// Request channel capacity + pub fn request_channel_capacity(&self) -> usize { + self.request_channel_capacity + } + + /// Enables throttling and sets outoing message rate to the specified 'rate' + pub fn set_pending_throttle(&mut self, duration: Duration) -> &mut Self { + self.pending_throttle = duration; + self + } + + /// Outgoing message rate + pub fn pending_throttle(&self) -> Duration { + self.pending_throttle + } + + /// Set number of concurrent in flight messages + pub fn set_inflight(&mut self, inflight: u16) -> &mut Self { + if inflight == 0 { + panic!("zero in flight is not allowed") + } + + self.inflight = inflight; + self + } + + /// Number of concurrent in flight messages + pub fn inflight(&self) -> u16 { + self.inflight + } + + /// set connection timeout in secs + pub fn set_connection_timeout(&mut self, timeout: u64) -> &mut Self { + self.conn_timeout = timeout; + self + } + + /// get timeout in secs + pub fn connection_timeout(&self) -> u64 { + self.conn_timeout + } + + /// set manual acknowledgements + pub fn set_manual_acks(&mut self, manual_acks: bool) -> &mut Self { + self.manual_acks = manual_acks; + self + } + + /// get manual acknowledgements + pub fn manual_acks(&self) -> bool { + self.manual_acks + } +} + +#[cfg(feature = "url")] +#[derive(Debug, PartialEq, thiserror::Error)] +pub enum OptionError { + #[error("Unsupported URL scheme.")] + Scheme, + + #[error("Missing client ID.")] + ClientId, + + #[error("Invalid keep-alive value.")] + KeepAlive, + + #[error("Invalid clean-session value.")] + CleanSession, + + #[error("Invalid max-incoming-packet-size value.")] + MaxIncomingPacketSize, + + #[error("Invalid max-outgoing-packet-size value.")] + MaxOutgoingPacketSize, + + #[error("Invalid request-channel-capacity value.")] + RequestChannelCapacity, + + #[error("Invalid max-request-batch value.")] + MaxRequestBatch, + + #[error("Invalid pending-throttle value.")] + PendingThrottle, + + #[error("Invalid inflight value.")] + Inflight, + + #[error("Invalid conn-timeout value.")] + ConnTimeout, + + #[error("Unknown option: {0}")] + Unknown(String), +} + +#[cfg(feature = "url")] +impl std::convert::TryFrom for MqttOptions { + type Error = OptionError; + + fn try_from(url: url::Url) -> Result { + use std::collections::HashMap; + + let broker_addr = url.host_str().unwrap_or_default().to_owned(); + + let (transport, default_port) = match url.scheme() { + // Encrypted connections are supported, but require explicit TLS configuration. We fall + // back to the unencrypted transport layer, so that `set_transport` can be used to + // configure the encrypted transport layer with the provided TLS configuration. + "mqtts" | "ssl" => (Transport::Tcp, 8883), + "mqtt" | "tcp" => (Transport::Tcp, 1883), + _ => return Err(OptionError::Scheme), + }; + + let port = url.port().unwrap_or(default_port); + + let mut queries = url.query_pairs().collect::>(); + + let keep_alive = Duration::from_secs( + queries + .remove("keep_alive_secs") + .map(|v| v.parse::().map_err(|_| OptionError::KeepAlive)) + .transpose()? + .unwrap_or(60), + ); + + let client_id = queries + .remove("client_id") + .ok_or(OptionError::ClientId)? + .into_owned(); + + let clean_session = queries + .remove("clean_session") + .map(|v| v.parse::().map_err(|_| OptionError::CleanSession)) + .transpose()? + .unwrap_or(true); + + let credentials = { + match url.username() { + "" => None, + username => Some(( + username.to_owned(), + url.password().unwrap_or_default().to_owned(), + )), + } + }; + + let max_incoming_packet_size = queries + .remove("max_incoming_packet_size_bytes") + .map(|v| { + v.parse::() + .map_err(|_| OptionError::MaxIncomingPacketSize) + }) + .transpose()? + .unwrap_or(10 * 1024); + + let max_outgoing_packet_size = queries + .remove("max_outgoing_packet_size_bytes") + .map(|v| { + v.parse::() + .map_err(|_| OptionError::MaxOutgoingPacketSize) + }) + .transpose()? + .unwrap_or(10 * 1024); + + let request_channel_capacity = queries + .remove("request_channel_capacity_num") + .map(|v| { + v.parse::() + .map_err(|_| OptionError::RequestChannelCapacity) + }) + .transpose()? + .unwrap_or(10); + + let max_request_batch = queries + .remove("max_request_batch_num") + .map(|v| v.parse::().map_err(|_| OptionError::MaxRequestBatch)) + .transpose()? + .unwrap_or(0); + + let pending_throttle = Duration::from_micros( + queries + .remove("pending_throttle_usecs") + .map(|v| v.parse::().map_err(|_| OptionError::PendingThrottle)) + .transpose()? + .unwrap_or(0), + ); + + let inflight = queries + .remove("inflight_num") + .map(|v| v.parse::().map_err(|_| OptionError::Inflight)) + .transpose()? + .unwrap_or(100); + + let conn_timeout = queries + .remove("conn_timeout_secs") + .map(|v| v.parse::().map_err(|_| OptionError::ConnTimeout)) + .transpose()? + .unwrap_or(5); + + if let Some((opt, _)) = queries.into_iter().next() { + return Err(OptionError::Unknown(opt.into_owned())); + } + + Ok(Self { + broker_addr, + port, + transport, + keep_alive, + clean_session, + client_id, + credentials, + max_incoming_packet_size, + max_outgoing_packet_size, + request_channel_capacity, + max_request_batch, + pending_throttle, + inflight, + last_will: None, + conn_timeout, + manual_acks: false + }) + } +} + +// Implement Debug manually because ClientConfig doesn't implement it, so derive(Debug) doesn't +// work. +impl Debug for MqttOptions { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("MqttOptions") + .field("broker_addr", &self.broker_addr) + .field("port", &self.port) + .field("keep_alive", &self.keep_alive) + .field("clean_session", &self.clean_session) + .field("client_id", &self.client_id) + .field("credentials", &self.credentials) + .field("max_packet_size", &self.max_incoming_packet_size) + .field("request_channel_capacity", &self.request_channel_capacity) + .field("max_request_batch", &self.max_request_batch) + .field("pending_throttle", &self.pending_throttle) + .field("inflight", &self.inflight) + .field("last_will", &self.last_will) + .field("conn_timeout", &self.conn_timeout) + .field("manual_acks", &self.manual_acks) + .finish() + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + #[should_panic] + fn client_id_startswith_space() { + let _mqtt_opts = MqttOptions::new(" client_a", "127.0.0.1", 1883).set_clean_session(true); + } + + #[test] + #[cfg(feature = "websocket")] + fn no_scheme() { + let mut _mqtt_opts = MqttOptions::new("client_a", "a3f8czas.iot.eu-west-1.amazonaws.com/mqtt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=MyCreds%2F20201001%2Feu-west-1%2Fiotdevicegateway%2Faws4_request&X-Amz-Date=20201001T130812Z&X-Amz-Expires=7200&X-Amz-Signature=9ae09b49896f44270f2707551581953e6cac71a4ccf34c7c3415555be751b2d1&X-Amz-SignedHeaders=host", 443); + + _mqtt_opts.set_transport(crate::Transport::wss(Vec::from("Test CA"), None, None)); + + if let crate::Transport::Wss(TlsConfiguration::Simple { + ca, + client_auth, + alpn, + }) = _mqtt_opts.transport + { + assert_eq!(ca, Vec::from("Test CA")); + assert_eq!(client_auth, None); + assert_eq!(alpn, None); + } else { + panic!("Unexpected transport!"); + } + + assert_eq!(_mqtt_opts.broker_addr, "a3f8czas.iot.eu-west-1.amazonaws.com/mqtt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=MyCreds%2F20201001%2Feu-west-1%2Fiotdevicegateway%2Faws4_request&X-Amz-Date=20201001T130812Z&X-Amz-Expires=7200&X-Amz-Signature=9ae09b49896f44270f2707551581953e6cac71a4ccf34c7c3415555be751b2d1&X-Amz-SignedHeaders=host"); + } + + #[test] + #[cfg(feature = "url")] + fn from_url() { + use std::convert::TryInto; + use std::str::FromStr; + + fn opt(s: &str) -> Result { + url::Url::from_str(s).expect("valid url").try_into() + } + fn ok(s: &str) -> MqttOptions { + opt(s).expect("valid options") + } + fn err(s: &str) -> OptionError { + opt(s).expect_err("invalid options") + } + + let v = ok("mqtt://host:42?client_id=foo"); + assert_eq!(v.broker_address(), ("host".to_owned(), 42)); + assert_eq!(v.client_id(), "foo".to_owned()); + + let v = ok("mqtt://host:42?client_id=foo&keep_alive_secs=5"); + assert_eq!(v.keep_alive, Duration::from_secs(5)); + + assert_eq!(err("mqtt://host:42"), OptionError::ClientId); + assert_eq!( + err("mqtt://host:42?client_id=foo&foo=bar"), + OptionError::Unknown("foo".to_owned()) + ); + assert_eq!(err("mqt://host:42?client_id=foo"), OptionError::Scheme); + assert_eq!( + err("mqtt://host:42?client_id=foo&keep_alive_secs=foo"), + OptionError::KeepAlive + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&clean_session=foo"), + OptionError::CleanSession + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&max_incoming_packet_size_bytes=foo"), + OptionError::MaxIncomingPacketSize + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&max_outgoing_packet_size_bytes=foo"), + OptionError::MaxOutgoingPacketSize + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&request_channel_capacity_num=foo"), + OptionError::RequestChannelCapacity + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&max_request_batch_num=foo"), + OptionError::MaxRequestBatch + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&pending_throttle_usecs=foo"), + OptionError::PendingThrottle + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&inflight_num=foo"), + OptionError::Inflight + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&conn_timeout_secs=foo"), + OptionError::ConnTimeout + ); + } + + #[test] + #[should_panic] + fn no_client_id() { + let _mqtt_opts = MqttOptions::new("", "127.0.0.1", 1883).set_clean_session(true); + } +} diff --git a/rumqttc/src/state.rs b/rumqttc/src/v4/state.rs similarity index 99% rename from rumqttc/src/state.rs rename to rumqttc/src/v4/state.rs index e33c7eb6e..c51116ccf 100644 --- a/rumqttc/src/state.rs +++ b/rumqttc/src/v4/state.rs @@ -1,4 +1,4 @@ -use crate::{Event, Incoming, Outgoing, Request}; +use crate::v4::{Event, Incoming, Outgoing, Request}; use bytes::BytesMut; use mqttbytes::v4::*; @@ -487,7 +487,7 @@ impl MqttState { #[cfg(test)] mod test { use super::{MqttState, StateError}; - use crate::{Event, Incoming, MqttOptions, Outgoing, Request}; + use crate::v4::{Event, Incoming, MqttOptions, Outgoing, Request}; use mqttbytes::v4::*; use mqttbytes::*; diff --git a/rumqttc/src/tls.rs b/rumqttc/src/v4/tls.rs similarity index 98% rename from rumqttc/src/tls.rs rename to rumqttc/src/v4/tls.rs index 747614015..ea1809e23 100644 --- a/rumqttc/src/tls.rs +++ b/rumqttc/src/v4/tls.rs @@ -7,7 +7,7 @@ use tokio_rustls::rustls::{ use tokio_rustls::webpki; use tokio_rustls::{client::TlsStream, TlsConnector}; -use crate::{Key, MqttOptions, TlsConfiguration}; +use crate::v4::{Key, MqttOptions, TlsConfiguration}; use std::convert::TryFrom; use std::io; diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs new file mode 100644 index 000000000..e69de29bb diff --git a/rumqttc/tests/broker.rs b/rumqttc/tests/broker.rs index 13c9ffc50..194f8ec92 100644 --- a/rumqttc/tests/broker.rs +++ b/rumqttc/tests/broker.rs @@ -9,7 +9,7 @@ use tokio::{task, time}; use async_channel::{bounded, Receiver, Sender}; use bytes::BytesMut; -use rumqttc::{Event, Incoming, Outgoing, Packet}; +use rumqttc::v4::{Event, Incoming, Outgoing, Packet}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; pub struct Broker { diff --git a/rumqttc/tests/reliability.rs b/rumqttc/tests/reliability.rs index 454c425da..3a129b6b8 100644 --- a/rumqttc/tests/reliability.rs +++ b/rumqttc/tests/reliability.rs @@ -5,7 +5,7 @@ use tokio::{task, time}; mod broker; use broker::*; -use rumqttc::*; +use rumqttc::v4::*; async fn start_requests(count: u8, qos: QoS, delay: u64, requests_tx: Sender) { for i in 1..=count { From c094fa566687905d58a40f0b2df9f1db715beb19 Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Fri, 11 Feb 2022 17:49:20 +0530 Subject: [PATCH 02/38] rumqttc: v5: init with v4 code Signed-off-by: Abhik Jain --- rumqttc/src/v5/client.rs | 423 +++++++++++++++++++ rumqttc/src/v5/eventloop.rs | 380 ++++++++++++++++++ rumqttc/src/v5/framed.rs | 122 ++++++ rumqttc/src/v5/mod.rs | 707 ++++++++++++++++++++++++++++++++ rumqttc/src/v5/packet.rs | 0 rumqttc/src/v5/state.rs | 781 ++++++++++++++++++++++++++++++++++++ rumqttc/src/v5/tls.rs | 130 ++++++ 7 files changed, 2543 insertions(+) create mode 100644 rumqttc/src/v5/client.rs create mode 100644 rumqttc/src/v5/eventloop.rs create mode 100644 rumqttc/src/v5/framed.rs create mode 100644 rumqttc/src/v5/packet.rs create mode 100644 rumqttc/src/v5/state.rs create mode 100644 rumqttc/src/v5/tls.rs diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs new file mode 100644 index 000000000..8b95c1bc2 --- /dev/null +++ b/rumqttc/src/v5/client.rs @@ -0,0 +1,423 @@ +//! This module offers a high level synchronous and asynchronous abstraction to +//! async eventloop. +use crate::v4::{ConnectionError, Event, EventLoop, MqttOptions, Request}; + +use async_channel::{SendError, Sender, TrySendError}; +use bytes::Bytes; +use mqttbytes::v4::*; +use mqttbytes::*; +use std::mem; +use tokio::runtime; +use tokio::runtime::Runtime; + +/// Client Error +#[derive(Debug, thiserror::Error)] +pub enum ClientError { + #[error("Failed to send cancel request to eventloop")] + Cancel(#[from] SendError<()>), + #[error("Failed to send mqtt requests to eventloop")] + Request(#[from] SendError), + #[error("Failed to send mqtt requests to eventloop")] + TryRequest(#[from] TrySendError), + #[error("Serialization error")] + Mqtt4(mqttbytes::Error), +} + +/// `AsyncClient` to communicate with MQTT `Eventloop` +/// This is cloneable and can be used to asynchronously Publish, Subscribe. +#[derive(Clone, Debug)] +pub struct AsyncClient { + request_tx: Sender, + cancel_tx: Sender<()>, +} + +impl AsyncClient { + /// Create a new `AsyncClient` + pub fn new(options: MqttOptions, cap: usize) -> (AsyncClient, EventLoop) { + let mut eventloop = EventLoop::new(options, cap); + let request_tx = eventloop.handle(); + let cancel_tx = eventloop.cancel_handle(); + + let client = AsyncClient { + request_tx, + cancel_tx, + }; + + (client, eventloop) + } + + /// Create a new `AsyncClient` from a pair of async channel `Sender`s. This is mostly useful for + /// creating a test instance. + pub fn from_senders(request_tx: Sender, cancel_tx: Sender<()>) -> AsyncClient { + AsyncClient { + request_tx, + cancel_tx, + } + } + + /// Sends a MQTT Publish to the eventloop + pub async fn publish( + &self, + topic: S, + qos: QoS, + retain: bool, + payload: V, + ) -> Result<(), ClientError> + where + S: Into, + V: Into>, + { + let mut publish = Publish::new(topic, qos, payload); + publish.retain = retain; + let publish = Request::Publish(publish); + self.request_tx.send(publish).await?; + Ok(()) + } + + /// Sends a MQTT Publish to the eventloop + pub fn try_publish( + &self, + topic: S, + qos: QoS, + retain: bool, + payload: V, + ) -> Result<(), ClientError> + where + S: Into, + V: Into>, + { + let mut publish = Publish::new(topic, qos, payload); + publish.retain = retain; + let publish = Request::Publish(publish); + self.request_tx.try_send(publish)?; + Ok(()) + } + + /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. + pub async fn ack( + &self, + publish: &Publish + ) -> Result<(), ClientError> + { + let ack = get_ack_req(publish); + + if let Some(ack) = ack { + self.request_tx.send(ack).await?; + } + Ok(()) + } + + /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. + pub fn try_ack( + &self, + publish: &Publish + ) -> Result<(), ClientError> + { + let ack = get_ack_req(publish); + if let Some(ack) = ack { + self.request_tx.try_send(ack)?; + } + Ok(()) + } + + /// Sends a MQTT Publish to the eventloop + pub async fn publish_bytes( + &self, + topic: S, + qos: QoS, + retain: bool, + payload: Bytes, + ) -> Result<(), ClientError> + where + S: Into, + { + let mut publish = Publish::from_bytes(topic, qos, payload); + publish.retain = retain; + let publish = Request::Publish(publish); + self.request_tx.send(publish).await?; + Ok(()) + } + + /// Sends a MQTT Subscribe to the eventloop + pub async fn subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { + let subscribe = Subscribe::new(topic.into(), qos); + let request = Request::Subscribe(subscribe); + self.request_tx.send(request).await?; + Ok(()) + } + + /// Sends a MQTT Subscribe to the eventloop + pub fn try_subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { + let subscribe = Subscribe::new(topic.into(), qos); + let request = Request::Subscribe(subscribe); + self.request_tx.try_send(request)?; + Ok(()) + } + + /// Sends a MQTT Subscribe for multiple topics to the eventloop + pub async fn subscribe_many(&self, topics: T) -> Result<(), ClientError> + where + T: IntoIterator, + { + let subscribe = Subscribe::new_many(topics); + let request = Request::Subscribe(subscribe); + self.request_tx.send(request).await?; + Ok(()) + } + + /// Sends a MQTT Subscribe for multiple topics to the eventloop + pub fn try_subscribe_many(&self, topics: T) -> Result<(), ClientError> + where + T: IntoIterator, + { + let subscribe = Subscribe::new_many(topics); + let request = Request::Subscribe(subscribe); + self.request_tx.try_send(request)?; + Ok(()) + } + + /// Sends a MQTT Unsubscribe to the eventloop + pub async fn unsubscribe>(&self, topic: S) -> Result<(), ClientError> { + let unsubscribe = Unsubscribe::new(topic.into()); + let request = Request::Unsubscribe(unsubscribe); + self.request_tx.send(request).await?; + Ok(()) + } + + /// Sends a MQTT Unsubscribe to the eventloop + pub fn try_unsubscribe>(&self, topic: S) -> Result<(), ClientError> { + let unsubscribe = Unsubscribe::new(topic.into()); + let request = Request::Unsubscribe(unsubscribe); + self.request_tx.try_send(request)?; + Ok(()) + } + + /// Sends a MQTT disconnect to the eventloop + pub async fn disconnect(&self) -> Result<(), ClientError> { + let request = Request::Disconnect; + self.request_tx.send(request).await?; + Ok(()) + } + + /// Sends a MQTT disconnect to the eventloop + pub fn try_disconnect(&self) -> Result<(), ClientError> { + let request = Request::Disconnect; + self.request_tx.try_send(request)?; + Ok(()) + } + + /// Stops the eventloop right away + pub async fn cancel(&self) -> Result<(), ClientError> { + self.cancel_tx.send(()).await?; + Ok(()) + } +} + +fn get_ack_req(publish: &Publish) -> Option { + let ack = match publish.qos { + QoS::AtMostOnce => return None, + QoS::AtLeastOnce => Request::PubAck(PubAck::new(publish.pkid)), + QoS::ExactlyOnce => Request::PubRec(PubRec::new(publish.pkid)) + }; + Some(ack) +} + +/// `Client` to communicate with MQTT eventloop `Connection`. +/// +/// Client is cloneable and can be used to synchronously Publish, Subscribe. +/// Asynchronous channel handle can also be extracted if necessary +#[derive(Clone)] +pub struct Client { + client: AsyncClient, +} + +impl Client { + /// Create a new `Client` + pub fn new(options: MqttOptions, cap: usize) -> (Client, Connection) { + let (client, eventloop) = AsyncClient::new(options, cap); + let client = Client { client }; + let runtime = runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + let connection = Connection::new(eventloop, runtime); + (client, connection) + } + + /// Sends a MQTT Publish to the eventloop + pub fn publish( + &mut self, + topic: S, + qos: QoS, + retain: bool, + payload: V, + ) -> Result<(), ClientError> + where + S: Into, + V: Into>, + { + pollster::block_on(self.client.publish(topic, qos, retain, payload))?; + Ok(()) + } + + pub fn try_publish( + &mut self, + topic: S, + qos: QoS, + retain: bool, + payload: V, + ) -> Result<(), ClientError> + where + S: Into, + V: Into>, + { + self.client.try_publish(topic, qos, retain, payload)?; + Ok(()) + } + + /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. + pub fn ack( + &self, + publish: &Publish + ) -> Result<(), ClientError> + { + pollster::block_on(self.client.ack(publish))?; + Ok(()) + } + + /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. + pub fn try_ack( + &self, + publish: &Publish + ) -> Result<(), ClientError> + { + self.client.try_ack(publish)?; + Ok(()) + } + + + /// Sends a MQTT Subscribe to the eventloop + pub fn subscribe>(&mut self, topic: S, qos: QoS) -> Result<(), ClientError> { + pollster::block_on(self.client.subscribe(topic, qos))?; + Ok(()) + } + + /// Sends a MQTT Subscribe to the eventloop + pub fn try_subscribe>( + &mut self, + topic: S, + qos: QoS, + ) -> Result<(), ClientError> { + self.client.try_subscribe(topic, qos)?; + Ok(()) + } + + /// Sends a MQTT Subscribe for multiple topics to the eventloop + pub fn subscribe_many(&mut self, topics: T) -> Result<(), ClientError> + where + T: IntoIterator, + { + pollster::block_on(self.client.subscribe_many(topics)) + } + + pub fn try_subscribe_many(&mut self, topics: T) -> Result<(), ClientError> + where + T: IntoIterator, + { + self.client.try_subscribe_many(topics) + } + + /// Sends a MQTT Unsubscribe to the eventloop + pub fn unsubscribe>(&mut self, topic: S) -> Result<(), ClientError> { + pollster::block_on(self.client.unsubscribe(topic))?; + Ok(()) + } + + /// Sends a MQTT Unsubscribe to the eventloop + pub fn try_unsubscribe>(&mut self, topic: S) -> Result<(), ClientError> { + self.client.try_unsubscribe(topic)?; + Ok(()) + } + + /// Sends a MQTT disconnect to the eventloop + pub fn disconnect(&mut self) -> Result<(), ClientError> { + pollster::block_on(self.client.disconnect())?; + Ok(()) + } + + /// Sends a MQTT disconnect to the eventloop + pub fn try_disconnect(&mut self) -> Result<(), ClientError> { + self.client.try_disconnect()?; + Ok(()) + } + + /// Stops the eventloop right away + pub fn cancel(&mut self) -> Result<(), ClientError> { + pollster::block_on(self.client.cancel())?; + Ok(()) + } +} + +/// MQTT connection. Maintains all the necessary state +pub struct Connection { + pub eventloop: EventLoop, + runtime: Option, +} + +impl Connection { + fn new(eventloop: EventLoop, runtime: Runtime) -> Connection { + Connection { + eventloop, + runtime: Some(runtime), + } + } + + /// Returns an iterator over this connection. Iterating over this is all that's + /// necessary to make connection progress and maintain a robust connection. + /// Just continuing to loop will reconnect + /// **NOTE** Don't block this while iterating + #[must_use = "Connection should be iterated over a loop to make progress"] + pub fn iter(&mut self) -> Iter { + let runtime = self.runtime.take().unwrap(); + Iter { + connection: self, + runtime, + } + } +} + +/// Iterator which polls the eventloop for connection progress +pub struct Iter<'a> { + connection: &'a mut Connection, + runtime: runtime::Runtime, +} + +impl<'a> Iterator for Iter<'a> { + type Item = Result; + + fn next(&mut self) -> Option { + let f = self.connection.eventloop.poll(); + match self.runtime.block_on(f) { + Ok(v) => Some(Ok(v)), + // closing of request channel should stop the iterator + Err(ConnectionError::RequestsDone) => { + trace!("Done with requests"); + None + } + Err(ConnectionError::Cancel) => { + trace!("Cancellation request received"); + None + } + Err(e) => Some(Err(e)), + } + } +} + +impl<'a> Drop for Iter<'a> { + fn drop(&mut self) { + // TODO: Don't create new runtime in drop + let runtime = runtime::Builder::new_current_thread().build().unwrap(); + self.connection.runtime = Some(mem::replace(&mut self.runtime, runtime)); + } +} diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs new file mode 100644 index 000000000..d4ce78c7a --- /dev/null +++ b/rumqttc/src/v5/eventloop.rs @@ -0,0 +1,380 @@ +use crate::v4::{framed::Network, Transport}; +use crate::v4::{tls, Incoming, MqttState, Packet, Request, StateError}; +use crate::v4::{MqttOptions, Outgoing}; + +use async_channel::{bounded, Receiver, Sender}; +#[cfg(feature = "websocket")] +use async_tungstenite::tokio::{connect_async, connect_async_with_tls_connector}; +use mqttbytes::v4::*; +use tokio::net::TcpStream; +#[cfg(unix)] +use tokio::net::UnixStream; +use tokio::select; +use tokio::time::{self, error::Elapsed, Instant, Sleep}; +#[cfg(feature = "websocket")] +use ws_stream_tungstenite::WsStream; + +use std::io; +#[cfg(unix)] +use std::path::Path; +use std::pin::Pin; +use std::time::Duration; +use std::vec::IntoIter; + +/// Critical errors during eventloop polling +#[derive(Debug, thiserror::Error)] +pub enum ConnectionError { + #[error("Mqtt state: {0}")] + MqttState(#[from] StateError), + #[error("Timeout")] + Timeout(#[from] Elapsed), + #[error("Packet parsing error: {0}")] + Mqtt4Bytes(mqttbytes::Error), + #[error("Network: {0}")] + Network(#[from] tls::Error), + #[error("I/O: {0}")] + Io(#[from] io::Error), + #[error("Stream done")] + StreamDone, + #[error("Requests done")] + RequestsDone, + #[error("Cancel request by the user")] + Cancel, +} + +/// Eventloop with all the state of a connection +pub struct EventLoop { + /// Options of the current mqtt connection + pub options: MqttOptions, + /// Current state of the connection + pub state: MqttState, + /// Request stream + pub requests_rx: Receiver, + /// Requests handle to send requests + pub requests_tx: Sender, + /// Pending packets from last session + pub pending: IntoIter, + /// Network connection to the broker + pub(crate) network: Option, + /// Keep alive time + pub(crate) keepalive_timeout: Option>>, + /// Handle to read cancellation requests + pub(crate) cancel_rx: Receiver<()>, + /// Handle to send cancellation requests (and drops) + pub(crate) cancel_tx: Sender<()>, +} + +/// Events which can be yielded by the event loop +#[derive(Debug, PartialEq, Clone)] +pub enum Event { + Incoming(Incoming), + Outgoing(Outgoing), +} + +impl EventLoop { + /// New MQTT `EventLoop` + /// + /// When connection encounters critical errors (like auth failure), user has a choice to + /// access and update `options`, `state` and `requests`. + pub fn new(options: MqttOptions, cap: usize) -> EventLoop { + let (cancel_tx, cancel_rx) = bounded(5); + let (requests_tx, requests_rx) = bounded(cap); + let pending = Vec::new(); + let pending = pending.into_iter(); + let max_inflight = options.inflight; + let manual_acks = options.manual_acks; + + EventLoop { + options, + state: MqttState::new(max_inflight, manual_acks), + requests_tx, + requests_rx, + pending, + network: None, + keepalive_timeout: None, + cancel_rx, + cancel_tx, + } + } + + /// Returns a handle to communicate with this eventloop + pub fn handle(&self) -> Sender { + self.requests_tx.clone() + } + + /// Handle for cancelling the eventloop. + /// + /// Can be useful in cases when connection should be halted immediately + /// between half-open connection detections or (re)connection timeouts + pub(crate) fn cancel_handle(&mut self) -> Sender<()> { + self.cancel_tx.clone() + } + + fn clean(&mut self) { + self.network = None; + self.keepalive_timeout = None; + let pending = self.state.clean(); + self.pending = pending.into_iter(); + } + + /// Yields Next notification or outgoing request and periodically pings + /// the broker. Continuing to poll will reconnect to the broker if there is + /// a disconnection. + /// **NOTE** Don't block this while iterating + pub async fn poll(&mut self) -> Result { + if self.network.is_none() { + let (network, connack) = connect_or_cancel(&self.options, &self.cancel_rx).await?; + self.network = Some(network); + + if self.keepalive_timeout.is_none() { + self.keepalive_timeout = Some(Box::pin(time::sleep(self.options.keep_alive))); + } + + return Ok(Event::Incoming(connack)); + } + + match self.select().await { + Ok(v) => Ok(v), + Err(e) => { + self.clean(); + Err(e) + } + } + } + + /// Select on network and requests and generate keepalive pings when necessary + async fn select(&mut self) -> Result { + let network = self.network.as_mut().unwrap(); + // let await_acks = self.state.await_acks; + let inflight_full = self.state.inflight >= self.options.inflight; + let throttle = self.options.pending_throttle; + let pending = self.pending.len() > 0; + let collision = self.state.collision.is_some(); + + // Read buffered events from previous polls before calling a new poll + if let Some(event) = self.state.events.pop_front() { + return Ok(event); + } + + // this loop is necessary since self.incoming.pop_front() might return None. In that case, + // instead of returning a None event, we try again. + select! { + // Pull a bunch of packets from network, reply in bunch and yield the first item + o = network.readb(&mut self.state) => { + o?; + // flush all the acks and return first incoming packet + network.flush(&mut self.state.write).await?; + Ok(self.state.events.pop_front().unwrap()) + }, + // Pull next request from user requests channel. + // If conditions in the below branch are for flow control. We read next user + // user request only when inflight messages are < configured inflight and there + // are no collisions while handling previous outgoing requests. + // + // Flow control is based on ack count. If inflight packet count in the buffer is + // less than max_inflight setting, next outgoing request will progress. For this + // to work correctly, broker should ack in sequence (a lot of brokers won't) + // + // E.g If max inflight = 5, user requests will be blocked when inflight queue + // looks like this -> [1, 2, 3, 4, 5]. + // If broker acking 2 instead of 1 -> [1, x, 3, 4, 5]. + // This pulls next user request. But because max packet id = max_inflight, next + // user request's packet id will roll to 1. This replaces existing packet id 1. + // Resulting in a collision + // + // Eventloop can stop receiving outgoing user requests when previous outgoing + // request collided. I.e collision state. Collision state will be cleared only + // when correct ack is received + // Full inflight queue will look like -> [1a, 2, 3, 4, 5]. + // If 3 is acked instead of 1 first -> [1a, 2, x, 4, 5]. + // After collision with pkid 1 -> [1b ,2, x, 4, 5]. + // 1a is saved to state and event loop is set to collision mode stopping new + // outgoing requests (along with 1b). + o = self.requests_rx.recv(), if !inflight_full && !pending && !collision => match o { + Ok(request) => { + self.state.handle_outgoing_packet(request)?; + network.flush(&mut self.state.write).await?; + Ok(self.state.events.pop_front().unwrap()) + } + Err(_) => Err(ConnectionError::RequestsDone), + }, + // Handle the next pending packet from previous session. Disable + // this branch when done with all the pending packets + Some(request) = next_pending(throttle, &mut self.pending), if pending => { + self.state.handle_outgoing_packet(request)?; + network.flush(&mut self.state.write).await?; + Ok(self.state.events.pop_front().unwrap()) + }, + // We generate pings irrespective of network activity. This keeps the ping logic + // simple. We can change this behavior in future if necessary (to prevent extra pings) + _ = self.keepalive_timeout.as_mut().unwrap() => { + let timeout = self.keepalive_timeout.as_mut().unwrap(); + timeout.as_mut().reset(Instant::now() + self.options.keep_alive); + + self.state.handle_outgoing_packet(Request::PingReq)?; + network.flush(&mut self.state.write).await?; + Ok(self.state.events.pop_front().unwrap()) + } + // cancellation requests to stop the polling + _ = self.cancel_rx.recv() => { + Err(ConnectionError::Cancel) + } + } + } +} + +async fn connect_or_cancel( + options: &MqttOptions, + cancel_rx: &Receiver<()>, +) -> Result<(Network, Incoming), ConnectionError> { + // select here prevents cancel request from being blocked until connection request is + // resolved. Returns with an error if connections fail continuously + select! { + o = connect(options) => o, + _ = cancel_rx.recv() => { + Err(ConnectionError::Cancel) + } + } +} + +/// This stream internally processes requests from the request stream provided to the eventloop +/// while also consuming byte stream from the network and yielding mqtt packets as the output of +/// the stream. +/// This function (for convenience) includes internal delays for users to perform internal sleeps +/// between re-connections so that cancel semantics can be used during this sleep +async fn connect(options: &MqttOptions) -> Result<(Network, Incoming), ConnectionError> { + // connect to the broker + let mut network = match network_connect(options).await { + Ok(network) => network, + Err(e) => { + return Err(e); + } + }; + + // make MQTT connection request (which internally awaits for ack) + let packet = match mqtt_connect(options, &mut network).await { + Ok(p) => p, + Err(e) => return Err(e), + }; + + // Last session might contain packets which aren't acked. MQTT says these packets should be + // republished in the next session + // move pending messages from state to eventloop + // let pending = self.state.clean(); + // self.pending = pending.into_iter(); + Ok((network, packet)) +} + +async fn network_connect(options: &MqttOptions) -> Result { + let network = match options.transport() { + Transport::Tcp => { + let addr = options.broker_addr.as_str(); + let port = options.port; + let socket = TcpStream::connect((addr, port)).await?; + Network::new(socket, options.max_incoming_packet_size) + } + Transport::Tls(tls_config) => { + let socket = tls::tls_connect(&options, &tls_config).await?; + Network::new(socket, options.max_incoming_packet_size) + } + #[cfg(unix)] + Transport::Unix => { + let file = options.broker_addr.as_str(); + let socket = UnixStream::connect(Path::new(file)).await?; + Network::new(socket, options.max_incoming_packet_size) + } + #[cfg(feature = "websocket")] + Transport::Ws => { + let request = http::Request::builder() + .method(http::Method::GET) + .uri(options.broker_addr.as_str()) + .header("Sec-WebSocket-Protocol", "mqttv3.1") + .body(()) + .unwrap(); + + let (socket, _) = connect_async(request) + .await + .map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, e))?; + + Network::new(WsStream::new(socket), options.max_incoming_packet_size) + } + #[cfg(feature = "websocket")] + Transport::Wss(tls_config) => { + let request = http::Request::builder() + .method(http::Method::GET) + .uri(options.broker_addr.as_str()) + .header("Sec-WebSocket-Protocol", "mqttv3.1") + .body(()) + .unwrap(); + + let connector = tls::tls_connector(&tls_config).await?; + + let (socket, _) = connect_async_with_tls_connector(request, Some(connector)) + .await + .map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, e))?; + + Network::new(WsStream::new(socket), options.max_incoming_packet_size) + } + }; + + Ok(network) +} + +async fn mqtt_connect( + options: &MqttOptions, + network: &mut Network, +) -> Result { + let keep_alive = options.keep_alive().as_secs() as u16; + let clean_session = options.clean_session(); + let last_will = options.last_will(); + + let mut connect = Connect::new(options.client_id()); + connect.keep_alive = keep_alive; + connect.clean_session = clean_session; + connect.last_will = last_will; + + if let Some((username, password)) = options.credentials() { + let login = Login::new(username, password); + connect.login = Some(login); + } + + // mqtt connection with timeout + time::timeout(Duration::from_secs(options.connection_timeout()), async { + network.connect(connect).await?; + Ok::<_, ConnectionError>(()) + }) + .await??; + + // wait for 'timeout' time to validate connack + let packet = time::timeout(Duration::from_secs(options.connection_timeout()), async { + let packet = match network.read().await? { + Incoming::ConnAck(connack) if connack.code == ConnectReturnCode::Success => { + Packet::ConnAck(connack) + } + Incoming::ConnAck(connack) => { + let error = format!("Broker rejected. Reason = {:?}", connack.code); + return Err(io::Error::new(io::ErrorKind::InvalidData, error)); + } + packet => { + let error = format!("Expecting connack. Received = {:?}", packet); + return Err(io::Error::new(io::ErrorKind::InvalidData, error)); + } + }; + + io::Result::Ok(packet) + }) + .await??; + + Ok(packet) +} + +/// Returns the next pending packet asynchronously to be used in select! +/// This is a synchronous function but made async to make it fit in select! +pub(crate) async fn next_pending( + delay: Duration, + pending: &mut IntoIter, +) -> Option { + // return next packet with a delay + time::sleep(delay).await; + pending.next() +} diff --git a/rumqttc/src/v5/framed.rs b/rumqttc/src/v5/framed.rs new file mode 100644 index 000000000..96bda17da --- /dev/null +++ b/rumqttc/src/v5/framed.rs @@ -0,0 +1,122 @@ +use bytes::BytesMut; +use mqttbytes::v4::*; +use mqttbytes::*; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +use crate::v4::{Incoming, MqttState, StateError}; +use std::io; + +/// Network transforms packets <-> frames efficiently. It takes +/// advantage of pre-allocation, buffering and vectorization when +/// appropriate to achieve performance +pub struct Network { + /// Socket for IO + socket: Box, + /// Buffered reads + read: BytesMut, + /// Maximum packet size + max_incoming_size: usize, + /// Maximum readv count + max_readb_count: usize, +} + +impl Network { + pub fn new(socket: impl N + 'static, max_incoming_size: usize) -> Network { + let socket = Box::new(socket) as Box; + Network { + socket, + read: BytesMut::with_capacity(10 * 1024), + max_incoming_size, + max_readb_count: 10, + } + } + + /// Reads more than 'required' bytes to frame a packet into self.read buffer + async fn read_bytes(&mut self, required: usize) -> io::Result { + let mut total_read = 0; + loop { + let read = self.socket.read_buf(&mut self.read).await?; + if 0 == read { + return if self.read.is_empty() { + Err(io::Error::new( + io::ErrorKind::ConnectionAborted, + "connection closed by peer", + )) + } else { + Err(io::Error::new( + io::ErrorKind::ConnectionReset, + "connection reset by peer", + )) + }; + } + + total_read += read; + if total_read >= required { + return Ok(total_read); + } + } + } + + pub async fn read(&mut self) -> io::Result { + loop { + let required = match read(&mut self.read, self.max_incoming_size) { + Ok(packet) => return Ok(packet), + Err(mqttbytes::Error::InsufficientBytes(required)) => required, + Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())), + }; + + // read more packets until a frame can be created. This function + // blocks until a frame can be created. Use this in a select! branch + self.read_bytes(required).await?; + } + } + + /// Read packets in bulk. This allow replies to be in bulk. This method is used + /// after the connection is established to read a bunch of incoming packets + pub async fn readb(&mut self, state: &mut MqttState) -> Result<(), StateError> { + let mut count = 0; + loop { + match read(&mut self.read, self.max_incoming_size) { + Ok(packet) => { + state.handle_incoming_packet(packet)?; + + count += 1; + if count >= self.max_readb_count { + return Ok(()); + } + } + // If some packets are already framed, return those + Err(Error::InsufficientBytes(_)) if count > 0 => return Ok(()), + // Wait for more bytes until a frame can be created + Err(Error::InsufficientBytes(required)) => { + self.read_bytes(required).await?; + } + Err(e) => return Err(StateError::Deserialization(e)), + }; + } + } + + pub async fn connect(&mut self, connect: Connect) -> io::Result { + let mut write = BytesMut::new(); + let len = match connect.write(&mut write) { + Ok(size) => size, + Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())), + }; + + self.socket.write_all(&write[..]).await?; + Ok(len) + } + + pub async fn flush(&mut self, write: &mut BytesMut) -> io::Result<()> { + if write.is_empty() { + return Ok(()); + } + + self.socket.write_all(&write[..]).await?; + write.clear(); + Ok(()) + } +} + +pub trait N: AsyncRead + AsyncWrite + Send + Unpin {} +impl N for T where T: AsyncRead + AsyncWrite + Send + Unpin {} diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index e69de29bb..01800c685 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -0,0 +1,707 @@ +use std::fmt::{self, Debug, Formatter}; +use std::sync::Arc; +use std::time::Duration; + +mod client; +mod eventloop; +mod framed; +mod packet; +mod state; +mod tls; + +pub use async_channel::{SendError, Sender, TrySendError}; +pub use client::{AsyncClient, Client, ClientError, Connection}; +pub use eventloop::{ConnectionError, Event, EventLoop}; +pub use mqttbytes::v4::*; +pub use mqttbytes::*; +pub use state::{MqttState, StateError}; +pub use tokio_rustls::rustls::ClientConfig; +pub use tls::Error; + +pub type Incoming = Packet; + +/// Current outgoing activity on the eventloop +#[derive(Debug, Eq, PartialEq, Clone)] +pub enum Outgoing { + /// Publish packet with packet identifier. 0 implies QoS 0 + Publish(u16), + /// Subscribe packet with packet identifier + Subscribe(u16), + /// Unsubscribe packet with packet identifier + Unsubscribe(u16), + /// PubAck packet + PubAck(u16), + /// PubRec packet + PubRec(u16), + /// PubRel packet + PubRel(u16), + /// PubComp packet + PubComp(u16), + /// Ping request packet + PingReq, + /// Ping response packet + PingResp, + /// Disconnect packet + Disconnect, + /// Await for an ack for more outgoing progress + AwaitAck(u16), +} + +/// Requests by the client to mqtt event loop. Request are +/// handled one by one. +#[derive(Clone, Debug, PartialEq)] +pub enum Request { + Publish(Publish), + PubAck(PubAck), + PubRec(PubRec), + PubComp(PubComp), + PubRel(PubRel), + PingReq, + PingResp, + Subscribe(Subscribe), + SubAck(SubAck), + Unsubscribe(Unsubscribe), + UnsubAck(UnsubAck), + Disconnect, +} + +/// Key type for TLS authentication +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum Key { + RSA(Vec), + ECC(Vec), +} + +impl From for Request { + fn from(publish: Publish) -> Request { + Request::Publish(publish) + } +} + +impl From for Request { + fn from(subscribe: Subscribe) -> Request { + Request::Subscribe(subscribe) + } +} + +impl From for Request { + fn from(unsubscribe: Unsubscribe) -> Request { + Request::Unsubscribe(unsubscribe) + } +} + +#[derive(Clone)] +pub enum Transport { + Tcp, + Tls(TlsConfiguration), + #[cfg(unix)] + Unix, + #[cfg(feature = "websocket")] + #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + Ws, + #[cfg(feature = "websocket")] + #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + Wss(TlsConfiguration), +} + +impl Default for Transport { + fn default() -> Self { + Self::tcp() + } +} + +impl Transport { + /// Use regular tcp as transport (default) + pub fn tcp() -> Self { + Self::Tcp + } + + /// Use secure tcp with tls as transport + pub fn tls( + ca: Vec, + client_auth: Option<(Vec, Key)>, + alpn: Option>>, + ) -> Self { + let config = TlsConfiguration::Simple { + ca, + alpn, + client_auth, + }; + + Self::tls_with_config(config) + } + + pub fn tls_with_config(tls_config: TlsConfiguration) -> Self { + Self::Tls(tls_config) + } + + #[cfg(unix)] + pub fn unix() -> Self { + Self::Unix + } + + /// Use websockets as transport + #[cfg(feature = "websocket")] + #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + pub fn ws() -> Self { + Self::Ws + } + + /// Use secure websockets with tls as transport + #[cfg(feature = "websocket")] + #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + pub fn wss( + ca: Vec, + client_auth: Option<(Vec, Key)>, + alpn: Option>>, + ) -> Self { + let config = TlsConfiguration::Simple { + ca, + client_auth, + alpn, + }; + + Self::wss_with_config(config) + } + + #[cfg(feature = "websocket")] + #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + pub fn wss_with_config(tls_config: TlsConfiguration) -> Self { + Self::Wss(tls_config) + } +} + +#[derive(Clone)] +pub enum TlsConfiguration { + Simple { + /// connection method + ca: Vec, + /// alpn settings + alpn: Option>>, + /// tls client_authentication + client_auth: Option<(Vec, Key)>, + }, + /// Injected rustls ClientConfig for TLS, to allow more customisation. + Rustls(Arc), +} + +impl From for TlsConfiguration { + fn from(config: ClientConfig) -> Self { + TlsConfiguration::Rustls(Arc::new(config)) + } +} + +// TODO: Should all the options be exposed as public? Drawback +// would be loosing the ability to panic when the user options +// are wrong (e.g empty client id) or aggressive (keep alive time) +/// Options to configure the behaviour of mqtt connection +#[derive(Clone)] +pub struct MqttOptions { + /// broker address that you want to connect to + broker_addr: String, + /// broker port + port: u16, + // What transport protocol to use + transport: Transport, + /// keep alive time to send pingreq to broker when the connection is idle + keep_alive: Duration, + /// clean (or) persistent session + clean_session: bool, + /// client identifier + client_id: String, + /// username and password + credentials: Option<(String, String)>, + /// maximum incoming packet size (verifies remaining length of the packet) + max_incoming_packet_size: usize, + /// Maximum outgoing packet size (only verifies publish payload size) + // TODO Verify this with all packets. This can be packet.write but message left in + // the state might be a footgun as user has to explicitly clean it. Probably state + // has to be moved to network + max_outgoing_packet_size: usize, + /// request (publish, subscribe) channel capacity + request_channel_capacity: usize, + /// Max internal request batching + max_request_batch: usize, + /// Minimum delay time between consecutive outgoing packets + /// while retransmitting pending packets + pending_throttle: Duration, + /// maximum number of outgoing inflight messages + inflight: u16, + /// Last will that will be issued on unexpected disconnect + last_will: Option, + /// Connection timeout + conn_timeout: u64, + /// If set to `true` MQTT acknowledgements are not sent automatically. + /// Every incoming publish packet must be manually acknowledged with `client.ack(...)` method. + manual_acks: bool, +} + +impl MqttOptions { + /// New mqtt options + pub fn new, T: Into>(id: S, host: T, port: u16) -> MqttOptions { + let id = id.into(); + if id.starts_with(' ') || id.is_empty() { + panic!("Invalid client id") + } + + MqttOptions { + broker_addr: host.into(), + port, + transport: Transport::tcp(), + keep_alive: Duration::from_secs(60), + clean_session: true, + client_id: id, + credentials: None, + max_incoming_packet_size: 10 * 1024, + max_outgoing_packet_size: 10 * 1024, + request_channel_capacity: 10, + max_request_batch: 0, + pending_throttle: Duration::from_micros(0), + inflight: 100, + last_will: None, + conn_timeout: 5, + manual_acks: false, + } + } + + /// Broker address + pub fn broker_address(&self) -> (String, u16) { + (self.broker_addr.clone(), self.port) + } + + pub fn set_last_will(&mut self, will: LastWill) -> &mut Self { + self.last_will = Some(will); + self + } + + pub fn last_will(&self) -> Option { + self.last_will.clone() + } + + pub fn set_transport(&mut self, transport: Transport) -> &mut Self { + self.transport = transport; + self + } + + pub fn transport(&self) -> Transport { + self.transport.clone() + } + + /// Set number of seconds after which client should ping the broker + /// if there is no other data exchange + pub fn set_keep_alive(&mut self, duration: Duration) -> &mut Self { + if duration.as_secs() < 5 { + panic!("Keep alives should be >= 5 secs"); + } + + self.keep_alive = duration; + self + } + + /// Keep alive time + pub fn keep_alive(&self) -> Duration { + self.keep_alive + } + + /// Client identifier + pub fn client_id(&self) -> String { + self.client_id.clone() + } + + /// Set packet size limit for outgoing an incoming packets + pub fn set_max_packet_size(&mut self, incoming: usize, outgoing: usize) -> &mut Self { + self.max_incoming_packet_size = incoming; + self.max_outgoing_packet_size = outgoing; + self + } + + /// Maximum packet size + pub fn max_packet_size(&self) -> usize { + self.max_incoming_packet_size + } + + /// `clean_session = true` removes all the state from queues & instructs the broker + /// to clean all the client state when client disconnects. + /// + /// When set `false`, broker will hold the client state and performs pending + /// operations on the client when reconnection with same `client_id` + /// happens. Local queue state is also held to retransmit packets after reconnection. + pub fn set_clean_session(&mut self, clean_session: bool) -> &mut Self { + self.clean_session = clean_session; + self + } + + /// Clean session + pub fn clean_session(&self) -> bool { + self.clean_session + } + + /// Username and password + pub fn set_credentials, P: Into>(&mut self, username: U, password: P) -> &mut Self { + self.credentials = Some((username.into(), password.into())); + self + } + + /// Security options + pub fn credentials(&self) -> Option<(String, String)> { + self.credentials.clone() + } + + /// Set request channel capacity + pub fn set_request_channel_capacity(&mut self, capacity: usize) -> &mut Self { + self.request_channel_capacity = capacity; + self + } + + /// Request channel capacity + pub fn request_channel_capacity(&self) -> usize { + self.request_channel_capacity + } + + /// Enables throttling and sets outoing message rate to the specified 'rate' + pub fn set_pending_throttle(&mut self, duration: Duration) -> &mut Self { + self.pending_throttle = duration; + self + } + + /// Outgoing message rate + pub fn pending_throttle(&self) -> Duration { + self.pending_throttle + } + + /// Set number of concurrent in flight messages + pub fn set_inflight(&mut self, inflight: u16) -> &mut Self { + if inflight == 0 { + panic!("zero in flight is not allowed") + } + + self.inflight = inflight; + self + } + + /// Number of concurrent in flight messages + pub fn inflight(&self) -> u16 { + self.inflight + } + + /// set connection timeout in secs + pub fn set_connection_timeout(&mut self, timeout: u64) -> &mut Self { + self.conn_timeout = timeout; + self + } + + /// get timeout in secs + pub fn connection_timeout(&self) -> u64 { + self.conn_timeout + } + + /// set manual acknowledgements + pub fn set_manual_acks(&mut self, manual_acks: bool) -> &mut Self { + self.manual_acks = manual_acks; + self + } + + /// get manual acknowledgements + pub fn manual_acks(&self) -> bool { + self.manual_acks + } +} + +#[cfg(feature = "url")] +#[derive(Debug, PartialEq, thiserror::Error)] +pub enum OptionError { + #[error("Unsupported URL scheme.")] + Scheme, + + #[error("Missing client ID.")] + ClientId, + + #[error("Invalid keep-alive value.")] + KeepAlive, + + #[error("Invalid clean-session value.")] + CleanSession, + + #[error("Invalid max-incoming-packet-size value.")] + MaxIncomingPacketSize, + + #[error("Invalid max-outgoing-packet-size value.")] + MaxOutgoingPacketSize, + + #[error("Invalid request-channel-capacity value.")] + RequestChannelCapacity, + + #[error("Invalid max-request-batch value.")] + MaxRequestBatch, + + #[error("Invalid pending-throttle value.")] + PendingThrottle, + + #[error("Invalid inflight value.")] + Inflight, + + #[error("Invalid conn-timeout value.")] + ConnTimeout, + + #[error("Unknown option: {0}")] + Unknown(String), +} + +#[cfg(feature = "url")] +impl std::convert::TryFrom for MqttOptions { + type Error = OptionError; + + fn try_from(url: url::Url) -> Result { + use std::collections::HashMap; + + let broker_addr = url.host_str().unwrap_or_default().to_owned(); + + let (transport, default_port) = match url.scheme() { + // Encrypted connections are supported, but require explicit TLS configuration. We fall + // back to the unencrypted transport layer, so that `set_transport` can be used to + // configure the encrypted transport layer with the provided TLS configuration. + "mqtts" | "ssl" => (Transport::Tcp, 8883), + "mqtt" | "tcp" => (Transport::Tcp, 1883), + _ => return Err(OptionError::Scheme), + }; + + let port = url.port().unwrap_or(default_port); + + let mut queries = url.query_pairs().collect::>(); + + let keep_alive = Duration::from_secs( + queries + .remove("keep_alive_secs") + .map(|v| v.parse::().map_err(|_| OptionError::KeepAlive)) + .transpose()? + .unwrap_or(60), + ); + + let client_id = queries + .remove("client_id") + .ok_or(OptionError::ClientId)? + .into_owned(); + + let clean_session = queries + .remove("clean_session") + .map(|v| v.parse::().map_err(|_| OptionError::CleanSession)) + .transpose()? + .unwrap_or(true); + + let credentials = { + match url.username() { + "" => None, + username => Some(( + username.to_owned(), + url.password().unwrap_or_default().to_owned(), + )), + } + }; + + let max_incoming_packet_size = queries + .remove("max_incoming_packet_size_bytes") + .map(|v| { + v.parse::() + .map_err(|_| OptionError::MaxIncomingPacketSize) + }) + .transpose()? + .unwrap_or(10 * 1024); + + let max_outgoing_packet_size = queries + .remove("max_outgoing_packet_size_bytes") + .map(|v| { + v.parse::() + .map_err(|_| OptionError::MaxOutgoingPacketSize) + }) + .transpose()? + .unwrap_or(10 * 1024); + + let request_channel_capacity = queries + .remove("request_channel_capacity_num") + .map(|v| { + v.parse::() + .map_err(|_| OptionError::RequestChannelCapacity) + }) + .transpose()? + .unwrap_or(10); + + let max_request_batch = queries + .remove("max_request_batch_num") + .map(|v| v.parse::().map_err(|_| OptionError::MaxRequestBatch)) + .transpose()? + .unwrap_or(0); + + let pending_throttle = Duration::from_micros( + queries + .remove("pending_throttle_usecs") + .map(|v| v.parse::().map_err(|_| OptionError::PendingThrottle)) + .transpose()? + .unwrap_or(0), + ); + + let inflight = queries + .remove("inflight_num") + .map(|v| v.parse::().map_err(|_| OptionError::Inflight)) + .transpose()? + .unwrap_or(100); + + let conn_timeout = queries + .remove("conn_timeout_secs") + .map(|v| v.parse::().map_err(|_| OptionError::ConnTimeout)) + .transpose()? + .unwrap_or(5); + + if let Some((opt, _)) = queries.into_iter().next() { + return Err(OptionError::Unknown(opt.into_owned())); + } + + Ok(Self { + broker_addr, + port, + transport, + keep_alive, + clean_session, + client_id, + credentials, + max_incoming_packet_size, + max_outgoing_packet_size, + request_channel_capacity, + max_request_batch, + pending_throttle, + inflight, + last_will: None, + conn_timeout, + manual_acks: false + }) + } +} + +// Implement Debug manually because ClientConfig doesn't implement it, so derive(Debug) doesn't +// work. +impl Debug for MqttOptions { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("MqttOptions") + .field("broker_addr", &self.broker_addr) + .field("port", &self.port) + .field("keep_alive", &self.keep_alive) + .field("clean_session", &self.clean_session) + .field("client_id", &self.client_id) + .field("credentials", &self.credentials) + .field("max_packet_size", &self.max_incoming_packet_size) + .field("request_channel_capacity", &self.request_channel_capacity) + .field("max_request_batch", &self.max_request_batch) + .field("pending_throttle", &self.pending_throttle) + .field("inflight", &self.inflight) + .field("last_will", &self.last_will) + .field("conn_timeout", &self.conn_timeout) + .field("manual_acks", &self.manual_acks) + .finish() + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + #[should_panic] + fn client_id_startswith_space() { + let _mqtt_opts = MqttOptions::new(" client_a", "127.0.0.1", 1883).set_clean_session(true); + } + + #[test] + #[cfg(feature = "websocket")] + fn no_scheme() { + let mut _mqtt_opts = MqttOptions::new("client_a", "a3f8czas.iot.eu-west-1.amazonaws.com/mqtt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=MyCreds%2F20201001%2Feu-west-1%2Fiotdevicegateway%2Faws4_request&X-Amz-Date=20201001T130812Z&X-Amz-Expires=7200&X-Amz-Signature=9ae09b49896f44270f2707551581953e6cac71a4ccf34c7c3415555be751b2d1&X-Amz-SignedHeaders=host", 443); + + _mqtt_opts.set_transport(crate::Transport::wss(Vec::from("Test CA"), None, None)); + + if let crate::Transport::Wss(TlsConfiguration::Simple { + ca, + client_auth, + alpn, + }) = _mqtt_opts.transport + { + assert_eq!(ca, Vec::from("Test CA")); + assert_eq!(client_auth, None); + assert_eq!(alpn, None); + } else { + panic!("Unexpected transport!"); + } + + assert_eq!(_mqtt_opts.broker_addr, "a3f8czas.iot.eu-west-1.amazonaws.com/mqtt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=MyCreds%2F20201001%2Feu-west-1%2Fiotdevicegateway%2Faws4_request&X-Amz-Date=20201001T130812Z&X-Amz-Expires=7200&X-Amz-Signature=9ae09b49896f44270f2707551581953e6cac71a4ccf34c7c3415555be751b2d1&X-Amz-SignedHeaders=host"); + } + + #[test] + #[cfg(feature = "url")] + fn from_url() { + use std::convert::TryInto; + use std::str::FromStr; + + fn opt(s: &str) -> Result { + url::Url::from_str(s).expect("valid url").try_into() + } + fn ok(s: &str) -> MqttOptions { + opt(s).expect("valid options") + } + fn err(s: &str) -> OptionError { + opt(s).expect_err("invalid options") + } + + let v = ok("mqtt://host:42?client_id=foo"); + assert_eq!(v.broker_address(), ("host".to_owned(), 42)); + assert_eq!(v.client_id(), "foo".to_owned()); + + let v = ok("mqtt://host:42?client_id=foo&keep_alive_secs=5"); + assert_eq!(v.keep_alive, Duration::from_secs(5)); + + assert_eq!(err("mqtt://host:42"), OptionError::ClientId); + assert_eq!( + err("mqtt://host:42?client_id=foo&foo=bar"), + OptionError::Unknown("foo".to_owned()) + ); + assert_eq!(err("mqt://host:42?client_id=foo"), OptionError::Scheme); + assert_eq!( + err("mqtt://host:42?client_id=foo&keep_alive_secs=foo"), + OptionError::KeepAlive + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&clean_session=foo"), + OptionError::CleanSession + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&max_incoming_packet_size_bytes=foo"), + OptionError::MaxIncomingPacketSize + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&max_outgoing_packet_size_bytes=foo"), + OptionError::MaxOutgoingPacketSize + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&request_channel_capacity_num=foo"), + OptionError::RequestChannelCapacity + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&max_request_batch_num=foo"), + OptionError::MaxRequestBatch + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&pending_throttle_usecs=foo"), + OptionError::PendingThrottle + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&inflight_num=foo"), + OptionError::Inflight + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&conn_timeout_secs=foo"), + OptionError::ConnTimeout + ); + } + + #[test] + #[should_panic] + fn no_client_id() { + let _mqtt_opts = MqttOptions::new("", "127.0.0.1", 1883).set_clean_session(true); + } +} diff --git a/rumqttc/src/v5/packet.rs b/rumqttc/src/v5/packet.rs new file mode 100644 index 000000000..e69de29bb diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs new file mode 100644 index 000000000..c51116ccf --- /dev/null +++ b/rumqttc/src/v5/state.rs @@ -0,0 +1,781 @@ +use crate::v4::{Event, Incoming, Outgoing, Request}; + +use bytes::BytesMut; +use mqttbytes::v4::*; +use mqttbytes::*; +use std::collections::VecDeque; +use std::{io, mem, time::Instant}; + +/// Errors during state handling +#[derive(Debug, thiserror::Error)] +pub enum StateError { + /// Io Error while state is passed to network + #[error("Io error {0:?}")] + Io(#[from] io::Error), + /// Broker's error reply to client's connect packet + #[error("Connect return code `{0:?}`")] + Connect(ConnectReturnCode), + /// Invalid state for a given operation + #[error("Invalid state for a given operation")] + InvalidState, + /// Received a packet (ack) which isn't asked for + #[error("Received unsolicited ack pkid {0}")] + Unsolicited(u16), + /// Last pingreq isn't acked + #[error("Last pingreq isn't acked")] + AwaitPingResp, + /// Received a wrong packet while waiting for another packet + #[error("Received a wrong packet while waiting for another packet")] + WrongPacket, + #[error("Timeout while waiting to resolve collision")] + CollisionTimeout, + #[error("Mqtt serialization/deserialization error")] + Deserialization(mqttbytes::Error), +} + +impl From for StateError { + fn from(e: mqttbytes::Error) -> StateError { + StateError::Deserialization(e) + } +} + +/// State of the mqtt connection. +// Design: Methods will just modify the state of the object without doing any network operations +// Design: All inflight queues are maintained in a pre initialized vec with index as packet id. +// This is done for 2 reasons +// Bad acks or out of order acks aren't O(n) causing cpu spikes +// Any missing acks from the broker are detected during the next recycled use of packet ids +#[derive(Debug, Clone)] +pub struct MqttState { + /// Status of last ping + pub await_pingresp: bool, + /// Collision ping count. Collisions stop user requests + /// which inturn trigger pings. Multiple pings without + /// resolving collisions will result in error + pub collision_ping_count: usize, + /// Last incoming packet time + last_incoming: Instant, + /// Last outgoing packet time + last_outgoing: Instant, + /// Packet id of the last outgoing packet + pub(crate) last_pkid: u16, + /// Number of outgoing inflight publishes + pub(crate) inflight: u16, + /// Maximum number of allowed inflight + pub(crate) max_inflight: u16, + /// Outgoing QoS 1, 2 publishes which aren't acked yet + pub(crate) outgoing_pub: Vec>, + /// Packet ids of released QoS 2 publishes + pub(crate) outgoing_rel: Vec>, + /// Packet ids on incoming QoS 2 publishes + pub(crate) incoming_pub: Vec>, + /// Last collision due to broker not acking in order + pub collision: Option, + /// Buffered incoming packets + pub events: VecDeque, + /// Write buffer + pub write: BytesMut, + /// Indicates if acknowledgements should be send immediately + pub manual_acks: bool, +} + +impl MqttState { + /// Creates new mqtt state. Same state should be used during a + /// connection for persistent sessions while new state should + /// instantiated for clean sessions + pub fn new(max_inflight: u16, manual_acks: bool) -> Self { + MqttState { + await_pingresp: false, + collision_ping_count: 0, + last_incoming: Instant::now(), + last_outgoing: Instant::now(), + last_pkid: 0, + inflight: 0, + max_inflight, + // index 0 is wasted as 0 is not a valid packet id + outgoing_pub: vec![None; max_inflight as usize + 1], + outgoing_rel: vec![None; max_inflight as usize + 1], + incoming_pub: vec![None; std::u16::MAX as usize + 1], + collision: None, + // TODO: Optimize these sizes later + events: VecDeque::with_capacity(100), + write: BytesMut::with_capacity(10 * 1024), + manual_acks + } + } + + /// Returns inflight outgoing packets and clears internal queues + pub fn clean(&mut self) -> Vec { + let mut pending = Vec::with_capacity(100); + // remove and collect pending publishes + for publish in self.outgoing_pub.iter_mut() { + if let Some(publish) = publish.take() { + let request = Request::Publish(publish); + pending.push(request); + } + } + + // remove and collect pending releases + for rel in self.outgoing_rel.iter_mut() { + if let Some(pkid) = rel.take() { + let request = Request::PubRel(PubRel::new(pkid)); + pending.push(request); + } + } + + // remove packed ids of incoming qos2 publishes + for id in self.incoming_pub.iter_mut() { + id.take(); + } + + self.await_pingresp = false; + self.collision_ping_count = 0; + self.inflight = 0; + pending + } + + pub fn inflight(&self) -> u16 { + self.inflight + } + + /// Consolidates handling of all outgoing mqtt packet logic. Returns a packet which should + /// be put on to the network by the eventloop + pub fn handle_outgoing_packet(&mut self, request: Request) -> Result<(), StateError> { + match request { + Request::Publish(publish) => self.outgoing_publish(publish)?, + Request::PubRel(pubrel) => self.outgoing_pubrel(pubrel)?, + Request::Subscribe(subscribe) => self.outgoing_subscribe(subscribe)?, + Request::Unsubscribe(unsubscribe) => self.outgoing_unsubscribe(unsubscribe)?, + Request::PingReq => self.outgoing_ping()?, + Request::Disconnect => self.outgoing_disconnect()?, + Request::PubAck(puback) => self.outgoing_puback(puback)?, + Request::PubRec(pubrec) => self.outgoing_pubrec(pubrec)?, + _ => unimplemented!(), + }; + + self.last_outgoing = Instant::now(); + Ok(()) + } + + /// Consolidates handling of all incoming mqtt packets. Returns a `Notification` which for the + /// user to consume and `Packet` which for the eventloop to put on the network + /// E.g For incoming QoS1 publish packet, this method returns (Publish, Puback). Publish packet will + /// be forwarded to user and Pubck packet will be written to network + pub fn handle_incoming_packet(&mut self, packet: Incoming) -> Result<(), StateError> { + let out = match &packet { + Incoming::PingResp => self.handle_incoming_pingresp(), + Incoming::Publish(publish) => self.handle_incoming_publish(publish), + Incoming::SubAck(_suback) => self.handle_incoming_suback(), + Incoming::UnsubAck(_unsuback) => self.handle_incoming_unsuback(), + Incoming::PubAck(puback) => self.handle_incoming_puback(puback), + Incoming::PubRec(pubrec) => self.handle_incoming_pubrec(pubrec), + Incoming::PubRel(pubrel) => self.handle_incoming_pubrel(pubrel), + Incoming::PubComp(pubcomp) => self.handle_incoming_pubcomp(pubcomp), + _ => { + error!("Invalid incoming packet = {:?}", packet); + return Err(StateError::WrongPacket); + } + }; + + out?; + self.events.push_back(Event::Incoming(packet)); + self.last_incoming = Instant::now(); + Ok(()) + } + + fn handle_incoming_suback(&mut self) -> Result<(), StateError> { + Ok(()) + } + + fn handle_incoming_unsuback(&mut self) -> Result<(), StateError> { + Ok(()) + } + + /// Results in a publish notification in all the QoS cases. Replys with an ack + /// in case of QoS1 and Replys rec in case of QoS while also storing the message + fn handle_incoming_publish(&mut self, publish: &Publish) -> Result<(), StateError> { + let qos = publish.qos; + + match qos { + QoS::AtMostOnce => Ok(()), + QoS::AtLeastOnce => { + if !self.manual_acks { + let puback = PubAck::new(publish.pkid); + self.outgoing_puback(puback)? + } + Ok(()) + } + QoS::ExactlyOnce => { + let pkid = publish.pkid; + self.incoming_pub[pkid as usize] = Some(pkid); + if !self.manual_acks { + let pubrec = PubRec::new(pkid); + self.outgoing_pubrec(pubrec)?; + } + Ok(()) + } + } + } + + fn handle_incoming_puback(&mut self, puback: &PubAck) -> Result<(), StateError> { + let v = match mem::replace(&mut self.outgoing_pub[puback.pkid as usize], None) { + Some(_) => { + self.inflight -= 1; + Ok(()) + } + None => { + error!("Unsolicited puback packet: {:?}", puback.pkid); + Err(StateError::Unsolicited(puback.pkid)) + } + }; + + if let Some(publish) = self.check_collision(puback.pkid) { + self.outgoing_pub[publish.pkid as usize] = Some(publish.clone()); + self.inflight += 1; + + publish.write(&mut self.write)?; + let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); + self.events.push_back(event); + self.collision_ping_count = 0; + } + + v + } + + fn handle_incoming_pubrec(&mut self, pubrec: &PubRec) -> Result<(), StateError> { + match mem::replace(&mut self.outgoing_pub[pubrec.pkid as usize], None) { + Some(_) => { + // NOTE: Inflight - 1 for qos2 in comp + self.outgoing_rel[pubrec.pkid as usize] = Some(pubrec.pkid); + PubRel::new(pubrec.pkid).write(&mut self.write)?; + + let event = Event::Outgoing(Outgoing::PubRel(pubrec.pkid)); + self.events.push_back(event); + Ok(()) + } + None => { + error!("Unsolicited pubrec packet: {:?}", pubrec.pkid); + Err(StateError::Unsolicited(pubrec.pkid)) + } + } + } + + fn handle_incoming_pubrel(&mut self, pubrel: &PubRel) -> Result<(), StateError> { + match mem::replace(&mut self.incoming_pub[pubrel.pkid as usize], None) { + Some(_) => { + PubComp::new(pubrel.pkid).write(&mut self.write)?; + let event = Event::Outgoing(Outgoing::PubComp(pubrel.pkid)); + self.events.push_back(event); + Ok(()) + } + None => { + error!("Unsolicited pubrel packet: {:?}", pubrel.pkid); + Err(StateError::Unsolicited(pubrel.pkid)) + } + } + } + + fn handle_incoming_pubcomp(&mut self, pubcomp: &PubComp) -> Result<(), StateError> { + if let Some(publish) = self.check_collision(pubcomp.pkid) { + publish.write(&mut self.write)?; + let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); + self.events.push_back(event); + self.collision_ping_count = 0; + } + + match mem::replace(&mut self.outgoing_rel[pubcomp.pkid as usize], None) { + Some(_) => { + self.inflight -= 1; + Ok(()) + } + None => { + error!("Unsolicited pubcomp packet: {:?}", pubcomp.pkid); + Err(StateError::Unsolicited(pubcomp.pkid)) + } + } + } + + fn handle_incoming_pingresp(&mut self) -> Result<(), StateError> { + self.await_pingresp = false; + Ok(()) + } + + /// Adds next packet identifier to QoS 1 and 2 publish packets and returns + /// it buy wrapping publish in packet + fn outgoing_publish(&mut self, mut publish: Publish) -> Result<(), StateError> { + if publish.qos != QoS::AtMostOnce { + if publish.pkid == 0 { + publish.pkid = self.next_pkid(); + } + + let pkid = publish.pkid; + if self + .outgoing_pub + .get(publish.pkid as usize) + .unwrap() + .is_some() + { + info!("Collision on packet id = {:?}", publish.pkid); + self.collision = Some(publish); + let event = Event::Outgoing(Outgoing::AwaitAck(pkid)); + self.events.push_back(event); + return Ok(()); + } + + // if there is an existing publish at this pkid, this implies that broker hasn't acked this + // packet yet. This error is possible only when broker isn't acking sequentially + self.outgoing_pub[pkid as usize] = Some(publish.clone()); + self.inflight += 1; + }; + + debug!( + "Publish. Topic = {}, Pkid = {:?}, Payload Size = {:?}", + publish.topic, + publish.pkid, + publish.payload.len() + ); + + publish.write(&mut self.write)?; + let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); + self.events.push_back(event); + Ok(()) + } + + fn outgoing_pubrel(&mut self, pubrel: PubRel) -> Result<(), StateError> { + let pubrel = self.save_pubrel(pubrel)?; + + debug!("Pubrel. Pkid = {}", pubrel.pkid); + PubRel::new(pubrel.pkid).write(&mut self.write)?; + + let event = Event::Outgoing(Outgoing::PubRel(pubrel.pkid)); + self.events.push_back(event); + Ok(()) + } + + fn outgoing_puback(&mut self, puback: PubAck) -> Result<(), StateError> { + puback.write(&mut self.write)?; + let event = Event::Outgoing(Outgoing::PubAck(puback.pkid)); + self.events.push_back(event); + Ok(()) + } + + fn outgoing_pubrec(&mut self, pubrec: PubRec) -> Result<(), StateError> { + pubrec.write(&mut self.write)?; + let event = Event::Outgoing(Outgoing::PubRec(pubrec.pkid)); + self.events.push_back(event); + Ok(()) + } + + /// check when the last control packet/pingreq packet is received and return + /// the status which tells if keep alive time has exceeded + /// NOTE: status will be checked for zero keepalive times also + fn outgoing_ping(&mut self) -> Result<(), StateError> { + let elapsed_in = self.last_incoming.elapsed(); + let elapsed_out = self.last_outgoing.elapsed(); + + if self.collision.is_some() { + self.collision_ping_count += 1; + if self.collision_ping_count >= 2 { + return Err(StateError::CollisionTimeout); + } + } + + // raise error if last ping didn't receive ack + if self.await_pingresp { + return Err(StateError::AwaitPingResp); + } + + self.await_pingresp = true; + + debug!( + "Pingreq, + last incoming packet before {} millisecs, + last outgoing request before {} millisecs", + elapsed_in.as_millis(), + elapsed_out.as_millis() + ); + + PingReq.write(&mut self.write)?; + let event = Event::Outgoing(Outgoing::PingReq); + self.events.push_back(event); + Ok(()) + } + + fn outgoing_subscribe(&mut self, mut subscription: Subscribe) -> Result<(), StateError> { + let pkid = self.next_pkid(); + subscription.pkid = pkid; + + debug!( + "Subscribe. Topics = {:?}, Pkid = {:?}", + subscription.filters, subscription.pkid + ); + + subscription.write(&mut self.write)?; + let event = Event::Outgoing(Outgoing::Subscribe(subscription.pkid)); + self.events.push_back(event); + Ok(()) + } + + fn outgoing_unsubscribe(&mut self, mut unsub: Unsubscribe) -> Result<(), StateError> { + let pkid = self.next_pkid(); + unsub.pkid = pkid; + + debug!( + "Unsubscribe. Topics = {:?}, Pkid = {:?}", + unsub.topics, unsub.pkid + ); + + unsub.write(&mut self.write)?; + let event = Event::Outgoing(Outgoing::Unsubscribe(unsub.pkid)); + self.events.push_back(event); + Ok(()) + } + + fn outgoing_disconnect(&mut self) -> Result<(), StateError> { + debug!("Disconnect"); + + Disconnect.write(&mut self.write)?; + let event = Event::Outgoing(Outgoing::Disconnect); + self.events.push_back(event); + Ok(()) + } + + fn check_collision(&mut self, pkid: u16) -> Option { + if let Some(publish) = &self.collision { + if publish.pkid == pkid { + return self.collision.take(); + } + } + + None + } + + fn save_pubrel(&mut self, mut pubrel: PubRel) -> Result { + let pubrel = match pubrel.pkid { + // consider PacketIdentifier(0) as uninitialized packets + 0 => { + pubrel.pkid = self.next_pkid(); + pubrel + } + _ => pubrel, + }; + + self.outgoing_rel[pubrel.pkid as usize] = Some(pubrel.pkid); + Ok(pubrel) + } + + /// http://stackoverflow.com/questions/11115364/mqtt-messageid-practical-implementation + /// Packet ids are incremented till maximum set inflight messages and reset to 1 after that. + /// + fn next_pkid(&mut self) -> u16 { + let next_pkid = self.last_pkid + 1; + + // When next packet id is at the edge of inflight queue, + // set await flag. This instructs eventloop to stop + // processing requests until all the inflight publishes + // are acked + if next_pkid == self.max_inflight { + self.last_pkid = 0; + return next_pkid; + } + + self.last_pkid = next_pkid; + next_pkid + } +} + +#[cfg(test)] +mod test { + use super::{MqttState, StateError}; + use crate::v4::{Event, Incoming, MqttOptions, Outgoing, Request}; + use mqttbytes::v4::*; + use mqttbytes::*; + + fn build_outgoing_publish(qos: QoS) -> Publish { + let topic = "hello/world".to_owned(); + let payload = vec![1, 2, 3]; + + let mut publish = Publish::new(topic, QoS::AtLeastOnce, payload); + publish.qos = qos; + publish + } + + fn build_incoming_publish(qos: QoS, pkid: u16) -> Publish { + let topic = "hello/world".to_owned(); + let payload = vec![1, 2, 3]; + + let mut publish = Publish::new(topic, QoS::AtLeastOnce, payload); + publish.pkid = pkid; + publish.qos = qos; + publish + } + + fn build_mqttstate() -> MqttState { + MqttState::new(100, false) + } + + #[test] + fn next_pkid_increments_as_expected() { + let mut mqtt = build_mqttstate(); + + for i in 1..=100 { + let pkid = mqtt.next_pkid(); + + // loops between 0-99. % 100 == 0 implies border + let expected = i % 100; + if expected == 0 { + break; + } + + assert_eq!(expected, pkid); + } + } + + #[test] + fn outgoing_publish_should_set_pkid_and_add_publish_to_queue() { + let mut mqtt = build_mqttstate(); + + // QoS0 Publish + let publish = build_outgoing_publish(QoS::AtMostOnce); + + // QoS 0 publish shouldn't be saved in queue + mqtt.outgoing_publish(publish).unwrap(); + assert_eq!(mqtt.last_pkid, 0); + assert_eq!(mqtt.inflight, 0); + + // QoS1 Publish + let publish = build_outgoing_publish(QoS::AtLeastOnce); + + // Packet id should be set and publish should be saved in queue + mqtt.outgoing_publish(publish.clone()).unwrap(); + assert_eq!(mqtt.last_pkid, 1); + assert_eq!(mqtt.inflight, 1); + + // Packet id should be incremented and publish should be saved in queue + mqtt.outgoing_publish(publish).unwrap(); + assert_eq!(mqtt.last_pkid, 2); + assert_eq!(mqtt.inflight, 2); + + // QoS1 Publish + let publish = build_outgoing_publish(QoS::ExactlyOnce); + + // Packet id should be set and publish should be saved in queue + mqtt.outgoing_publish(publish.clone()).unwrap(); + assert_eq!(mqtt.last_pkid, 3); + assert_eq!(mqtt.inflight, 3); + + // Packet id should be incremented and publish should be saved in queue + mqtt.outgoing_publish(publish).unwrap(); + assert_eq!(mqtt.last_pkid, 4); + assert_eq!(mqtt.inflight, 4); + } + + #[test] + fn incoming_publish_should_be_added_to_queue_correctly() { + let mut mqtt = build_mqttstate(); + + // QoS0, 1, 2 Publishes + let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); + let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); + let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); + + mqtt.handle_incoming_publish(&publish1).unwrap(); + mqtt.handle_incoming_publish(&publish2).unwrap(); + mqtt.handle_incoming_publish(&publish3).unwrap(); + + let pkid = mqtt.incoming_pub[3].unwrap(); + + // only qos2 publish should be add to queue + assert_eq!(pkid, 3); + } + + #[test] + fn incoming_publish_should_be_acked() { + let mut mqtt = build_mqttstate(); + + // QoS0, 1, 2 Publishes + let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); + let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); + let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); + + mqtt.handle_incoming_publish(&publish1).unwrap(); + mqtt.handle_incoming_publish(&publish2).unwrap(); + mqtt.handle_incoming_publish(&publish3).unwrap(); + + if let Event::Outgoing(Outgoing::PubAck(pkid)) = mqtt.events[0] { + assert_eq!(pkid, 2); + } else { + panic!("missing puback") + } + + if let Event::Outgoing(Outgoing::PubRec(pkid)) = mqtt.events[1] { + assert_eq!(pkid, 3); + } else { + panic!("missing PubRec") + } + } + + #[test] + fn incoming_publish_should_not_be_acked_with_manual_acks() { + let mut mqtt = build_mqttstate(); + mqtt.manual_acks = true; + + // QoS0, 1, 2 Publishes + let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); + let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); + let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); + + mqtt.handle_incoming_publish(&publish1).unwrap(); + mqtt.handle_incoming_publish(&publish2).unwrap(); + mqtt.handle_incoming_publish(&publish3).unwrap(); + + let pkid = mqtt.incoming_pub[3].unwrap(); + assert_eq!(pkid, 3); + + assert!(mqtt.events.is_empty()); + } + + #[test] + fn incoming_qos2_publish_should_send_rec_to_network_and_publish_to_user() { + let mut mqtt = build_mqttstate(); + let publish = build_incoming_publish(QoS::ExactlyOnce, 1); + + mqtt.handle_incoming_publish(&publish).unwrap(); + let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + match packet { + Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), + _ => panic!("Invalid network request: {:?}", packet), + } + } + + #[test] + fn incoming_puback_should_remove_correct_publish_from_queue() { + let mut mqtt = build_mqttstate(); + + let publish1 = build_outgoing_publish(QoS::AtLeastOnce); + let publish2 = build_outgoing_publish(QoS::ExactlyOnce); + + mqtt.outgoing_publish(publish1).unwrap(); + mqtt.outgoing_publish(publish2).unwrap(); + assert_eq!(mqtt.inflight, 2); + + mqtt.handle_incoming_puback(&PubAck::new(1)).unwrap(); + assert_eq!(mqtt.inflight, 1); + + mqtt.handle_incoming_puback(&PubAck::new(2)).unwrap(); + assert_eq!(mqtt.inflight, 0); + + assert!(mqtt.outgoing_pub[1].is_none()); + assert!(mqtt.outgoing_pub[2].is_none()); + } + + #[test] + fn incoming_pubrec_should_release_publish_from_queue_and_add_relid_to_rel_queue() { + let mut mqtt = build_mqttstate(); + + let publish1 = build_outgoing_publish(QoS::AtLeastOnce); + let publish2 = build_outgoing_publish(QoS::ExactlyOnce); + + let _publish_out = mqtt.outgoing_publish(publish1); + let _publish_out = mqtt.outgoing_publish(publish2); + + mqtt.handle_incoming_pubrec(&PubRec::new(2)).unwrap(); + assert_eq!(mqtt.inflight, 2); + + // check if the remaining element's pkid is 1 + let backup = mqtt.outgoing_pub[1].clone(); + assert_eq!(backup.unwrap().pkid, 1); + + // check if the qos2 element's release pkid is 2 + assert_eq!(mqtt.outgoing_rel[2].unwrap(), 2); + } + + #[test] + fn incoming_pubrec_should_send_release_to_network_and_nothing_to_user() { + let mut mqtt = build_mqttstate(); + + let publish = build_outgoing_publish(QoS::ExactlyOnce); + mqtt.outgoing_publish(publish).unwrap(); + let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + match packet { + Packet::Publish(publish) => assert_eq!(publish.pkid, 1), + packet => panic!("Invalid network request: {:?}", packet), + } + + mqtt.handle_incoming_pubrec(&PubRec::new(1)).unwrap(); + let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + match packet { + Packet::PubRel(pubrel) => assert_eq!(pubrel.pkid, 1), + packet => panic!("Invalid network request: {:?}", packet), + } + } + + #[test] + fn incoming_pubrel_should_send_comp_to_network_and_nothing_to_user() { + let mut mqtt = build_mqttstate(); + let publish = build_incoming_publish(QoS::ExactlyOnce, 1); + + mqtt.handle_incoming_publish(&publish).unwrap(); + let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + match packet { + Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), + packet => panic!("Invalid network request: {:?}", packet), + } + + mqtt.handle_incoming_pubrel(&PubRel::new(1)).unwrap(); + let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + match packet { + Packet::PubComp(pubcomp) => assert_eq!(pubcomp.pkid, 1), + packet => panic!("Invalid network request: {:?}", packet), + } + } + + #[test] + fn incoming_pubcomp_should_release_correct_pkid_from_release_queue() { + let mut mqtt = build_mqttstate(); + let publish = build_outgoing_publish(QoS::ExactlyOnce); + + mqtt.outgoing_publish(publish).unwrap(); + mqtt.handle_incoming_pubrec(&PubRec::new(1)).unwrap(); + + mqtt.handle_incoming_pubcomp(&PubComp::new(1)).unwrap(); + assert_eq!(mqtt.inflight, 0); + } + + #[test] + fn outgoing_ping_handle_should_throw_errors_for_no_pingresp() { + let mut mqtt = build_mqttstate(); + let mut opts = MqttOptions::new("test", "localhost", 1883); + opts.set_keep_alive(std::time::Duration::from_secs(10)); + mqtt.outgoing_ping().unwrap(); + + // network activity other than pingresp + let publish = build_outgoing_publish(QoS::AtLeastOnce); + mqtt.handle_outgoing_packet(Request::Publish(publish)) + .unwrap(); + mqtt.handle_incoming_packet(Incoming::PubAck(PubAck::new(1))) + .unwrap(); + + // should throw error because we didn't get pingresp for previous ping + match mqtt.outgoing_ping() { + Ok(_) => panic!("Should throw pingresp await error"), + Err(StateError::AwaitPingResp) => (), + Err(e) => panic!("Should throw pingresp await error. Error = {:?}", e), + } + } + + #[test] + fn outgoing_ping_handle_should_succeed_if_pingresp_is_received() { + let mut mqtt = build_mqttstate(); + + let mut opts = MqttOptions::new("test", "localhost", 1883); + opts.set_keep_alive(std::time::Duration::from_secs(10)); + + // should ping + mqtt.outgoing_ping().unwrap(); + mqtt.handle_incoming_packet(Incoming::PingResp).unwrap(); + + // should ping + mqtt.outgoing_ping().unwrap(); + } +} diff --git a/rumqttc/src/v5/tls.rs b/rumqttc/src/v5/tls.rs new file mode 100644 index 000000000..ea1809e23 --- /dev/null +++ b/rumqttc/src/v5/tls.rs @@ -0,0 +1,130 @@ +use tokio::net::TcpStream; +use tokio_rustls::rustls; +use tokio_rustls::rustls::client::InvalidDnsNameError; +use tokio_rustls::rustls::{ + Certificate, ClientConfig, OwnedTrustAnchor, PrivateKey, RootCertStore, ServerName, +}; +use tokio_rustls::webpki; +use tokio_rustls::{client::TlsStream, TlsConnector}; + +use crate::v4::{Key, MqttOptions, TlsConfiguration}; + +use std::convert::TryFrom; +use std::io; +use std::io::{BufReader, Cursor}; +use std::net::AddrParseError; +use std::sync::Arc; + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("Addr")] + Addr(#[from] AddrParseError), + #[error("I/O")] + Io(#[from] io::Error), + #[error("Web Pki")] + WebPki(#[from] webpki::Error), + #[error("DNS name")] + DNSName(#[from] InvalidDnsNameError), + #[error("TLS error")] + TLS(#[from] rustls::Error), + #[error("No valid cert in chain")] + NoValidCertInChain, +} + +// The cert handling functions return unit right now, this is a shortcut +impl From<()> for Error { + fn from(_: ()) -> Self { + Error::NoValidCertInChain + } +} + +pub async fn tls_connector(tls_config: &TlsConfiguration) -> Result { + let config = match tls_config { + TlsConfiguration::Simple { + ca, + alpn, + client_auth, + } => { + // Add ca to root store if the connection is TLS + let mut root_cert_store = RootCertStore::empty(); + let certs = rustls_pemfile::certs(&mut BufReader::new(Cursor::new(ca)))?; + + let trust_anchors = certs.iter().map_while(|cert| { + if let Ok(ta) = webpki::TrustAnchor::try_from_cert_der(&cert[..]) { + Some(OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + )) + } else { + None + } + }); + + root_cert_store.add_server_trust_anchors(trust_anchors); + + if root_cert_store.is_empty() { + return Err(Error::NoValidCertInChain); + } + + let config = ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_cert_store); + + // Add der encoded client cert and key + let mut config = if let Some(client) = client_auth.as_ref() { + let certs = + rustls_pemfile::certs(&mut BufReader::new(Cursor::new(client.0.clone())))?; + // load appropriate Key as per the user request. The underlying signature algorithm + // of key generation determines the Signature Algorithm during the TLS Handskahe. + let read_keys = match &client.1 { + Key::RSA(k) => rustls_pemfile::rsa_private_keys(&mut BufReader::new( + Cursor::new(k.clone()), + )), + Key::ECC(k) => rustls_pemfile::pkcs8_private_keys(&mut BufReader::new( + Cursor::new(k.clone()), + )), + }; + let keys = match read_keys { + Ok(v) => v, + Err(_e) => return Err(Error::NoValidCertInChain), + }; + + // Get the first key. Error if it's not valid + let key = match keys.first() { + Some(k) => k.clone(), + None => return Err(Error::NoValidCertInChain), + }; + + let certs = certs.into_iter().map(|cert| Certificate(cert)).collect(); + + config.with_single_cert(certs, PrivateKey(key))? + } else { + config.with_no_client_auth() + }; + + // Set ALPN + if let Some(alpn) = alpn.as_ref() { + config.alpn_protocols.extend_from_slice(alpn); + } + + Arc::new(config) + } + TlsConfiguration::Rustls(tls_client_config) => tls_client_config.clone(), + }; + + Ok(TlsConnector::from(config)) +} + +pub async fn tls_connect( + options: &MqttOptions, + tls_config: &TlsConfiguration, +) -> Result, Error> { + let addr = options.broker_addr.as_str(); + let port = options.port; + let connector = tls_connector(tls_config).await?; + let domain = ServerName::try_from(addr)?; + let tcp = TcpStream::connect((addr, port)).await?; + let tls = connector.connect(domain, tcp).await?; + Ok(tls) +} From c064f0ee8a9341fe2b40bb9bbbe8d40dff134949 Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Fri, 11 Feb 2022 18:25:44 +0530 Subject: [PATCH 03/38] rumqttc: v5: packet: copy code from mqttbytes/v5 Signed-off-by: Abhik Jain --- rumqttc/Cargo.toml | 1 + rumqttc/src/v5/packet.rs | 0 rumqttc/src/v5/packet/connack.rs | 553 ++++++++++++++++ rumqttc/src/v5/packet/connect.rs | 928 +++++++++++++++++++++++++++ rumqttc/src/v5/packet/disconnect.rs | 434 +++++++++++++ rumqttc/src/v5/packet/mod.rs | 489 ++++++++++++++ rumqttc/src/v5/packet/ping.rs | 20 + rumqttc/src/v5/packet/puback.rs | 324 ++++++++++ rumqttc/src/v5/packet/pubcomp.rs | 237 +++++++ rumqttc/src/v5/packet/publish.rs | 394 ++++++++++++ rumqttc/src/v5/packet/pubrec.rs | 252 ++++++++ rumqttc/src/v5/packet/pubrel.rs | 236 +++++++ rumqttc/src/v5/packet/suback.rs | 263 ++++++++ rumqttc/src/v5/packet/subscribe.rs | 425 ++++++++++++ rumqttc/src/v5/packet/unsuback.rs | 249 +++++++ rumqttc/src/v5/packet/unsubscribe.rs | 238 +++++++ 16 files changed, 5043 insertions(+) delete mode 100644 rumqttc/src/v5/packet.rs create mode 100644 rumqttc/src/v5/packet/connack.rs create mode 100644 rumqttc/src/v5/packet/connect.rs create mode 100644 rumqttc/src/v5/packet/disconnect.rs create mode 100644 rumqttc/src/v5/packet/mod.rs create mode 100644 rumqttc/src/v5/packet/ping.rs create mode 100644 rumqttc/src/v5/packet/puback.rs create mode 100644 rumqttc/src/v5/packet/pubcomp.rs create mode 100644 rumqttc/src/v5/packet/publish.rs create mode 100644 rumqttc/src/v5/packet/pubrec.rs create mode 100644 rumqttc/src/v5/packet/pubrel.rs create mode 100644 rumqttc/src/v5/packet/suback.rs create mode 100644 rumqttc/src/v5/packet/subscribe.rs create mode 100644 rumqttc/src/v5/packet/unsuback.rs create mode 100644 rumqttc/src/v5/packet/unsubscribe.rs diff --git a/rumqttc/Cargo.toml b/rumqttc/Cargo.toml index 7e8a89b89..81a7b05b7 100644 --- a/rumqttc/Cargo.toml +++ b/rumqttc/Cargo.toml @@ -43,3 +43,4 @@ tokio = { version = "1.0", features = ["full", "macros"] } matches = "0.1.8" rustls = "0.20.2" rustls-native-certs = "0.6.1" +pretty_assertions = "1.1.0" diff --git a/rumqttc/src/v5/packet.rs b/rumqttc/src/v5/packet.rs deleted file mode 100644 index e69de29bb..000000000 diff --git a/rumqttc/src/v5/packet/connack.rs b/rumqttc/src/v5/packet/connack.rs new file mode 100644 index 000000000..f0e0ebaee --- /dev/null +++ b/rumqttc/src/v5/packet/connack.rs @@ -0,0 +1,553 @@ +use super::*; +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +/// Return code in connack +#[derive(Debug, Clone, Copy, PartialEq)] +#[repr(u8)] +pub enum ConnectReturnCode { + Success = 0, + UnspecifiedError = 128, + MalformedPacket = 129, + ProtocolError = 130, + ImplementationSpecificError = 131, + UnsupportedProtocolVersion = 132, + ClientIdentifierNotValid = 133, + BadUserNamePassword = 134, + NotAuthorized = 135, + ServerUnavailable = 136, + ServerBusy = 137, + Banned = 138, + BadAuthenticationMethod = 140, + TopicNameInvalid = 144, + PacketTooLarge = 149, + QuotaExceeded = 151, + PayloadFormatInvalid = 153, + RetainNotSupported = 154, + QoSNotSupported = 155, + UseAnotherServer = 156, + ServerMoved = 157, + ConnectionRateExceeded = 159, +} + +/// Acknowledgement to connect packet +#[derive(Debug, Clone, PartialEq)] +pub struct ConnAck { + pub session_present: bool, + pub code: ConnectReturnCode, + pub properties: Option, +} + +impl ConnAck { + pub fn new(code: ConnectReturnCode, session_present: bool) -> ConnAck { + ConnAck { + code, + session_present, + properties: None, + } + } + + fn len(&self) -> usize { + let mut len = 1 // session present + + 1; // code + + if let Some(properties) = &self.properties { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + let flags = read_u8(&mut bytes)?; + let return_code = read_u8(&mut bytes)?; + + let session_present = (flags & 0x01) == 1; + let code = connect_return(return_code)?; + let connack = ConnAck { + session_present, + code, + properties: ConnAckProperties::extract(&mut bytes)?, + }; + + Ok(connack) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + let len = self.len(); + buffer.put_u8(0x20); + + let count = write_remaining_length(buffer, len)?; + buffer.put_u8(self.session_present as u8); + buffer.put_u8(self.code as u8); + + if let Some(properties) = &self.properties { + properties.write(buffer)?; + } + + Ok(1 + count + len) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct ConnAckProperties { + pub session_expiry_interval: Option, + pub receive_max: Option, + pub max_qos: Option, + pub retain_available: Option, + pub max_packet_size: Option, + pub assigned_client_identifier: Option, + pub topic_alias_max: Option, + pub reason_string: Option, + pub user_properties: Vec<(String, String)>, + pub wildcard_subscription_available: Option, + pub subscription_identifiers_available: Option, + pub shared_subscription_available: Option, + pub server_keep_alive: Option, + pub response_information: Option, + pub server_reference: Option, + pub authentication_method: Option, + pub authentication_data: Option, +} + +impl ConnAckProperties { + pub fn new() -> ConnAckProperties { + ConnAckProperties { + session_expiry_interval: None, + receive_max: None, + max_qos: None, + retain_available: None, + max_packet_size: None, + assigned_client_identifier: None, + topic_alias_max: None, + reason_string: None, + user_properties: Vec::new(), + wildcard_subscription_available: None, + subscription_identifiers_available: None, + shared_subscription_available: None, + server_keep_alive: None, + response_information: None, + server_reference: None, + authentication_method: None, + authentication_data: None, + } + } + + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(_) = &self.session_expiry_interval { + len += 1 + 4; + } + + if let Some(_) = &self.receive_max { + len += 1 + 2; + } + + if let Some(_) = &self.max_qos { + len += 1 + 1; + } + + if let Some(_) = &self.retain_available { + len += 1 + 1; + } + + if let Some(_) = &self.max_packet_size { + len += 1 + 4; + } + + if let Some(id) = &self.assigned_client_identifier { + len += 1 + 2 + id.len(); + } + + if let Some(_) = &self.topic_alias_max { + len += 1 + 2; + } + + if let Some(reason) = &self.reason_string { + len += 1 + 2 + reason.len(); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + if let Some(_) = &self.wildcard_subscription_available { + len += 1 + 1; + } + + if let Some(_) = &self.subscription_identifiers_available { + len += 1 + 1; + } + + if let Some(_) = &self.shared_subscription_available { + len += 1 + 1; + } + + if let Some(_) = &self.server_keep_alive { + len += 1 + 2; + } + + if let Some(info) = &self.response_information { + len += 1 + 2 + info.len(); + } + + if let Some(reference) = &self.server_reference { + len += 1 + 2 + reference.len(); + } + + if let Some(authentication_method) = &self.authentication_method { + len += 1 + 2 + authentication_method.len(); + } + + if let Some(authentication_data) = &self.authentication_data { + len += 1 + 2 + authentication_data.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut session_expiry_interval = None; + let mut receive_max = None; + let mut max_qos = None; + let mut retain_available = None; + let mut max_packet_size = None; + let mut assigned_client_identifier = None; + let mut topic_alias_max = None; + let mut reason_string = None; + let mut user_properties = Vec::new(); + let mut wildcard_subscription_available = None; + let mut subscription_identifiers_available = None; + let mut shared_subscription_available = None; + let mut server_keep_alive = None; + let mut response_information = None; + let mut server_reference = None; + let mut authentication_method = None; + let mut authentication_data = None; + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::SessionExpiryInterval => { + session_expiry_interval = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::ReceiveMaximum => { + receive_max = Some(read_u16(&mut bytes)?); + cursor += 2; + } + PropertyType::MaximumQos => { + max_qos = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::RetainAvailable => { + retain_available = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::AssignedClientIdentifier => { + let id = read_mqtt_string(&mut bytes)?; + cursor += 2 + id.len(); + assigned_client_identifier = Some(id); + } + PropertyType::MaximumPacketSize => { + max_packet_size = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::TopicAliasMaximum => { + topic_alias_max = Some(read_u16(&mut bytes)?); + cursor += 2; + } + PropertyType::ReasonString => { + let reason = read_mqtt_string(&mut bytes)?; + cursor += 2 + reason.len(); + reason_string = Some(reason); + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + PropertyType::WildcardSubscriptionAvailable => { + wildcard_subscription_available = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::SubscriptionIdentifierAvailable => { + subscription_identifiers_available = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::SharedSubscriptionAvailable => { + shared_subscription_available = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::ServerKeepAlive => { + server_keep_alive = Some(read_u16(&mut bytes)?); + cursor += 2; + } + PropertyType::ResponseInformation => { + let info = read_mqtt_string(&mut bytes)?; + cursor += 2 + info.len(); + response_information = Some(info); + } + PropertyType::ServerReference => { + let reference = read_mqtt_string(&mut bytes)?; + cursor += 2 + reference.len(); + server_reference = Some(reference); + } + PropertyType::AuthenticationMethod => { + let method = read_mqtt_string(&mut bytes)?; + cursor += 2 + method.len(); + authentication_method = Some(method); + } + PropertyType::AuthenticationData => { + let data = read_mqtt_bytes(&mut bytes)?; + cursor += 2 + data.len(); + authentication_data = Some(data); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(ConnAckProperties { + session_expiry_interval, + receive_max, + max_qos, + retain_available, + max_packet_size, + assigned_client_identifier, + topic_alias_max, + reason_string, + user_properties, + wildcard_subscription_available, + subscription_identifiers_available, + shared_subscription_available, + server_keep_alive, + response_information, + server_reference, + authentication_method, + authentication_data, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(session_expiry_interval) = self.session_expiry_interval { + buffer.put_u8(PropertyType::SessionExpiryInterval as u8); + buffer.put_u32(session_expiry_interval); + } + + if let Some(receive_maximum) = self.receive_max { + buffer.put_u8(PropertyType::ReceiveMaximum as u8); + buffer.put_u16(receive_maximum); + } + + if let Some(qos) = self.max_qos { + buffer.put_u8(PropertyType::MaximumQos as u8); + buffer.put_u8(qos); + } + + if let Some(retain_available) = self.retain_available { + buffer.put_u8(PropertyType::RetainAvailable as u8); + buffer.put_u8(retain_available); + } + + if let Some(max_packet_size) = self.max_packet_size { + buffer.put_u8(PropertyType::MaximumPacketSize as u8); + buffer.put_u32(max_packet_size); + } + + if let Some(id) = &self.assigned_client_identifier { + buffer.put_u8(PropertyType::AssignedClientIdentifier as u8); + write_mqtt_string(buffer, id); + } + + if let Some(topic_alias_max) = self.topic_alias_max { + buffer.put_u8(PropertyType::TopicAliasMaximum as u8); + buffer.put_u16(topic_alias_max); + } + + if let Some(reason) = &self.reason_string { + buffer.put_u8(PropertyType::ReasonString as u8); + write_mqtt_string(buffer, reason); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + if let Some(w) = self.wildcard_subscription_available { + buffer.put_u8(PropertyType::WildcardSubscriptionAvailable as u8); + buffer.put_u8(w); + } + + if let Some(s) = self.subscription_identifiers_available { + buffer.put_u8(PropertyType::SubscriptionIdentifierAvailable as u8); + buffer.put_u8(s); + } + + if let Some(s) = self.shared_subscription_available { + buffer.put_u8(PropertyType::SharedSubscriptionAvailable as u8); + buffer.put_u8(s); + } + + if let Some(keep_alive) = self.server_keep_alive { + buffer.put_u8(PropertyType::ServerKeepAlive as u8); + buffer.put_u16(keep_alive); + } + + if let Some(info) = &self.response_information { + buffer.put_u8(PropertyType::ResponseInformation as u8); + write_mqtt_string(buffer, info); + } + + if let Some(reference) = &self.server_reference { + buffer.put_u8(PropertyType::ServerReference as u8); + write_mqtt_string(buffer, reference); + } + + if let Some(authentication_method) = &self.authentication_method { + buffer.put_u8(PropertyType::AuthenticationMethod as u8); + write_mqtt_string(buffer, authentication_method); + } + + if let Some(authentication_data) = &self.authentication_data { + buffer.put_u8(PropertyType::AuthenticationData as u8); + write_mqtt_bytes(buffer, authentication_data); + } + + Ok(()) + } +} + +/// Connection return code type +fn connect_return(num: u8) -> Result { + let code = match num { + 0 => ConnectReturnCode::Success, + 128 => ConnectReturnCode::UnspecifiedError, + 129 => ConnectReturnCode::MalformedPacket, + 130 => ConnectReturnCode::ProtocolError, + 131 => ConnectReturnCode::ImplementationSpecificError, + 132 => ConnectReturnCode::UnsupportedProtocolVersion, + 133 => ConnectReturnCode::ClientIdentifierNotValid, + 134 => ConnectReturnCode::BadUserNamePassword, + 135 => ConnectReturnCode::NotAuthorized, + 136 => ConnectReturnCode::ServerUnavailable, + 137 => ConnectReturnCode::ServerBusy, + 138 => ConnectReturnCode::Banned, + 140 => ConnectReturnCode::BadAuthenticationMethod, + 144 => ConnectReturnCode::TopicNameInvalid, + 149 => ConnectReturnCode::PacketTooLarge, + 151 => ConnectReturnCode::QuotaExceeded, + 153 => ConnectReturnCode::PayloadFormatInvalid, + 154 => ConnectReturnCode::RetainNotSupported, + 155 => ConnectReturnCode::QoSNotSupported, + 156 => ConnectReturnCode::UseAnotherServer, + 157 => ConnectReturnCode::ServerMoved, + 159 => ConnectReturnCode::ConnectionRateExceeded, + num => return Err(Error::InvalidConnectReturnCode(num)), + }; + + Ok(code) +} + +#[cfg(test)] +mod test { + use super::*; + use bytes::{Bytes, BytesMut}; + use pretty_assertions::assert_eq; + + fn sample() -> ConnAck { + let properties = ConnAckProperties { + session_expiry_interval: Some(1234), + receive_max: Some(432), + max_qos: Some(2), + retain_available: Some(1), + max_packet_size: Some(100), + assigned_client_identifier: Some("test".to_owned()), + topic_alias_max: Some(456), + reason_string: Some("test".to_owned()), + user_properties: vec![("test".to_owned(), "test".to_owned())], + wildcard_subscription_available: Some(1), + subscription_identifiers_available: Some(1), + shared_subscription_available: Some(0), + server_keep_alive: Some(1234), + response_information: Some("test".to_owned()), + server_reference: Some("test".to_owned()), + authentication_method: Some("test".to_owned()), + authentication_data: Some(Bytes::from(vec![1, 2, 3, 4])), + }; + + ConnAck { + session_present: false, + code: ConnectReturnCode::Success, + properties: Some(properties), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0x20, // Packet type + 0x57, // Remaining length + 0x00, 0x00, // Session, code + 0x54, // Properties length + 0x11, 0x00, 0x00, 0x04, 0xd2, // Session expiry interval + 0x21, 0x01, 0xb0, // Receive maximum + 0x24, 0x02, // Maximum qos + 0x25, 0x01, // Retain available + 0x27, 0x00, 0x00, 0x00, 0x64, // Maximum packet size + 0x12, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // Assigned client identifier + 0x22, 0x01, 0xc8, // Topic alias max + 0x1f, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // Reason string + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // user properties + 0x28, 0x01, // wildcard_subscription_available + 0x29, 0x01, // subscription_identifiers_available + 0x2a, 0x00, // shared_subscription_available + 0x13, 0x04, 0xd2, // server keep_alive + 0x1a, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // response_information + 0x1c, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // server reference + 0x15, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // authentication method + 0x16, 0x00, 0x04, 0x01, 0x02, 0x03, 0x04, // authentication data + ] + } + + #[test] + fn connack_parsing_works() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let connack_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let connack = ConnAck::read(fixed_header, connack_bytes).unwrap(); + + assert_eq!(connack, sample()); + } + + #[test] + fn connack_encoding_works() { + let connack = sample(); + let mut buf = BytesMut::new(); + connack.write(&mut buf).unwrap(); + assert_eq!(&buf[..], sample_bytes()); + } +} diff --git a/rumqttc/src/v5/packet/connect.rs b/rumqttc/src/v5/packet/connect.rs new file mode 100644 index 000000000..1c630ae79 --- /dev/null +++ b/rumqttc/src/v5/packet/connect.rs @@ -0,0 +1,928 @@ +use super::*; +use bytes::{Buf, Bytes}; + +/// Connection packet initiated by the client +#[derive(Debug, Clone, PartialEq)] +pub struct Connect { + /// Mqtt protocol version + pub protocol: Protocol, + /// Mqtt keep alive time + pub keep_alive: u16, + /// Client Id + pub client_id: String, + /// Clean session. Asks the broker to clear previous state + pub clean_session: bool, + /// Will that broker needs to publish when the client disconnects + pub last_will: Option, + /// Login credentials + pub login: Option, + /// Properties + pub properties: Option, +} + +impl Connect { + pub fn new>(id: S) -> Connect { + Connect { + protocol: Protocol::V5, + keep_alive: 10, + properties: None, + client_id: id.into(), + clean_session: true, + last_will: None, + login: None, + } + } + + pub fn set_login, P: Into>(&mut self, u: U, p: P) -> &mut Connect { + let login = Login { + username: u.into(), + password: p.into(), + }; + + self.login = Some(login); + self + } + + pub fn len(&self) -> usize { + let mut len = 2 + "MQTT".len() // protocol name + + 1 // protocol version + + 1 // connect flags + + 2; // keep alive + + match &self.properties { + Some(properties) => { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + None => { + // just 1 byte representing 0 len + len += 1; + } + } + + len += 2 + self.client_id.len(); + + // last will len + if let Some(last_will) = &self.last_will { + len += last_will.len(); + } + + // username and password len + if let Some(login) = &self.login { + len += login.len(); + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + // Variable header + let protocol_name = read_mqtt_string(&mut bytes)?; + let protocol_level = read_u8(&mut bytes)?; + if protocol_name != "MQTT" { + return Err(Error::InvalidProtocol); + } + + let protocol = match protocol_level { + 4 => Protocol::V4, + 5 => Protocol::V5, + num => return Err(Error::InvalidProtocolLevel(num)), + }; + + let connect_flags = read_u8(&mut bytes)?; + let clean_session = (connect_flags & 0b10) != 0; + let keep_alive = read_u16(&mut bytes)?; + + // Properties in variable header + let properties = match protocol { + Protocol::V5 => ConnectProperties::read(&mut bytes)?, + Protocol::V4 => None, + }; + + let client_id = read_mqtt_string(&mut bytes)?; + let last_will = LastWill::read(connect_flags, &mut bytes)?; + let login = Login::read(connect_flags, &mut bytes)?; + + let connect = Connect { + protocol, + keep_alive, + client_id, + clean_session, + last_will, + login, + properties, + }; + + Ok(connect) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + let len = self.len(); + buffer.put_u8(0b0001_0000); + let count = write_remaining_length(buffer, len)?; + write_mqtt_string(buffer, "MQTT"); + + match self.protocol { + Protocol::V4 => buffer.put_u8(0x04), + Protocol::V5 => buffer.put_u8(0x05), + } + + let flags_index = 1 + count + 2 + 4 + 1; + + let mut connect_flags = 0; + if self.clean_session { + connect_flags |= 0x02; + } + + buffer.put_u8(connect_flags); + buffer.put_u16(self.keep_alive); + + match &self.properties { + Some(properties) => properties.write(buffer)?, + None => { + write_remaining_length(buffer, 0)?; + } + }; + + write_mqtt_string(buffer, &self.client_id); + + if let Some(last_will) = &self.last_will { + connect_flags |= last_will.write(buffer)?; + } + + if let Some(login) = &self.login { + connect_flags |= login.write(buffer); + } + + // update connect flags + buffer[flags_index] = connect_flags; + Ok(len) + } +} + +/// LastWill that broker forwards on behalf of the client +#[derive(Debug, Clone, PartialEq)] +pub struct LastWill { + pub topic: String, + pub message: Bytes, + pub qos: QoS, + pub retain: bool, + pub properties: Option, +} + +impl LastWill { + pub fn new( + topic: impl Into, + payload: impl Into>, + qos: QoS, + retain: bool, + ) -> LastWill { + LastWill { + topic: topic.into(), + message: Bytes::from(payload.into()), + qos, + retain, + properties: None, + } + } + + fn len(&self) -> usize { + let mut len = 0; + + match &self.properties { + Some(properties) => { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + None => { + // just 1 byte representing 0 len + len += 1; + } + }; + + len += 2 + self.topic.len() + 2 + self.message.len(); + len + } + + fn read(connect_flags: u8, mut bytes: &mut Bytes) -> Result, Error> { + let last_will = match connect_flags & 0b100 { + 0 if (connect_flags & 0b0011_1000) != 0 => { + return Err(Error::IncorrectPacketFormat); + } + 0 => None, + _ => { + // Properties in variable header + let properties = WillProperties::read(&mut bytes)?; + + let will_topic = read_mqtt_string(&mut bytes)?; + let will_message = read_mqtt_bytes(&mut bytes)?; + let will_qos = qos((connect_flags & 0b11000) >> 3)?; + Some(LastWill { + topic: will_topic, + message: will_message, + qos: will_qos, + retain: (connect_flags & 0b0010_0000) != 0, + properties, + }) + } + }; + + Ok(last_will) + } + + fn write(&self, buffer: &mut BytesMut) -> Result { + let mut connect_flags = 0; + + connect_flags |= 0x04 | (self.qos as u8) << 3; + if self.retain { + connect_flags |= 0x20; + } + + match &self.properties { + Some(properties) => properties.write(buffer)?, + None => { + write_remaining_length(buffer, 0)?; + } + }; + + write_mqtt_string(buffer, &self.topic); + write_mqtt_bytes(buffer, &self.message); + Ok(connect_flags) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct WillProperties { + pub delay_interval: Option, + pub payload_format_indicator: Option, + pub message_expiry_interval: Option, + pub content_type: Option, + pub response_topic: Option, + pub correlation_data: Option, + pub user_properties: Vec<(String, String)>, +} + +impl WillProperties { + fn len(&self) -> usize { + let mut len = 0; + + if self.delay_interval.is_some() { + len += 1 + 4; + } + + if self.payload_format_indicator.is_some() { + len += 1 + 1; + } + + if self.message_expiry_interval.is_some() { + len += 1 + 4; + } + + if let Some(typ) = &self.content_type { + len += 1 + 2 + typ.len() + } + + if let Some(topic) = &self.response_topic { + len += 1 + 2 + topic.len() + } + + if let Some(data) = &self.correlation_data { + len += 1 + 2 + data.len() + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + fn read(mut bytes: &mut Bytes) -> Result, Error> { + let mut delay_interval = None; + let mut payload_format_indicator = None; + let mut message_expiry_interval = None; + let mut content_type = None; + let mut response_topic = None; + let mut correlation_data = None; + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::WillDelayInterval => { + delay_interval = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::PayloadFormatIndicator => { + payload_format_indicator = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::MessageExpiryInterval => { + message_expiry_interval = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::ContentType => { + let typ = read_mqtt_string(&mut bytes)?; + cursor += 2 + typ.len(); + content_type = Some(typ); + } + PropertyType::ResponseTopic => { + let topic = read_mqtt_string(&mut bytes)?; + cursor += 2 + topic.len(); + response_topic = Some(topic); + } + PropertyType::CorrelationData => { + let data = read_mqtt_bytes(&mut bytes)?; + cursor += 2 + data.len(); + correlation_data = Some(data); + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(WillProperties { + delay_interval, + payload_format_indicator, + message_expiry_interval, + content_type, + response_topic, + correlation_data, + user_properties, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(delay_interval) = self.delay_interval { + buffer.put_u8(PropertyType::WillDelayInterval as u8); + buffer.put_u32(delay_interval); + } + + if let Some(payload_format_indicator) = self.payload_format_indicator { + buffer.put_u8(PropertyType::PayloadFormatIndicator as u8); + buffer.put_u8(payload_format_indicator); + } + + if let Some(message_expiry_interval) = self.message_expiry_interval { + buffer.put_u8(PropertyType::MessageExpiryInterval as u8); + buffer.put_u32(message_expiry_interval); + } + + if let Some(typ) = &self.content_type { + buffer.put_u8(PropertyType::ContentType as u8); + write_mqtt_string(buffer, typ); + } + + if let Some(topic) = &self.response_topic { + buffer.put_u8(PropertyType::ResponseTopic as u8); + write_mqtt_string(buffer, topic); + } + + if let Some(data) = &self.correlation_data { + buffer.put_u8(PropertyType::CorrelationData as u8); + write_mqtt_bytes(buffer, data); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct Login { + pub username: String, + pub password: String, +} + +impl Login { + pub fn new, P: Into>(u: U, p: P) -> Login { + Login { + username: u.into(), + password: p.into(), + } + } + + fn read(connect_flags: u8, mut bytes: &mut Bytes) -> Result, Error> { + let username = match connect_flags & 0b1000_0000 { + 0 => String::new(), + _ => read_mqtt_string(&mut bytes)?, + }; + + let password = match connect_flags & 0b0100_0000 { + 0 => String::new(), + _ => read_mqtt_string(&mut bytes)?, + }; + + if username.is_empty() && password.is_empty() { + Ok(None) + } else { + Ok(Some(Login { username, password })) + } + } + + fn len(&self) -> usize { + let mut len = 0; + + if !self.username.is_empty() { + len += 2 + self.username.len(); + } + + if !self.password.is_empty() { + len += 2 + self.password.len(); + } + + len + } + + fn write(&self, buffer: &mut BytesMut) -> u8 { + let mut connect_flags = 0; + if !self.username.is_empty() { + connect_flags |= 0x80; + write_mqtt_string(buffer, &self.username); + } + + if !self.password.is_empty() { + connect_flags |= 0x40; + write_mqtt_string(buffer, &self.password); + } + + connect_flags + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct ConnectProperties { + /// Expiry interval property after loosing connection + pub session_expiry_interval: Option, + /// Maximum simultaneous packets + pub receive_maximum: Option, + /// Maximum packet size + pub max_packet_size: Option, + /// Maximum mapping integer for a topic + pub topic_alias_max: Option, + pub request_response_info: Option, + pub request_problem_info: Option, + /// List of user properties + pub user_properties: Vec<(String, String)>, + /// Method of authentication + pub authentication_method: Option, + /// Authentication data + pub authentication_data: Option, +} + +impl ConnectProperties { + fn _new() -> ConnectProperties { + ConnectProperties { + session_expiry_interval: None, + receive_maximum: None, + max_packet_size: None, + topic_alias_max: None, + request_response_info: None, + request_problem_info: None, + user_properties: Vec::new(), + authentication_method: None, + authentication_data: None, + } + } + + fn read(mut bytes: &mut Bytes) -> Result, Error> { + let mut session_expiry_interval = None; + let mut receive_maximum = None; + let mut max_packet_size = None; + let mut topic_alias_max = None; + let mut request_response_info = None; + let mut request_problem_info = None; + let mut user_properties = Vec::new(); + let mut authentication_method = None; + let mut authentication_data = None; + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + match property(prop)? { + PropertyType::SessionExpiryInterval => { + session_expiry_interval = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::ReceiveMaximum => { + receive_maximum = Some(read_u16(&mut bytes)?); + cursor += 2; + } + PropertyType::MaximumPacketSize => { + max_packet_size = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::TopicAliasMaximum => { + topic_alias_max = Some(read_u16(&mut bytes)?); + cursor += 2; + } + PropertyType::RequestResponseInformation => { + request_response_info = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::RequestProblemInformation => { + request_problem_info = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + PropertyType::AuthenticationMethod => { + let method = read_mqtt_string(&mut bytes)?; + cursor += 2 + method.len(); + authentication_method = Some(method); + } + PropertyType::AuthenticationData => { + let data = read_mqtt_bytes(&mut bytes)?; + cursor += 2 + data.len(); + authentication_data = Some(data); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(ConnectProperties { + session_expiry_interval, + receive_maximum, + max_packet_size, + topic_alias_max, + request_response_info, + request_problem_info, + user_properties, + authentication_method, + authentication_data, + })) + } + + fn len(&self) -> usize { + let mut len = 0; + + if self.session_expiry_interval.is_some() { + len += 1 + 4; + } + + if self.receive_maximum.is_some() { + len += 1 + 2; + } + + if self.max_packet_size.is_some() { + len += 1 + 4; + } + + if self.topic_alias_max.is_some() { + len += 1 + 2; + } + + if self.request_response_info.is_some() { + len += 1 + 1; + } + + if self.request_problem_info.is_some() { + len += 1 + 1; + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + if let Some(authentication_method) = &self.authentication_method { + len += 1 + 2 + authentication_method.len(); + } + + if let Some(authentication_data) = &self.authentication_data { + len += 1 + 2 + authentication_data.len(); + } + + len + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(session_expiry_interval) = self.session_expiry_interval { + buffer.put_u8(PropertyType::SessionExpiryInterval as u8); + buffer.put_u32(session_expiry_interval); + } + + if let Some(receive_maximum) = self.receive_maximum { + buffer.put_u8(PropertyType::ReceiveMaximum as u8); + buffer.put_u16(receive_maximum); + } + + if let Some(max_packet_size) = self.max_packet_size { + buffer.put_u8(PropertyType::MaximumPacketSize as u8); + buffer.put_u32(max_packet_size); + } + + if let Some(topic_alias_max) = self.topic_alias_max { + buffer.put_u8(PropertyType::TopicAliasMaximum as u8); + buffer.put_u16(topic_alias_max); + } + + if let Some(request_response_info) = self.request_response_info { + buffer.put_u8(PropertyType::RequestResponseInformation as u8); + buffer.put_u8(request_response_info); + } + + if let Some(request_problem_info) = self.request_problem_info { + buffer.put_u8(PropertyType::RequestProblemInformation as u8); + buffer.put_u8(request_problem_info); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + if let Some(authentication_method) = &self.authentication_method { + buffer.put_u8(PropertyType::AuthenticationMethod as u8); + write_mqtt_string(buffer, authentication_method); + } + + if let Some(authentication_data) = &self.authentication_data { + buffer.put_u8(PropertyType::AuthenticationData as u8); + write_mqtt_bytes(buffer, authentication_data); + } + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use super::*; + + fn sample() -> Connect { + let connect_properties = ConnectProperties { + session_expiry_interval: Some(1234), + receive_maximum: Some(432), + max_packet_size: Some(100), + topic_alias_max: Some(456), + request_response_info: Some(1), + request_problem_info: Some(1), + user_properties: vec![("test".to_owned(), "test".to_owned())], + authentication_method: Some("test".to_owned()), + authentication_data: Some(Bytes::from(vec![1, 2, 3, 4])), + }; + + let will_properties = WillProperties { + delay_interval: Some(1234), + payload_format_indicator: Some(0), + message_expiry_interval: Some(4321), + content_type: Some("test".to_owned()), + response_topic: Some("topic".to_owned()), + correlation_data: Some(Bytes::from(vec![1, 2, 3, 4])), + user_properties: vec![("test".to_owned(), "test".to_owned())], + }; + + let will = LastWill { + topic: "mydevice/status".to_string(), + message: Bytes::from(vec![b'd', b'e', b'a', b'd']), + qos: QoS::AtMostOnce, + retain: false, + properties: Some(will_properties), + }; + + let login = Login { + username: "matteo".to_string(), + password: "collina".to_string(), + }; + + Connect { + protocol: Protocol::V5, + keep_alive: 0, + properties: Some(connect_properties), + client_id: "my-device".to_string(), + clean_session: true, + last_will: Some(will), + login: Some(login), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0x10, // packet type + 0x9d, // remaining len + 0x01, // remaining len + 0x00, 0x04, // 4 + 0x4d, // M + 0x51, // Q + 0x54, // T + 0x54, // T + 0x05, // Level + 0xc6, // connect flags + 0x00, 0x00, // keep alive + 0x2f, // properties len + 0x11, 0x00, 0x00, 0x04, 0xd2, // session expiry interval + 0x21, 0x01, 0xb0, // receive_maximum + 0x27, 0x00, 0x00, 0x00, 0x64, // max packet size + 0x22, 0x01, 0xc8, // topic_alias_max + 0x19, 0x01, // request_response_info + 0x17, 0x01, // request_problem_info + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // user + 0x15, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // authentication_method + 0x16, 0x00, 0x04, 0x01, 0x02, 0x03, 0x04, // authentication_data + 0x00, 0x09, 0x6d, 0x79, 0x2d, 0x64, 0x65, 0x76, 0x69, 0x63, 0x65, // client id + 0x2f, // will properties len + 0x18, 0x00, 0x00, 0x04, 0xd2, // will delay interval + 0x01, 0x00, // payload format indicator + 0x02, 0x00, 0x00, 0x10, 0xe1, // message expiry interval + 0x03, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // content type + 0x08, 0x00, 0x05, 0x74, 0x6f, 0x70, 0x69, 0x63, // response topic + 0x09, 0x00, 0x04, 0x01, 0x02, 0x03, 0x04, // correlation_data + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // will user properties + 0x00, 0x0f, 0x6d, 0x79, 0x64, 0x65, 0x76, 0x69, 0x63, 0x65, 0x2f, 0x73, 0x74, 0x61, + 0x74, 0x75, 0x73, // will topic + 0x00, 0x04, 0x64, 0x65, 0x61, 0x64, // will payload + 0x00, 0x06, 0x6d, 0x61, 0x74, 0x74, 0x65, 0x6f, // username + 0x00, 0x07, 0x63, 0x6f, 0x6c, 0x6c, 0x69, 0x6e, 0x61, // password + ] + } + + #[test] + fn connect1_parsing_works() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let connect_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let connect = Connect::read(fixed_header, connect_bytes).unwrap(); + assert_eq!(connect, sample()); + } + + #[test] + fn connect1_encoding_works() { + let connect = sample(); + let mut buf = BytesMut::new(); + connect.write(&mut buf).unwrap(); + assert_eq!(&buf[..], sample_bytes()); + } + + fn sample2() -> Connect { + Connect { + protocol: Protocol::V5, + keep_alive: 10, + properties: None, + client_id: "hackathonmqtt5test".to_owned(), + clean_session: true, + last_will: None, + login: None, + } + } + + fn sample2_bytes() -> Vec { + vec![ + 0x10, // packet type + 0x1f, 0x00, // remaining len + 0x04, // 4 + 0x4d, 0x51, 0x54, 0x54, // MQTT + 0x05, // level + 0x02, // connect flags + 0x00, 0x0a, // keep alive + 0x00, 0x00, 0x12, 0x68, 0x61, 0x63, 0x6b, 0x61, 0x74, 0x68, 0x6f, 0x6e, 0x6d, 0x71, + 0x74, 0x74, 0x35, 0x74, 0x65, 0x73, 0x74, // payload + 0x10, 0x11, 0x12, // extra bytes in the stream + ] + } + + #[test] + fn connect2_parsing_works() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample2_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let connect_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let connect = Connect::read(fixed_header, connect_bytes).unwrap(); + assert_eq!(connect, sample2()); + } + + #[test] + fn connect2_encoding_works() { + let connect = sample2(); + let mut buf = BytesMut::new(); + connect.write(&mut buf).unwrap(); + + let expected = sample2_bytes(); + assert_eq!(&buf[..], &expected[0..(expected.len() - 3)]); + } + + fn sample3() -> Connect { + let connect_properties = ConnectProperties { + session_expiry_interval: Some(1234), + receive_maximum: Some(432), + max_packet_size: Some(100), + topic_alias_max: Some(456), + request_response_info: Some(1), + request_problem_info: Some(1), + user_properties: vec![("test".to_owned(), "test".to_owned())], + authentication_method: Some("test".to_owned()), + authentication_data: Some(Bytes::from(vec![1, 2, 3, 4])), + }; + + let will = LastWill { + topic: "mydevice/status".to_string(), + message: Bytes::from(vec![b'd', b'e', b'a', b'd']), + qos: QoS::AtMostOnce, + retain: false, + properties: None, + }; + + let login = Login { + username: "matteo".to_string(), + password: "collina".to_string(), + }; + + Connect { + protocol: Protocol::V5, + keep_alive: 0, + properties: Some(connect_properties), + client_id: "my-device".to_string(), + clean_session: true, + last_will: Some(will), + login: Some(login), + } + } + + fn sample3_bytes() -> Vec { + vec![ + 0x10, 0x6e, 0x00, 0x04, 0x4d, 0x51, 0x54, 0x54, 0x05, 0xc6, 0x00, 0x00, 0x2f, 0x11, + 0x00, 0x00, 0x04, 0xd2, 0x21, 0x01, 0xb0, 0x27, 0x00, 0x00, 0x00, 0x64, 0x22, 0x01, + 0xc8, 0x19, 0x01, 0x17, 0x01, 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, + 0x74, 0x65, 0x73, 0x74, 0x15, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x16, 0x00, 0x04, + 0x01, 0x02, 0x03, 0x04, 0x00, 0x09, 0x6d, 0x79, 0x2d, 0x64, 0x65, 0x76, 0x69, 0x63, + 0x65, 0x00, 0x00, 0x0f, 0x6d, 0x79, 0x64, 0x65, 0x76, 0x69, 0x63, 0x65, 0x2f, 0x73, + 0x74, 0x61, 0x74, 0x75, 0x73, 0x00, 0x04, 0x64, 0x65, 0x61, 0x64, 0x00, 0x06, 0x6d, + 0x61, 0x74, 0x74, 0x65, 0x6f, 0x00, 0x07, 0x63, 0x6f, 0x6c, 0x6c, 0x69, 0x6e, 0x61, + ] + } + + #[test] + fn connect3_parsing_works() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample3_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let connect_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let connect = Connect::read(fixed_header, connect_bytes).unwrap(); + assert_eq!(connect, sample3()); + } + + #[test] + fn connect3_encoding_works() { + let connect = sample3(); + let mut buf = BytesMut::new(); + connect.write(&mut buf).unwrap(); + + let expected = sample3_bytes(); + assert_eq!(&buf[..], &expected[0..(expected.len())]); + } + + #[test] + fn missing_properties_are_encoded() {} +} diff --git a/rumqttc/src/v5/packet/disconnect.rs b/rumqttc/src/v5/packet/disconnect.rs new file mode 100644 index 000000000..508f3c427 --- /dev/null +++ b/rumqttc/src/v5/packet/disconnect.rs @@ -0,0 +1,434 @@ +use std::convert::{TryFrom, TryInto}; + +use bytes::{BufMut, BytesMut, Bytes}; + +use super::*; + +use super::{property, PropertyType}; + +#[derive(Debug, Clone, Copy, PartialEq)] +#[repr(u8)] +pub enum DisconnectReasonCode { + /// Close the connection normally. Do not send the Will Message. + NormalDisconnection = 0x00, + /// The Client wishes to disconnect but requires that the Server also publishes its Will Message. + DisconnectWithWillMessage = 0x04, + /// The Connection is closed but the sender either does not wish to reveal the reason, or none of the other Reason Codes apply. + UnspecifiedError = 0x80, + /// The received packet does not conform to this specification. + MalformedPacket = 0x81, + /// An unexpected or out of order packet was received. + ProtocolError = 0x82, + /// The packet received is valid but cannot be processed by this implementation. + ImplementationSpecificError = 0x83, + /// The request is not authorized. + NotAuthorized = 0x87, + /// The Server is busy and cannot continue processing requests from this Client. + ServerBusy = 0x89, + /// The Server is shutting down. + ServerShuttingDown = 0x8B, + /// The Connection is closed because no packet has been received for 1.5 times the Keepalive time. + KeepAliveTimeout = 0x8D, + /// Another Connection using the same ClientID has connected causing this Connection to be closed. + SessionTakenOver = 0x8E, + /// The Topic Filter is correctly formed, but is not accepted by this Sever. + TopicFilterInvalid = 0x8F, + /// The Topic Name is correctly formed, but is not accepted by this Client or Server. + TopicNameInvalid = 0x90, + /// The Client or Server has received more than Receive Maximum publication for which it has not sent PUBACK or PUBCOMP. + ReceiveMaximumExceeded = 0x93, + /// The Client or Server has received a PUBLISH packet containing a Topic Alias which is greater than the Maximum Topic Alias it sent in the CONNECT or CONNACK packet. + TopicAliasInvalid = 0x94, + /// The packet size is greater than Maximum Packet Size for this Client or Server. + PacketTooLarge = 0x95, + /// The received data rate is too high. + MessageRateTooHigh = 0x96, + /// An implementation or administrative imposed limit has been exceeded. + QuotaExceeded = 0x97, + /// The Connection is closed due to an administrative action. + AdministrativeAction = 0x98, + /// The payload format does not match the one specified by the Payload Format Indicator. + PayloadFormatInvalid = 0x99, + /// The Server has does not support retained messages. + RetainNotSupported = 0x9A, + /// The Client specified a QoS greater than the QoS specified in a Maximum QoS in the CONNACK. + QoSNotSupported = 0x9B, + /// The Client should temporarily change its Server. + UseAnotherServer = 0x9C, + /// The Server is moved and the Client should permanently change its server location. + ServerMoved = 0x9D, + /// The Server does not support Shared Subscriptions. + SharedSubscriptionNotSupported = 0x9E, + /// This connection is closed because the connection rate is too high. + ConnectionRateExceeded = 0x9F, + /// The maximum connection time authorized for this connection has been exceeded. + MaximumConnectTime = 0xA0, + /// The Server does not support Subscription Identifiers; the subscription is not accepted. + SubscriptionIdentifiersNotSupported = 0xA1, + /// The Server does not support Wildcard subscription; the subscription is not accepted. + WildcardSubscriptionsNotSupported = 0xA2, +} + +impl TryFrom for DisconnectReasonCode { + type Error = Error; + + fn try_from(value: u8) -> Result { + let rc = match value { + 0x00 => Self::NormalDisconnection, + 0x04 => Self::DisconnectWithWillMessage, + 0x80 => Self::UnspecifiedError, + 0x81 => Self::MalformedPacket, + 0x82 => Self::ProtocolError, + 0x83 => Self::ImplementationSpecificError, + 0x87 => Self::NotAuthorized, + 0x89 => Self::ServerBusy, + 0x8B => Self::ServerShuttingDown, + 0x8D => Self::KeepAliveTimeout, + 0x8E => Self::SessionTakenOver, + 0x8F => Self::TopicFilterInvalid, + 0x90 => Self::TopicNameInvalid, + 0x93 => Self::ReceiveMaximumExceeded, + 0x94 => Self::TopicAliasInvalid, + 0x95 => Self::PacketTooLarge, + 0x96 => Self::MessageRateTooHigh, + 0x97 => Self::QuotaExceeded, + 0x98 => Self::AdministrativeAction, + 0x99 => Self::PayloadFormatInvalid, + 0x9A => Self::RetainNotSupported, + 0x9B => Self::QoSNotSupported, + 0x9C => Self::UseAnotherServer, + 0x9D => Self::ServerMoved, + 0x9E => Self::SharedSubscriptionNotSupported, + 0x9F => Self::ConnectionRateExceeded, + 0xA0 => Self::MaximumConnectTime, + 0xA1 => Self::SubscriptionIdentifiersNotSupported, + 0xA2 => Self::WildcardSubscriptionsNotSupported, + other => return Err(Error::InvalidConnectReturnCode(other)), + }; + + Ok(rc) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct DisconnectProperties { + /// Session Expiry Interval in seconds + pub session_expiry_interval: Option, + + /// Human readable reason for the disconnect + pub reason_string: Option, + + /// List of user properties + pub user_properties: Vec<(String, String)>, + + /// String which can be used by the Client to identify another Server to use. + pub server_reference: Option, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct Disconnect { + /// Disconnect Reason Code + pub reason_code: DisconnectReasonCode, + + /// Disconnect Properties + pub properties: Option, +} + +impl DisconnectProperties { + pub fn new() -> Self { + Self { + session_expiry_interval: None, + reason_string: None, + user_properties: Vec::new(), + server_reference: None, + } + } + + fn len(&self) -> usize { + let mut length = 0; + + if self.session_expiry_interval.is_some() { + length += 1 + 4; + } + + if let Some(reason) = &self.reason_string { + length += 1 + 2 + reason.len(); + } + + for (key, value) in self.user_properties.iter() { + length += 1 + 2 + key.len() + 2 + value.len(); + } + + if let Some(server_reference) = &self.server_reference { + length += 1 + 2 + server_reference.len(); + } + + length + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let (properties_len_len, properties_len) = length(bytes.iter())?; + + bytes.advance(properties_len_len); + + if properties_len == 0 { + return Ok(None); + } + + let mut session_expiry_interval = None; + let mut reason_string = None; + let mut user_properties = Vec::new(); + let mut server_reference = None; + + let mut cursor = 0; + + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::SessionExpiryInterval => { + session_expiry_interval = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::ReasonString => { + let reason = read_mqtt_string(&mut bytes)?; + cursor += 2 + reason.len(); + reason_string = Some(reason); + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + PropertyType::ServerReference => { + let reference = read_mqtt_string(&mut bytes)?; + cursor += 2 + reference.len(); + server_reference = Some(reference); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + let properties = Self { + session_expiry_interval, + reason_string, + user_properties, + server_reference, + }; + + Ok(Some(properties)) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let length = self.len(); + write_remaining_length(buffer, length)?; + + if let Some(session_expiry_interval) = self.session_expiry_interval { + buffer.put_u8(PropertyType::SessionExpiryInterval as u8); + buffer.put_u32(session_expiry_interval); + } + + if let Some(reason) = &self.reason_string { + buffer.put_u8(PropertyType::ReasonString as u8); + write_mqtt_string(buffer, reason); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + if let Some(reference) = &self.server_reference { + buffer.put_u8(PropertyType::ServerReference as u8); + write_mqtt_string(buffer, reference); + } + + Ok(()) + } +} + +impl Disconnect { + pub fn new() -> Self { + Self { + reason_code: DisconnectReasonCode::NormalDisconnection, + properties: None, + } + } + + fn len(&self) -> usize { + if self.reason_code == DisconnectReasonCode::NormalDisconnection + && self.properties.is_none() + { + return 2; // Packet type + 0x00 + } + + let mut length = 0; + + match &self.properties { + Some(properties) => { + length += 1; // Disconnect Reason Code + + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + length += properties_len_len + properties_len; + } + None if self.reason_code == DisconnectReasonCode::NormalDisconnection => {} + None => { + length += 1; // Disconnect Reason Code + } + }; + + length + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let packet_type = fixed_header.byte1 >> 4; + let flags = fixed_header.byte1 & 0b0000_1111; + + bytes.advance(fixed_header.fixed_header_len); + + if packet_type != PacketType::Disconnect as u8 { + return Err(Error::InvalidPacketType(packet_type)); + }; + + if flags != 0x00 { + return Err(Error::MalformedPacket); + }; + + if fixed_header.remaining_len == 0 { + return Ok(Self::new()); + } + + let reason_code = read_u8(&mut bytes)?; + + let disconnect = Self { + reason_code: reason_code.try_into()?, + properties: DisconnectProperties::extract(&mut bytes)?, + }; + + Ok(disconnect) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + buffer.put_u8(0xE0); + + let length = self.len(); + + if length == 2 { + buffer.put_u8(0x00); + + return Ok(length); + } + + let len_len = write_remaining_length(buffer, length)?; + + buffer.put_u8(self.reason_code as u8); + + if let Some(properties) = &self.properties { + properties.write(buffer)?; + } + + Ok(1 + len_len + length) + } +} + +#[cfg(test)] +mod test { + use bytes::BytesMut; + + use super::parse_fixed_header; + + use super::{Disconnect, DisconnectProperties, DisconnectReasonCode}; + + #[test] + fn disconnect1_parsing_works() { + let mut buffer = bytes::BytesMut::new(); + let packet_bytes = [ + 0xE0, // Packet type + 0x00, // Remaining length + ]; + let expected = Disconnect::new(); + + buffer.extend_from_slice(&packet_bytes[..]); + + let fixed_header = parse_fixed_header(buffer.iter()).unwrap(); + let disconnect_bytes = buffer.split_to(fixed_header.frame_length()).freeze(); + let disconnect = Disconnect::read(fixed_header, disconnect_bytes).unwrap(); + + assert_eq!(disconnect, expected); + } + + #[test] + fn disconnect1_encoding_works() { + let mut buffer = BytesMut::new(); + let disconnect = Disconnect::new(); + let expected = [ + 0xE0, // Packet type + 0x00, // Remaining length + ]; + + disconnect.write(&mut buffer).unwrap(); + + assert_eq!(&buffer[..], &expected); + } + + fn sample2() -> Disconnect { + let properties = DisconnectProperties { + // TODO: change to 2137 xD + session_expiry_interval: Some(1234), + reason_string: Some("test".to_owned()), + user_properties: vec![("test".to_owned(), "test".to_owned())], + server_reference: Some("test".to_owned()), + }; + + Disconnect { + reason_code: DisconnectReasonCode::UnspecifiedError, + properties: Some(properties), + } + } + + fn sample_bytes2() -> Vec { + vec![ + 0xE0, // Packet type + 0x22, // Remaining length + 0x80, // Disconnect Reason Code + 0x20, // Properties length + 0x11, 0x00, 0x00, 0x04, 0xd2, // Session expiry interval + 0x1F, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // Reason string + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // User properties + 0x1C, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // server reference + ] + } + + #[test] + fn disconnect2_parsing_works() { + let mut buffer = bytes::BytesMut::new(); + let packet_bytes = sample_bytes2(); + let expected = sample2(); + + buffer.extend_from_slice(&packet_bytes[..]); + + let fixed_header = parse_fixed_header(buffer.iter()).unwrap(); + let disconnect_bytes = buffer.split_to(fixed_header.frame_length()).freeze(); + let disconnect = Disconnect::read(fixed_header, disconnect_bytes).unwrap(); + + assert_eq!(disconnect, expected); + } + + #[test] + fn disconnect2_encoding_works() { + let mut buffer = BytesMut::new(); + + let disconnect = sample2(); + let expected = sample_bytes2(); + + disconnect.write(&mut buffer).unwrap(); + + assert_eq!(&buffer[..], &expected); + } +} diff --git a/rumqttc/src/v5/packet/mod.rs b/rumqttc/src/v5/packet/mod.rs new file mode 100644 index 000000000..8f954c1cf --- /dev/null +++ b/rumqttc/src/v5/packet/mod.rs @@ -0,0 +1,489 @@ +use std::{ + fmt::{self, Display, Formatter}, + slice::Iter, +}; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +mod connack; +mod connect; +mod disconnect; +mod ping; +mod puback; +mod pubcomp; +mod publish; +mod pubrec; +mod pubrel; +mod suback; +mod subscribe; +mod unsuback; +mod unsubscribe; + +pub use connack::*; +pub use connect::*; +pub use disconnect::*; +pub use ping::*; +pub use puback::*; +pub use pubcomp::*; +pub use publish::*; +pub use pubrec::*; +pub use pubrel::*; +pub use suback::*; +pub use subscribe::*; +pub use unsuback::*; +pub use unsubscribe::*; + +/// Encapsulates all MQTT packet types +#[derive(Debug, Clone, PartialEq)] +pub enum Packet { + Connect(Connect), + ConnAck(ConnAck), + Publish(Publish), + PubAck(PubAck), + PubRec(PubRec), + PubRel(PubRel), + PubComp(PubComp), + Subscribe(Subscribe), + SubAck(SubAck), + Unsubscribe(Unsubscribe), + UnsubAck(UnsubAck), + PingReq, + PingResp, + Disconnect(Disconnect), +} + +/// MQTT packet type +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PacketType { + Connect = 1, + ConnAck, + Publish, + PubAck, + PubRec, + PubRel, + PubComp, + Subscribe, + SubAck, + Unsubscribe, + UnsubAck, + PingReq, + PingResp, + Disconnect, +} + +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum PropertyType { + PayloadFormatIndicator = 1, + MessageExpiryInterval = 2, + ContentType = 3, + ResponseTopic = 8, + CorrelationData = 9, + SubscriptionIdentifier = 11, + SessionExpiryInterval = 17, + AssignedClientIdentifier = 18, + ServerKeepAlive = 19, + AuthenticationMethod = 21, + AuthenticationData = 22, + RequestProblemInformation = 23, + WillDelayInterval = 24, + RequestResponseInformation = 25, + ResponseInformation = 26, + ServerReference = 28, + ReasonString = 31, + ReceiveMaximum = 33, + TopicAliasMaximum = 34, + TopicAlias = 35, + MaximumQos = 36, + RetainAvailable = 37, + UserProperty = 38, + MaximumPacketSize = 39, + WildcardSubscriptionAvailable = 40, + SubscriptionIdentifierAvailable = 41, + SharedSubscriptionAvailable = 42, +} + +/// Error during serialization and deserialization +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Error { + NotConnect(PacketType), + UnexpectedConnect, + InvalidConnectReturnCode(u8), + InvalidReason(u8), + InvalidProtocol, + InvalidProtocolLevel(u8), + IncorrectPacketFormat, + InvalidPacketType(u8), + InvalidPropertyType(u8), + InvalidRetainForwardRule(u8), + InvalidQoS(u8), + InvalidSubscribeReasonCode(u8), + PacketIdZero, + SubscriptionIdZero, + PayloadSizeIncorrect, + PayloadTooLong, + PayloadSizeLimitExceeded(usize), + PayloadRequired, + TopicNotUtf8, + BoundaryCrossed(usize), + MalformedPacket, + MalformedRemainingLength, + /// More bytes required to frame packet. Argument + /// implies minimum additional bytes required to + /// proceed further + InsufficientBytes(usize), +} + +/// Protocol type +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Protocol { + V4, + V5, +} + +/// Quality of service +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)] +pub enum QoS { + AtMostOnce = 0, + AtLeastOnce = 1, + ExactlyOnce = 2, +} + +/// Packet type from a byte +/// +/// ```ignore +/// 7 3 0 +/// +--------------------------+--------------------------+ +/// byte 1 | MQTT Control Packet Type | Flags for each type | +/// +--------------------------+--------------------------+ +/// | Remaining Bytes Len (1/2/3/4 bytes) | +/// +-----------------------------------------------------+ +/// +/// http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Figure_2.2_- +/// ``` +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)] +pub struct FixedHeader { + /// First byte of the stream. Used to identify packet types and + /// several flags + byte1: u8, + /// Length of fixed header. Byte 1 + (1..4) bytes. So fixed header + /// len can vary from 2 bytes to 5 bytes + /// 1..4 bytes are variable length encoded to represent remaining length + fixed_header_len: usize, + /// Remaining length of the packet. Doesn't include fixed header bytes + /// Represents variable header + payload size + remaining_len: usize, +} + +impl FixedHeader { + pub fn new(byte1: u8, remaining_len_len: usize, remaining_len: usize) -> FixedHeader { + FixedHeader { + byte1, + fixed_header_len: remaining_len_len + 1, + remaining_len, + } + } + + pub fn packet_type(&self) -> Result { + let num = self.byte1 >> 4; + match num { + 1 => Ok(PacketType::Connect), + 2 => Ok(PacketType::ConnAck), + 3 => Ok(PacketType::Publish), + 4 => Ok(PacketType::PubAck), + 5 => Ok(PacketType::PubRec), + 6 => Ok(PacketType::PubRel), + 7 => Ok(PacketType::PubComp), + 8 => Ok(PacketType::Subscribe), + 9 => Ok(PacketType::SubAck), + 10 => Ok(PacketType::Unsubscribe), + 11 => Ok(PacketType::UnsubAck), + 12 => Ok(PacketType::PingReq), + 13 => Ok(PacketType::PingResp), + 14 => Ok(PacketType::Disconnect), + _ => Err(Error::InvalidPacketType(num)), + } + } + + /// Returns the size of full packet (fixed header + variable header + payload) + /// Fixed header is enough to get the size of a frame in the stream + pub fn frame_length(&self) -> usize { + self.fixed_header_len + self.remaining_len + } +} + +fn property(num: u8) -> Result { + let property = match num { + 1 => PropertyType::PayloadFormatIndicator, + 2 => PropertyType::MessageExpiryInterval, + 3 => PropertyType::ContentType, + 8 => PropertyType::ResponseTopic, + 9 => PropertyType::CorrelationData, + 11 => PropertyType::SubscriptionIdentifier, + 17 => PropertyType::SessionExpiryInterval, + 18 => PropertyType::AssignedClientIdentifier, + 19 => PropertyType::ServerKeepAlive, + 21 => PropertyType::AuthenticationMethod, + 22 => PropertyType::AuthenticationData, + 23 => PropertyType::RequestProblemInformation, + 24 => PropertyType::WillDelayInterval, + 25 => PropertyType::RequestResponseInformation, + 26 => PropertyType::ResponseInformation, + 28 => PropertyType::ServerReference, + 31 => PropertyType::ReasonString, + 33 => PropertyType::ReceiveMaximum, + 34 => PropertyType::TopicAliasMaximum, + 35 => PropertyType::TopicAlias, + 36 => PropertyType::MaximumQos, + 37 => PropertyType::RetainAvailable, + 38 => PropertyType::UserProperty, + 39 => PropertyType::MaximumPacketSize, + 40 => PropertyType::WildcardSubscriptionAvailable, + 41 => PropertyType::SubscriptionIdentifierAvailable, + 42 => PropertyType::SharedSubscriptionAvailable, + num => return Err(Error::InvalidPropertyType(num)), + }; + + Ok(property) +} + +/// Checks if the stream has enough bytes to frame a packet and returns fixed header +/// only if a packet can be framed with existing bytes in the `stream`. +/// The passed stream doesn't modify parent stream's cursor. If this function +/// returned an error, next `check` on the same parent stream is forced start +/// with cursor at 0 again (Iter is owned. Only Iter's cursor is changed internally) +pub fn check(stream: Iter, max_packet_size: usize) -> Result { + // Create fixed header if there are enough bytes in the stream + // to frame full packet + let stream_len = stream.len(); + let fixed_header = parse_fixed_header(stream)?; + + // Don't let rogue connections attack with huge payloads. + // Disconnect them before reading all that data + if fixed_header.remaining_len > max_packet_size { + return Err(Error::PayloadSizeLimitExceeded(fixed_header.remaining_len)); + } + + // If the current call fails due to insufficient bytes in the stream, + // after calculating remaining length, we extend the stream + let frame_length = fixed_header.frame_length(); + if stream_len < frame_length { + return Err(Error::InsufficientBytes(frame_length - stream_len)); + } + + Ok(fixed_header) +} + +/// Parses fixed header +fn parse_fixed_header(mut stream: Iter) -> Result { + // At least 2 bytes are necessary to frame a packet + let stream_len = stream.len(); + if stream_len < 2 { + return Err(Error::InsufficientBytes(2 - stream_len)); + } + + let byte1 = stream.next().unwrap(); + let (len_len, len) = length(stream)?; + + Ok(FixedHeader::new(*byte1, len_len, len)) +} + +/// Parses variable byte integer in the stream and returns the length +/// and number of bytes that make it. Used for remaining length calculation +/// as well as for calculating property lengths +fn length(stream: Iter) -> Result<(usize, usize), Error> { + let mut len: usize = 0; + let mut len_len = 0; + let mut done = false; + let mut shift = 0; + + // Use continuation bit at position 7 to continue reading next + // byte to frame 'length'. + // Stream 0b1xxx_xxxx 0b1yyy_yyyy 0b1zzz_zzzz 0b0www_wwww will + // be framed as number 0bwww_wwww_zzz_zzzz_yyy_yyyy_xxx_xxxx + for byte in stream { + len_len += 1; + let byte = *byte as usize; + len += (byte & 0x7F) << shift; + + // stop when continue bit is 0 + done = (byte & 0x80) == 0; + if done { + break; + } + + shift += 7; + + // Only a max of 4 bytes allowed for remaining length + // more than 4 shifts (0, 7, 14, 21) implies bad length + if shift > 21 { + return Err(Error::MalformedRemainingLength); + } + } + + // Not enough bytes to frame remaining length. wait for + // one more byte + if !done { + return Err(Error::InsufficientBytes(1)); + } + + Ok((len_len, len)) +} + +/// Reads a stream of bytes and extracts next MQTT packet out of it +pub fn read(stream: &mut BytesMut, max_size: usize) -> Result { + let fixed_header = check(stream.iter(), max_size)?; + + // Test with a stream with exactly the size to check border panics + let packet = stream.split_to(fixed_header.frame_length()); + let packet_type = fixed_header.packet_type()?; + + if fixed_header.remaining_len == 0 { + // no payload packets + return match packet_type { + PacketType::PingReq => Ok(Packet::PingReq), + PacketType::PingResp => Ok(Packet::PingResp), + _ => Err(Error::PayloadRequired), + }; + } + + let packet = packet.freeze(); + let packet = match packet_type { + PacketType::Connect => Packet::Connect(Connect::read(fixed_header, packet)?), + PacketType::ConnAck => Packet::ConnAck(ConnAck::read(fixed_header, packet)?), + PacketType::Publish => Packet::Publish(Publish::read(fixed_header, packet)?), + PacketType::PubAck => Packet::PubAck(PubAck::read(fixed_header, packet)?), + PacketType::PubRec => Packet::PubRec(PubRec::read(fixed_header, packet)?), + PacketType::PubRel => Packet::PubRel(PubRel::read(fixed_header, packet)?), + PacketType::PubComp => Packet::PubComp(PubComp::read(fixed_header, packet)?), + PacketType::Subscribe => Packet::Subscribe(Subscribe::read(fixed_header, packet)?), + PacketType::SubAck => Packet::SubAck(SubAck::read(fixed_header, packet)?), + PacketType::Unsubscribe => Packet::Unsubscribe(Unsubscribe::read(fixed_header, packet)?), + PacketType::UnsubAck => Packet::UnsubAck(UnsubAck::read(fixed_header, packet)?), + PacketType::PingReq => Packet::PingReq, + PacketType::PingResp => Packet::PingResp, + PacketType::Disconnect => Packet::Disconnect(Disconnect::read(fixed_header, packet)?), + }; + + Ok(packet) +} + +/// Reads a series of bytes with a length from a byte stream +fn read_mqtt_bytes(stream: &mut Bytes) -> Result { + let len = read_u16(stream)? as usize; + + // Prevent attacks with wrong remaining length. This method is used in + // `packet.assembly()` with (enough) bytes to frame packet. Ensures that + // reading variable len string or bytes doesn't cross promised boundary + // with `read_fixed_header()` + if len > stream.len() { + return Err(Error::BoundaryCrossed(len)); + } + + Ok(stream.split_to(len)) +} + +/// Reads a string from bytes stream +fn read_mqtt_string(stream: &mut Bytes) -> Result { + let s = read_mqtt_bytes(stream)?; + match String::from_utf8(s.to_vec()) { + Ok(v) => Ok(v), + Err(_e) => Err(Error::TopicNotUtf8), + } +} + +/// Serializes bytes to stream (including length) +fn write_mqtt_bytes(stream: &mut BytesMut, bytes: &[u8]) { + stream.put_u16(bytes.len() as u16); + stream.extend_from_slice(bytes); +} + +/// Serializes a string to stream +fn write_mqtt_string(stream: &mut BytesMut, string: &str) { + write_mqtt_bytes(stream, string.as_bytes()); +} + +/// Writes remaining length to stream and returns number of bytes for remaining length +fn write_remaining_length(stream: &mut BytesMut, len: usize) -> Result { + if len > 268_435_455 { + return Err(Error::PayloadTooLong); + } + + let mut done = false; + let mut x = len; + let mut count = 0; + + while !done { + let mut byte = (x % 128) as u8; + x /= 128; + if x > 0 { + byte |= 128; + } + + stream.put_u8(byte); + count += 1; + done = x == 0; + } + + Ok(count) +} + +/// Return number of remaining length bytes required for encoding length +fn len_len(len: usize) -> usize { + if len >= 2_097_152 { + 4 + } else if len >= 16_384 { + 3 + } else if len >= 128 { + 2 + } else { + 1 + } +} + +/// Maps a number to QoS +pub fn qos(num: u8) -> Result { + match num { + 0 => Ok(QoS::AtMostOnce), + 1 => Ok(QoS::AtLeastOnce), + 2 => Ok(QoS::ExactlyOnce), + qos => Err(Error::InvalidQoS(qos)), + } +} + +/// After collecting enough bytes to frame a packet (packet's frame()) +/// , It's possible that content itself in the stream is wrong. Like expected +/// packet id or qos not being present. In cases where `read_mqtt_string` or +/// `read_mqtt_bytes` exhausted remaining length but packet framing expects to +/// parse qos next, these pre checks will prevent `bytes` crashes +fn read_u16(stream: &mut Bytes) -> Result { + if stream.len() < 2 { + return Err(Error::MalformedPacket); + } + + Ok(stream.get_u16()) +} + +fn read_u8(stream: &mut Bytes) -> Result { + if stream.is_empty() { + return Err(Error::MalformedPacket); + } + + Ok(stream.get_u8()) +} + +fn read_u32(stream: &mut Bytes) -> Result { + if stream.len() < 4 { + return Err(Error::MalformedPacket); + } + + Ok(stream.get_u32()) +} + +impl Display for Error { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "Error = {:?}", self) + } +} diff --git a/rumqttc/src/v5/packet/ping.rs b/rumqttc/src/v5/packet/ping.rs new file mode 100644 index 000000000..1072029bd --- /dev/null +++ b/rumqttc/src/v5/packet/ping.rs @@ -0,0 +1,20 @@ +use super::*; +use bytes::{BufMut, BytesMut}; + +pub struct PingReq; + +impl PingReq { + pub fn write(&self, payload: &mut BytesMut) -> Result { + payload.put_slice(&[0xC0, 0x00]); + Ok(2) + } +} + +pub struct PingResp; + +impl PingResp { + pub fn write(&self, payload: &mut BytesMut) -> Result { + payload.put_slice(&[0xD0, 0x00]); + Ok(2) + } +} diff --git a/rumqttc/src/v5/packet/puback.rs b/rumqttc/src/v5/packet/puback.rs new file mode 100644 index 000000000..51131949e --- /dev/null +++ b/rumqttc/src/v5/packet/puback.rs @@ -0,0 +1,324 @@ +use super::*; +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +/// Return code in connack +#[derive(Debug, Clone, Copy, PartialEq)] +#[repr(u8)] +pub enum PubAckReason { + Success = 0, + NoMatchingSubscribers = 16, + UnspecifiedError = 128, + ImplementationSpecificError = 131, + NotAuthorized = 135, + TopicNameInvalid = 144, + PacketIdentifierInUse = 145, + QuotaExceeded = 151, + PayloadFormatInvalid = 153, +} + +/// Acknowledgement to QoS1 publish +#[derive(Debug, Clone, PartialEq)] +pub struct PubAck { + pub pkid: u16, + pub reason: PubAckReason, + pub properties: Option, +} + +impl PubAck { + pub fn new(pkid: u16) -> PubAck { + PubAck { + pkid, + reason: PubAckReason::Success, + properties: None, + } + } + + fn len(&self) -> usize { + let mut len = 2 + 1; // pkid + reason + + // If there are no properties, sending reason code is optional + if self.reason == PubAckReason::Success && self.properties.is_none() { + return 2; + } + + if let Some(properties) = &self.properties { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + + // Unlike other packets, property length can be ignored if there are + // no properties in acks + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + let pkid = read_u16(&mut bytes)?; + + // No reason code or properties if remaining length == 2 + if fixed_header.remaining_len == 2 { + return Ok(PubAck { + pkid, + reason: PubAckReason::Success, + properties: None, + }); + } + + // No properties len or properties if remaining len > 2 but < 4 + let ack_reason = read_u8(&mut bytes)?; + if fixed_header.remaining_len < 4 { + return Ok(PubAck { + pkid, + reason: reason(ack_reason)?, + properties: None, + }); + } + + let puback = PubAck { + pkid, + reason: reason(ack_reason)?, + properties: PubAckProperties::extract(&mut bytes)?, + }; + + Ok(puback) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + let len = self.len(); + buffer.put_u8(0x40); + + let count = write_remaining_length(buffer, len)?; + buffer.put_u16(self.pkid); + + // Reason code is optional with success if there are no properties + if self.reason == PubAckReason::Success && self.properties.is_none() { + return Ok(4); + } + + buffer.put_u8(self.reason as u8); + if let Some(properties) = &self.properties { + properties.write(buffer)?; + } + + Ok(1 + count + len) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct PubAckProperties { + pub reason_string: Option, + pub user_properties: Vec<(String, String)>, +} + +impl PubAckProperties { + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(reason) = &self.reason_string { + len += 1 + 2 + reason.len(); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut reason_string = None; + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::ReasonString => { + let reason = read_mqtt_string(&mut bytes)?; + cursor += 2 + reason.len(); + reason_string = Some(reason); + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(PubAckProperties { + reason_string, + user_properties, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(reason) = &self.reason_string { + buffer.put_u8(PropertyType::ReasonString as u8); + write_mqtt_string(buffer, reason); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } +} +/// Connection return code type +fn reason(num: u8) -> Result { + let code = match num { + 0 => PubAckReason::Success, + 16 => PubAckReason::NoMatchingSubscribers, + 128 => PubAckReason::UnspecifiedError, + 131 => PubAckReason::ImplementationSpecificError, + 135 => PubAckReason::NotAuthorized, + 144 => PubAckReason::TopicNameInvalid, + 145 => PubAckReason::PacketIdentifierInUse, + 151 => PubAckReason::QuotaExceeded, + 153 => PubAckReason::PayloadFormatInvalid, + num => return Err(Error::InvalidConnectReturnCode(num)), + }; + + Ok(code) +} + +#[cfg(v5)] +#[cfg(test)] +mod test { + use super::*; + use alloc::vec; + use bytes::BytesMut; + use pretty_assertions::assert_eq; + + fn sample() -> PubAck { + let properties = PubAckProperties { + reason_string: Some("test".to_owned()), + user_properties: vec![("test".to_owned(), "test".to_owned())], + }; + + PubAck { + pkid: 42, + reason: PubAckReason::NoMatchingSubscribers, + properties: Some(properties), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0x40, // payload type + 0x18, // remaining length + 0x00, 0x2a, // packet id + 0x10, // reason + 0x14, // properties len + 0x1f, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // reason_string + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // user properties + ] + } + + #[test] + fn puback_parsing_works() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let puback_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let puback = PubAck::read(fixed_header, puback_bytes).unwrap(); + assert_eq!(puback, sample()); + } + + #[test] + fn puback_encoding_works() { + let puback = sample(); + let mut buf = BytesMut::new(); + puback.write(&mut buf).unwrap(); + assert_eq!(&buf[..], sample_bytes()); + } + + fn sample2() -> PubAck { + PubAck { + pkid: 42, + reason: PubAckReason::NoMatchingSubscribers, + properties: None, + } + } + + fn sample2_bytes() -> Vec { + vec![0x40, 0x03, 0x00, 0x2a, 0x10] + } + + #[test] + fn puback2_parsing_works() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample2_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let puback_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let puback = PubAck::read(fixed_header, puback_bytes).unwrap(); + assert_eq!(puback, sample2()); + } + + #[test] + fn puback2_encoding_works() { + let puback = sample2(); + let mut buf = BytesMut::new(); + + puback.write(&mut buf).unwrap(); + assert_eq!(&buf[..], sample2_bytes()); + } + + fn sample3() -> PubAck { + PubAck { + pkid: 42, + reason: PubAckReason::Success, + properties: None, + } + } + + fn sample3_bytes() -> Vec { + vec![0x40, 0x02, 0x00, 0x2a] + } + + #[test] + fn puback3_parsing_works() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample3_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let puback_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let puback = PubAck::read(fixed_header, puback_bytes).unwrap(); + assert_eq!(puback, sample3()); + } + + #[test] + fn puback3_encoding_works() { + let puback = sample3(); + let mut buf = BytesMut::new(); + + puback.write(&mut buf).unwrap(); + assert_eq!(&buf[..], sample3_bytes()); + } +} diff --git a/rumqttc/src/v5/packet/pubcomp.rs b/rumqttc/src/v5/packet/pubcomp.rs new file mode 100644 index 000000000..badb97867 --- /dev/null +++ b/rumqttc/src/v5/packet/pubcomp.rs @@ -0,0 +1,237 @@ +use super::*; +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +/// Return code in connack +#[derive(Debug, Clone, Copy, PartialEq)] +#[repr(u8)] +pub enum PubCompReason { + Success = 0, + PacketIdentifierNotFound = 146, +} + +/// Acknowledgement to QoS1 publish +#[derive(Debug, Clone, PartialEq)] +pub struct PubComp { + pub pkid: u16, + pub reason: PubCompReason, + pub properties: Option, +} + +impl PubComp { + pub fn new(pkid: u16) -> PubComp { + PubComp { + pkid, + reason: PubCompReason::Success, + properties: None, + } + } + + fn len(&self) -> usize { + let mut len = 2 + 1; // pkid + reason + + // If there are no properties during success, sending reason code is optional + if self.reason == PubCompReason::Success && self.properties.is_none() { + return 2; + } + + if let Some(properties) = &self.properties { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + let pkid = read_u16(&mut bytes)?; + + if fixed_header.remaining_len == 2 { + return Ok(PubComp { + pkid, + reason: PubCompReason::Success, + properties: None, + }); + } + + let ack_reason = read_u8(&mut bytes)?; + if fixed_header.remaining_len < 4 { + return Ok(PubComp { + pkid, + reason: reason(ack_reason)?, + properties: None, + }); + } + + let puback = PubComp { + pkid, + reason: reason(ack_reason)?, + properties: PubCompProperties::extract(&mut bytes)?, + }; + + Ok(puback) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + let len = self.len(); + buffer.put_u8(0x70); + let count = write_remaining_length(buffer, len)?; + buffer.put_u16(self.pkid); + + // If there are no properties during success, sending reason code is optional + if self.reason == PubCompReason::Success && self.properties.is_none() { + return Ok(4); + } + + buffer.put_u8(self.reason as u8); + + if let Some(properties) = &self.properties { + properties.write(buffer)?; + } + + Ok(1 + count + len) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct PubCompProperties { + pub reason_string: Option, + pub user_properties: Vec<(String, String)>, +} + +impl PubCompProperties { + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(reason) = &self.reason_string { + len += 1 + 2 + reason.len(); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut reason_string = None; + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::ReasonString => { + let reason = read_mqtt_string(&mut bytes)?; + cursor += 2 + reason.len(); + reason_string = Some(reason); + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(PubCompProperties { + reason_string, + user_properties, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(reason) = &self.reason_string { + buffer.put_u8(PropertyType::ReasonString as u8); + write_mqtt_string(buffer, reason); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } +} +/// Connection return code type +fn reason(num: u8) -> Result { + let code = match num { + 0 => PubCompReason::Success, + 146 => PubCompReason::PacketIdentifierNotFound, + num => return Err(Error::InvalidConnectReturnCode(num)), + }; + + Ok(code) +} + +#[cfg(test)] +mod test { + use super::*; + use bytes::BytesMut; + use pretty_assertions::assert_eq; + + fn sample() -> PubComp { + let properties = PubCompProperties { + reason_string: Some("test".to_owned()), + user_properties: vec![("test".to_owned(), "test".to_owned())], + }; + + PubComp { + pkid: 42, + reason: PubCompReason::PacketIdentifierNotFound, + properties: Some(properties), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0x70, // payload type + 0x18, // remaining length + 0x00, 0x2a, // packet id + 0x92, // reason + 0x14, // properties len + 0x1f, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // reason_string + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // user properties + ] + } + + #[test] + fn pubcomp_parsing_works_correctly() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let pubcomp_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let pubcomp = PubComp::read(fixed_header, pubcomp_bytes).unwrap(); + assert_eq!(pubcomp, sample()); + } + + #[test] + fn pubcomp_encoding_works_correctly() { + let pubcomp = sample(); + let mut buf = BytesMut::new(); + pubcomp.write(&mut buf).unwrap(); + assert_eq!(&buf[..], sample_bytes()); + } +} diff --git a/rumqttc/src/v5/packet/publish.rs b/rumqttc/src/v5/packet/publish.rs new file mode 100644 index 000000000..9f0e228ea --- /dev/null +++ b/rumqttc/src/v5/packet/publish.rs @@ -0,0 +1,394 @@ +use super::*; +use bytes::{Buf, Bytes}; +use core::fmt; + +/// Publish packet +#[derive(Clone, PartialEq)] +pub struct Publish { + pub dup: bool, + pub qos: QoS, + pub retain: bool, + pub topic: String, + pub pkid: u16, + pub properties: Option, + pub payload: Bytes, +} + +impl Publish { + pub fn new, P: Into>>(topic: S, qos: QoS, payload: P) -> Publish { + Publish { + dup: false, + qos, + retain: false, + pkid: 0, + topic: topic.into(), + properties: None, + payload: Bytes::from(payload.into()), + } + } + + pub fn from_bytes>(topic: S, qos: QoS, payload: Bytes) -> Publish { + Publish { + dup: false, + qos, + retain: false, + pkid: 0, + topic: topic.into(), + properties: None, + payload, + } + } + + pub fn len(&self) -> usize { + let mut len = 2 + self.topic.len(); + if self.qos != QoS::AtMostOnce && self.pkid != 0 { + len += 2; + } + + match &self.properties { + Some(properties) => { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + None => { + // just 1 byte representing 0 len + len += 1; + } + } + + len += self.payload.len(); + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let qos = qos((fixed_header.byte1 & 0b0110) >> 1)?; + let dup = (fixed_header.byte1 & 0b1000) != 0; + let retain = (fixed_header.byte1 & 0b0001) != 0; + + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + let topic = read_mqtt_string(&mut bytes)?; + + // Packet identifier exists where QoS > 0 + let pkid = match qos { + QoS::AtMostOnce => 0, + QoS::AtLeastOnce | QoS::ExactlyOnce => read_u16(&mut bytes)?, + }; + + if qos != QoS::AtMostOnce && pkid == 0 { + return Err(Error::PacketIdZero); + } + + let publish = Publish { + dup, + retain, + qos, + pkid, + topic, + properties: PublishProperties::extract(&mut bytes)?, + payload: bytes, + }; + + Ok(publish) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + let len = self.len(); + + let dup = self.dup as u8; + let qos = self.qos as u8; + let retain = self.retain as u8; + buffer.put_u8(0b0011_0000 | retain | qos << 1 | dup << 3); + + let count = write_remaining_length(buffer, len)?; + write_mqtt_string(buffer, self.topic.as_str()); + + if self.qos != QoS::AtMostOnce { + let pkid = self.pkid; + if pkid == 0 { + return Err(Error::PacketIdZero); + } + + buffer.put_u16(pkid); + } + + match &self.properties { + Some(properties) => properties.write(buffer)?, + None => { + write_remaining_length(buffer, 0)?; + } + }; + + buffer.extend_from_slice(&self.payload); + + // TODO: Returned length is wrong in other packets. Fix it + Ok(1 + count + len) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct PublishProperties { + pub payload_format_indicator: Option, + pub message_expiry_interval: Option, + pub topic_alias: Option, + pub response_topic: Option, + pub correlation_data: Option, + pub user_properties: Vec<(String, String)>, + pub subscription_identifiers: Vec, + pub content_type: Option, +} + +impl PublishProperties { + fn len(&self) -> usize { + let mut len = 0; + + if self.payload_format_indicator.is_some() { + len += 1 + 1; + } + + if self.message_expiry_interval.is_some() { + len += 1 + 4; + } + + if self.topic_alias.is_some() { + len += 1 + 2; + } + + if let Some(topic) = &self.response_topic { + len += 1 + 2 + topic.len() + } + + if let Some(data) = &self.correlation_data { + len += 1 + 2 + data.len() + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + for id in self.subscription_identifiers.iter() { + len += 1 + len_len(*id); + } + + if let Some(typ) = &self.content_type { + len += 1 + 2 + typ.len() + } + + len + } + + fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut payload_format_indicator = None; + let mut message_expiry_interval = None; + let mut topic_alias = None; + let mut response_topic = None; + let mut correlation_data = None; + let mut user_properties = Vec::new(); + let mut subscription_identifiers = Vec::new(); + let mut content_type = None; + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::PayloadFormatIndicator => { + payload_format_indicator = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::MessageExpiryInterval => { + message_expiry_interval = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::TopicAlias => { + topic_alias = Some(read_u16(&mut bytes)?); + cursor += 2; + } + PropertyType::ResponseTopic => { + let topic = read_mqtt_string(&mut bytes)?; + cursor += 2 + topic.len(); + response_topic = Some(topic); + } + PropertyType::CorrelationData => { + let data = read_mqtt_bytes(&mut bytes)?; + cursor += 2 + data.len(); + correlation_data = Some(data); + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + PropertyType::SubscriptionIdentifier => { + let (id_len, id) = length(bytes.iter())?; + cursor += 1 + id_len; + bytes.advance(id_len); + subscription_identifiers.push(id); + } + PropertyType::ContentType => { + let typ = read_mqtt_string(&mut bytes)?; + cursor += 2 + typ.len(); + content_type = Some(typ); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(PublishProperties { + payload_format_indicator, + message_expiry_interval, + topic_alias, + response_topic, + correlation_data, + user_properties, + subscription_identifiers, + content_type, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(payload_format_indicator) = self.payload_format_indicator { + buffer.put_u8(PropertyType::PayloadFormatIndicator as u8); + buffer.put_u8(payload_format_indicator); + } + + if let Some(message_expiry_interval) = self.message_expiry_interval { + buffer.put_u8(PropertyType::MessageExpiryInterval as u8); + buffer.put_u32(message_expiry_interval); + } + + if let Some(topic_alias) = self.topic_alias { + buffer.put_u8(PropertyType::TopicAlias as u8); + buffer.put_u16(topic_alias); + } + + if let Some(topic) = &self.response_topic { + buffer.put_u8(PropertyType::ResponseTopic as u8); + write_mqtt_string(buffer, topic); + } + + if let Some(data) = &self.correlation_data { + buffer.put_u8(PropertyType::CorrelationData as u8); + write_mqtt_bytes(buffer, data); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + for id in self.subscription_identifiers.iter() { + buffer.put_u8(PropertyType::SubscriptionIdentifier as u8); + write_remaining_length(buffer, *id)?; + } + + if let Some(typ) = &self.content_type { + buffer.put_u8(PropertyType::ContentType as u8); + write_mqtt_string(buffer, typ); + } + + Ok(()) + } +} + +impl fmt::Debug for Publish { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "Topic = {}, Qos = {:?}, Retain = {}, Pkid = {:?}, Payload Size = {}", + self.topic, + self.qos, + self.retain, + self.pkid, + self.payload.len() + ) + } +} + +#[cfg(test)] +mod test { + use super::*; + use bytes::{Bytes, BytesMut}; + use pretty_assertions::assert_eq; + + fn sample_v5() -> Publish { + let publish_properties = PublishProperties { + payload_format_indicator: Some(1), + message_expiry_interval: Some(4321), + topic_alias: Some(100), + response_topic: Some("topic".to_owned()), + correlation_data: Some(Bytes::from(vec![1, 2, 3, 4])), + user_properties: vec![("test".to_owned(), "test".to_owned())], + subscription_identifiers: vec![120, 121], + content_type: Some("test".to_owned()), + }; + + Publish { + dup: false, + qos: QoS::ExactlyOnce, + retain: false, + topic: "test".to_string(), + pkid: 42, + properties: Some(publish_properties), + payload: Bytes::from(vec![b't', b'e', b's', b't']), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0x34, // payload type + 0x3e, // remaining len + 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // topic name + 0x00, 0x2a, // pkid + 0x31, // properties len + 0x01, 0x01, // payload format indicator + 0x02, 0x00, 0x00, 0x10, 0xe1, // message_expiry_interval + 0x23, 0x00, 0x64, // topic alias + 0x08, 0x00, 0x05, 0x74, 0x6f, 0x70, 0x69, 0x63, // response topic + 0x09, 0x00, 0x04, 0x01, 0x02, 0x03, 0x04, // correlation_data + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // user properties + 0x0b, 0x78, // subscription_identifier + 0x0b, 0x79, // subscription_identifier + 0x03, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // content_type + 0x74, 0x65, 0x73, 0x74, // payload + ] + } + + #[test] + fn publish_parsing_works() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let publish_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let publish = Publish::read(fixed_header, publish_bytes).unwrap(); + assert_eq!(publish, sample_v5()); + } + + #[test] + fn publish_encoding_works() { + let publish = sample_v5(); + let mut buf = BytesMut::new(); + publish.write(&mut buf).unwrap(); + assert_eq!(&buf[..], sample_bytes()); + } + + #[test] + fn missing_properties_are_encoded() {} +} diff --git a/rumqttc/src/v5/packet/pubrec.rs b/rumqttc/src/v5/packet/pubrec.rs new file mode 100644 index 000000000..5e8de572e --- /dev/null +++ b/rumqttc/src/v5/packet/pubrec.rs @@ -0,0 +1,252 @@ +use super::*; +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +/// Return code in connack +#[derive(Debug, Clone, Copy, PartialEq)] +#[repr(u8)] +pub enum PubRecReason { + Success = 0, + NoMatchingSubscribers = 16, + UnspecifiedError = 128, + ImplementationSpecificError = 131, + NotAuthorized = 135, + TopicNameInvalid = 144, + PacketIdentifierInUse = 145, + QuotaExceeded = 151, + PayloadFormatInvalid = 153, +} + +/// Acknowledgement to QoS1 publish +#[derive(Debug, Clone, PartialEq)] +pub struct PubRec { + pub pkid: u16, + pub reason: PubRecReason, + pub properties: Option, +} + +impl PubRec { + pub fn new(pkid: u16) -> PubRec { + PubRec { + pkid, + reason: PubRecReason::Success, + properties: None, + } + } + + fn len(&self) -> usize { + let mut len = 2 + 1; // pkid + reason + + // If there are no properties during success, sending reason code is optional + if self.reason == PubRecReason::Success && self.properties.is_none() { + return 2; + } + + if let Some(properties) = &self.properties { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + + // Unlike other packets, property length can be ignored if there are + // no properties in acks + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + let pkid = read_u16(&mut bytes)?; + if fixed_header.remaining_len == 2 { + return Ok(PubRec { + pkid, + reason: PubRecReason::Success, + properties: None, + }); + } + + let ack_reason = read_u8(&mut bytes)?; + if fixed_header.remaining_len < 4 { + return Ok(PubRec { + pkid, + reason: reason(ack_reason)?, + properties: None, + }); + } + + let puback = PubRec { + pkid, + reason: reason(ack_reason)?, + properties: PubRecProperties::extract(&mut bytes)?, + }; + + Ok(puback) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + let len = self.len(); + buffer.put_u8(0x50); + let count = write_remaining_length(buffer, len)?; + buffer.put_u16(self.pkid); + + // If there are no properties during success, sending reason code is optional + if self.reason == PubRecReason::Success && self.properties.is_none() { + return Ok(4); + } + + buffer.put_u8(self.reason as u8); + + if let Some(properties) = &self.properties { + properties.write(buffer)?; + } + + Ok(1 + count + len) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct PubRecProperties { + pub reason_string: Option, + pub user_properties: Vec<(String, String)>, +} + +impl PubRecProperties { + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(reason) = &self.reason_string { + len += 1 + 2 + reason.len(); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut reason_string = None; + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::ReasonString => { + let reason = read_mqtt_string(&mut bytes)?; + cursor += 2 + reason.len(); + reason_string = Some(reason); + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(PubRecProperties { + reason_string, + user_properties, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(reason) = &self.reason_string { + buffer.put_u8(PropertyType::ReasonString as u8); + write_mqtt_string(buffer, reason); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } +} +/// Connection return code type +fn reason(num: u8) -> Result { + let code = match num { + 0 => PubRecReason::Success, + 16 => PubRecReason::NoMatchingSubscribers, + 128 => PubRecReason::UnspecifiedError, + 131 => PubRecReason::ImplementationSpecificError, + 135 => PubRecReason::NotAuthorized, + 144 => PubRecReason::TopicNameInvalid, + 145 => PubRecReason::PacketIdentifierInUse, + 151 => PubRecReason::QuotaExceeded, + 153 => PubRecReason::PayloadFormatInvalid, + num => return Err(Error::InvalidConnectReturnCode(num)), + }; + + Ok(code) +} + +#[cfg(test)] +mod test { + use super::*; + use bytes::BytesMut; + use pretty_assertions::assert_eq; + + fn sample() -> PubRec { + let properties = PubRecProperties { + reason_string: Some("test".to_owned()), + user_properties: vec![("test".to_owned(), "test".to_owned())], + }; + + PubRec { + pkid: 42, + reason: PubRecReason::NoMatchingSubscribers, + properties: Some(properties), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0x50, // payload type + 0x18, // remaining length + 0x00, 0x2a, // packet id + 0x10, // reason + 0x14, // properties len + 0x1f, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // reason_string + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // user properties + ] + } + + #[test] + fn pubrec_parsing_works() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let pubrec_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let pubrec = PubRec::read(fixed_header, pubrec_bytes).unwrap(); + assert_eq!(pubrec, sample()); + } + + #[test] + fn pubrec_encoding_works() { + let pubrec = sample(); + let mut buf = BytesMut::new(); + pubrec.write(&mut buf).unwrap(); + assert_eq!(&buf[..], sample_bytes()); + } +} diff --git a/rumqttc/src/v5/packet/pubrel.rs b/rumqttc/src/v5/packet/pubrel.rs new file mode 100644 index 000000000..1a1a62e4d --- /dev/null +++ b/rumqttc/src/v5/packet/pubrel.rs @@ -0,0 +1,236 @@ +use super::*; +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +/// Return code in connack +#[derive(Debug, Clone, Copy, PartialEq)] +#[repr(u8)] +pub enum PubRelReason { + Success = 0, + PacketIdentifierNotFound = 146, +} + +/// Acknowledgement to QoS1 publish +#[derive(Debug, Clone, PartialEq)] +pub struct PubRel { + pub pkid: u16, + pub reason: PubRelReason, + pub properties: Option, +} + +impl PubRel { + pub fn new(pkid: u16) -> PubRel { + PubRel { + pkid, + reason: PubRelReason::Success, + properties: None, + } + } + + fn len(&self) -> usize { + let mut len = 2 + 1; // pkid + reason + + // If there are no properties during success, sending reason code is optional + if self.reason == PubRelReason::Success && self.properties.is_none() { + return 2; + } + + if let Some(properties) = &self.properties { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + let pkid = read_u16(&mut bytes)?; + if fixed_header.remaining_len == 2 { + return Ok(PubRel { + pkid, + reason: PubRelReason::Success, + properties: None, + }); + } + + let ack_reason = read_u8(&mut bytes)?; + if fixed_header.remaining_len < 4 { + return Ok(PubRel { + pkid, + reason: reason(ack_reason)?, + properties: None, + }); + } + + let puback = PubRel { + pkid, + reason: reason(ack_reason)?, + properties: PubRelProperties::extract(&mut bytes)?, + }; + + Ok(puback) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + let len = self.len(); + buffer.put_u8(0x62); + let count = write_remaining_length(buffer, len)?; + buffer.put_u16(self.pkid); + + // If there are no properties during success, sending reason code is optional + if self.reason == PubRelReason::Success && self.properties.is_none() { + return Ok(4); + } + + buffer.put_u8(self.reason as u8); + + if let Some(properties) = &self.properties { + properties.write(buffer)?; + } + + Ok(1 + count + len) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct PubRelProperties { + pub reason_string: Option, + pub user_properties: Vec<(String, String)>, +} + +impl PubRelProperties { + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(reason) = &self.reason_string { + len += 1 + 2 + reason.len(); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut reason_string = None; + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::ReasonString => { + let reason = read_mqtt_string(&mut bytes)?; + cursor += 2 + reason.len(); + reason_string = Some(reason); + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(PubRelProperties { + reason_string, + user_properties, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(reason) = &self.reason_string { + buffer.put_u8(PropertyType::ReasonString as u8); + write_mqtt_string(buffer, reason); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } +} +/// Connection return code type +fn reason(num: u8) -> Result { + let code = match num { + 0 => PubRelReason::Success, + 146 => PubRelReason::PacketIdentifierNotFound, + num => return Err(Error::InvalidConnectReturnCode(num)), + }; + + Ok(code) +} + +#[cfg(test)] +mod test { + use super::*; + use bytes::BytesMut; + use pretty_assertions::assert_eq; + + fn sample() -> PubRel { + let properties = PubRelProperties { + reason_string: Some("test".to_owned()), + user_properties: vec![("test".to_owned(), "test".to_owned())], + }; + + PubRel { + pkid: 42, + reason: PubRelReason::PacketIdentifierNotFound, + properties: Some(properties), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0x62, // payload type + 0x18, // remaining length + 0x00, 0x2a, // packet id + 0x92, // reason + 0x14, // properties len + 0x1f, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // reason_string + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // user properties + ] + } + + #[test] + fn pubrel_parsing_works() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let pubrel_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let pubrel = PubRel::read(fixed_header, pubrel_bytes).unwrap(); + assert_eq!(pubrel, sample()); + } + + #[test] + fn pubrel_encoding_works() { + let pubrel = sample(); + let mut buf = BytesMut::new(); + pubrel.write(&mut buf).unwrap(); + assert_eq!(&buf[..], sample_bytes()); + } +} diff --git a/rumqttc/src/v5/packet/suback.rs b/rumqttc/src/v5/packet/suback.rs new file mode 100644 index 000000000..0ec0ead05 --- /dev/null +++ b/rumqttc/src/v5/packet/suback.rs @@ -0,0 +1,263 @@ +use super::*; +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use std::convert::{TryFrom, TryInto}; + +/// Acknowledgement to subscribe +#[derive(Debug, Clone, PartialEq)] +pub struct SubAck { + pub pkid: u16, + pub return_codes: Vec, + pub properties: Option, +} + +impl SubAck { + pub fn new(pkid: u16, return_codes: Vec) -> SubAck { + SubAck { + pkid, + return_codes, + properties: None, + } + } + + pub fn len(&self) -> usize { + let mut len = 2 + self.return_codes.len(); + + match &self.properties { + Some(properties) => { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + None => { + // just 1 byte representing 0 len + len += 1; + } + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + let pkid = read_u16(&mut bytes)?; + let properties = SubAckProperties::extract(&mut bytes)?; + + if !bytes.has_remaining() { + return Err(Error::MalformedPacket); + } + + let mut return_codes = Vec::new(); + while bytes.has_remaining() { + let return_code = read_u8(&mut bytes)?; + return_codes.push(return_code.try_into()?); + } + + let suback = SubAck { + pkid, + return_codes, + properties, + }; + + Ok(suback) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + buffer.put_u8(0x90); + let remaining_len = self.len(); + let remaining_len_bytes = write_remaining_length(buffer, remaining_len)?; + + buffer.put_u16(self.pkid); + + match &self.properties { + Some(properties) => properties.write(buffer)?, + None => { + write_remaining_length(buffer, 0)?; + } + }; + + let p: Vec = self.return_codes.iter().map(|code| *code as u8).collect(); + buffer.extend_from_slice(&p); + Ok(1 + remaining_len_bytes + remaining_len) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct SubAckProperties { + pub reason_string: Option, + pub user_properties: Vec<(String, String)>, +} + +impl SubAckProperties { + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(reason) = &self.reason_string { + len += 1 + 2 + reason.len(); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut reason_string = None; + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::ReasonString => { + let reason = read_mqtt_string(&mut bytes)?; + cursor += 2 + reason.len(); + reason_string = Some(reason); + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(SubAckProperties { + reason_string, + user_properties, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(reason) = &self.reason_string { + buffer.put_u8(PropertyType::ReasonString as u8); + write_mqtt_string(buffer, reason); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } +} + +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SubscribeReasonCode { + QoS0 = 0, + QoS1 = 1, + QoS2 = 2, + Unspecified = 128, + ImplementationSpecific = 131, + NotAuthorized = 135, + TopicFilterInvalid = 143, + PkidInUse = 145, + QuotaExceeded = 151, + SharedSubscriptionsNotSupported = 158, + SubscriptionIdNotSupported = 161, + WildcardSubscriptionsNotSupported = 162, +} + +impl TryFrom for SubscribeReasonCode { + type Error = super::Error; + + fn try_from(value: u8) -> Result { + let v = match value { + 0 => SubscribeReasonCode::QoS0, + 1 => SubscribeReasonCode::QoS1, + 2 => SubscribeReasonCode::QoS2, + 128 => SubscribeReasonCode::Unspecified, + 131 => SubscribeReasonCode::ImplementationSpecific, + 135 => SubscribeReasonCode::NotAuthorized, + 143 => SubscribeReasonCode::TopicFilterInvalid, + 145 => SubscribeReasonCode::PkidInUse, + 151 => SubscribeReasonCode::QuotaExceeded, + 158 => SubscribeReasonCode::SharedSubscriptionsNotSupported, + 161 => SubscribeReasonCode::SubscriptionIdNotSupported, + 162 => SubscribeReasonCode::WildcardSubscriptionsNotSupported, + v => return Err(super::Error::InvalidSubscribeReasonCode(v)), + }; + + Ok(v) + } +} + +#[cfg(test)] +mod test { + use super::*; + use bytes::BytesMut; + use pretty_assertions::assert_eq; + + fn sample() -> SubAck { + let properties = SubAckProperties { + reason_string: Some("test".to_owned()), + user_properties: vec![("test".to_owned(), "test".to_owned())], + }; + + SubAck { + pkid: 42, + return_codes: vec![ + SubscribeReasonCode::QoS0, + SubscribeReasonCode::QoS1, + SubscribeReasonCode::QoS2, + SubscribeReasonCode::Unspecified, + ], + properties: Some(properties), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0x90, // packet type + 0x1b, // remaining len + 0x00, 0x2a, // pkid + 0x14, // properties len + 0x1f, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, + 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // user properties + 0x00, 0x01, 0x02, 0x80, // return codes + ] + } + + #[test] + fn suback_parsing_works() { + let mut stream = BytesMut::new(); + let packetstream = &sample_bytes(); + + stream.extend_from_slice(&packetstream[..]); + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let suback_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let suback = SubAck::read(fixed_header, suback_bytes).unwrap(); + assert_eq!(suback, sample()); + } + + #[test] + fn suback_encoding_works() { + let publish = sample(); + let mut buf = BytesMut::new(); + publish.write(&mut buf).unwrap(); + + // println!("{:X?}", buf); + // println!("{:#04X?}", &buf[..]); + assert_eq!(&buf[..], sample_bytes()); + } +} diff --git a/rumqttc/src/v5/packet/subscribe.rs b/rumqttc/src/v5/packet/subscribe.rs new file mode 100644 index 000000000..9a5c43d1e --- /dev/null +++ b/rumqttc/src/v5/packet/subscribe.rs @@ -0,0 +1,425 @@ +use super::*; +use bytes::{Buf, Bytes}; +use core::fmt; + +/// Subscription packet +#[derive(Clone, PartialEq)] +pub struct Subscribe { + pub pkid: u16, + pub filters: Vec, + pub properties: Option, +} + +impl Subscribe { + pub fn new>(path: S, qos: QoS) -> Subscribe { + let filter = SubscribeFilter { + path: path.into(), + qos, + nolocal: false, + preserve_retain: false, + retain_forward_rule: RetainForwardRule::OnEverySubscribe, + }; + + Subscribe { + pkid: 0, + filters: vec![filter], + properties: None, + } + } + + pub fn new_many(topics: T) -> Subscribe + where + T: IntoIterator, + { + Subscribe { + pkid: 0, + filters: topics.into_iter().collect(), + properties: None, + } + } + + pub fn empty_subscribe() -> Subscribe { + Subscribe { + pkid: 0, + filters: Vec::new(), + properties: None, + } + } + + pub fn add(&mut self, path: String, qos: QoS) -> &mut Self { + let filter = SubscribeFilter { + path, + qos, + nolocal: false, + preserve_retain: false, + retain_forward_rule: RetainForwardRule::OnEverySubscribe, + }; + + self.filters.push(filter); + self + } + + pub fn len(&self) -> usize { + let mut len = 2 + self.filters.iter().fold(0, |s, t| s + t.len()); + + if let Some(properties) = &self.properties { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } else { + // just 1 byte representing 0 len + len += 1; + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + let pkid = read_u16(&mut bytes)?; + let properties = SubscribeProperties::extract(&mut bytes)?; + + // variable header size = 2 (packet identifier) + let mut filters = Vec::new(); + + while bytes.has_remaining() { + let path = read_mqtt_string(&mut bytes)?; + let options = read_u8(&mut bytes)?; + let requested_qos = options & 0b0000_0011; + + let nolocal = options >> 2 & 0b0000_0001; + let nolocal = nolocal != 0; + + let preserve_retain = options >> 3 & 0b0000_0001; + let preserve_retain = preserve_retain != 0; + + let retain_forward_rule = (options >> 4) & 0b0000_0011; + let retain_forward_rule = match retain_forward_rule { + 0 => RetainForwardRule::OnEverySubscribe, + 1 => RetainForwardRule::OnNewSubscribe, + 2 => RetainForwardRule::Never, + r => return Err(Error::InvalidRetainForwardRule(r)), + }; + + filters.push(SubscribeFilter { + path, + qos: qos(requested_qos)?, + nolocal, + preserve_retain, + retain_forward_rule, + }); + } + + let subscribe = Subscribe { + pkid, + filters, + properties, + }; + + Ok(subscribe) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + // write packet type + buffer.put_u8(0x82); + + // write remaining length + let remaining_len = self.len(); + let remaining_len_bytes = write_remaining_length(buffer, remaining_len)?; + + // write packet id + buffer.put_u16(self.pkid); + + match &self.properties { + Some(properties) => properties.write(buffer)?, + None => { + write_remaining_length(buffer, 0)?; + } + }; + + // write filters + for filter in self.filters.iter() { + filter.write(buffer); + } + + Ok(1 + remaining_len_bytes + remaining_len) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct SubscribeProperties { + pub id: Option, + pub user_properties: Vec<(String, String)>, +} + +impl SubscribeProperties { + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(id) = &self.id { + len += 1 + len_len(*id); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut id = None; + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::SubscriptionIdentifier => { + let (id_len, sub_id) = length(bytes.iter())?; + // TODO: Validate 1 +. Tests are working either way + cursor += 1 + id_len; + bytes.advance(id_len); + id = Some(sub_id) + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(SubscribeProperties { + id, + user_properties, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(id) = &self.id { + buffer.put_u8(PropertyType::SubscriptionIdentifier as u8); + write_remaining_length(buffer, *id)?; + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } +} + +/// Subscription filter +#[derive(Clone, PartialEq)] +pub struct SubscribeFilter { + pub path: String, + pub qos: QoS, + pub nolocal: bool, + pub preserve_retain: bool, + pub retain_forward_rule: RetainForwardRule, +} + +impl SubscribeFilter { + pub fn new(path: String, qos: QoS) -> SubscribeFilter { + SubscribeFilter { + path, + qos, + nolocal: false, + preserve_retain: false, + retain_forward_rule: RetainForwardRule::OnEverySubscribe, + } + } + + pub fn set_nolocal(&mut self, flag: bool) -> &mut Self { + self.nolocal = flag; + self + } + + pub fn set_preserve_retain(&mut self, flag: bool) -> &mut Self { + self.preserve_retain = flag; + self + } + + pub fn set_retain_forward_rule(&mut self, rule: RetainForwardRule) -> &mut Self { + self.retain_forward_rule = rule; + self + } + + pub fn len(&self) -> usize { + // filter len + filter + options + 2 + self.path.len() + 1 + } + + fn write(&self, buffer: &mut BytesMut) { + let mut options = 0; + options |= self.qos as u8; + + if self.nolocal { + options |= 1 << 2; + } + + if self.preserve_retain { + options |= 1 << 3; + } + + match self.retain_forward_rule { + RetainForwardRule::OnEverySubscribe => options |= 0 << 4, + RetainForwardRule::OnNewSubscribe => options |= 1 << 4, + RetainForwardRule::Never => options |= 2 << 4, + } + + write_mqtt_string(buffer, self.path.as_str()); + buffer.put_u8(options); + } +} + +#[derive(Debug, Clone, PartialEq)] +pub enum RetainForwardRule { + OnEverySubscribe, + OnNewSubscribe, + Never, +} + +impl fmt::Debug for Subscribe { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "Filters = {:?}, Packet id = {:?}", + self.filters, self.pkid + ) + } +} + +impl fmt::Debug for SubscribeFilter { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "Filter = {}, Qos = {:?}, Nolocal = {}, Preserve retain = {}, Forward rule = {:?}", + self.path, self.qos, self.nolocal, self.preserve_retain, self.retain_forward_rule + ) + } +} + +#[cfg(test)] +mod test { + use super::*; + use bytes::BytesMut; + use pretty_assertions::assert_eq; + + fn sample() -> Subscribe { + let subscribe_properties = SubscribeProperties { + id: Some(100), + user_properties: vec![("test".to_owned(), "test".to_owned())], + }; + + let mut filter = SubscribeFilter::new("hello".to_owned(), QoS::AtLeastOnce); + filter + .set_nolocal(true) + .set_preserve_retain(true) + .set_retain_forward_rule(RetainForwardRule::Never); + + Subscribe { + pkid: 42, + filters: vec![filter], + properties: Some(subscribe_properties), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0x82, // packet type + 0x1a, // remaining length + 0x00, 0x2a, // pkid + 0x0f, // properties len + 0x0b, 0x64, // subscription identifier + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // user properties + 0x00, 0x05, 0x68, 0x65, 0x6c, 0x6c, 0x6f, // filter + 0x2d, // options + ] + } + + #[test] + fn subscribe_parsing_works() { + let mut stream = BytesMut::new(); + let packetstream = &sample_bytes(); + + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let subscribe_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let subscribe = Subscribe::read(fixed_header, subscribe_bytes).unwrap(); + assert_eq!(subscribe, sample()); + } + + #[test] + fn subscribe_encoding_works() { + let publish = sample(); + let mut buf = BytesMut::new(); + publish.write(&mut buf).unwrap(); + + // println!("{:X?}", buf); + // println!("{:#04X?}", &buf[..]); + assert_eq!(&buf[..], sample_bytes()); + } + + fn sample2() -> Subscribe { + let filter = SubscribeFilter::new("hello/world".to_owned(), QoS::AtLeastOnce); + Subscribe { + pkid: 42, + filters: vec![filter], + properties: None, + } + } + + fn sample2_bytes() -> Vec { + vec![ + 0x82, 0x11, 0x00, 0x2a, 0x00, 0x00, 0x0b, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x2f, 0x77, + 0x6f, 0x72, 0x6c, 0x64, 0x01, + ] + } + + #[test] + fn subscribe2_parsing_works() { + let mut stream = BytesMut::new(); + let packetstream = &sample2_bytes(); + + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let subscribe_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let subscribe = Subscribe::read(fixed_header, subscribe_bytes).unwrap(); + assert_eq!(subscribe, sample2()); + } + + #[test] + fn subscribe2_encoding_works() { + let publish = sample2(); + let mut buf = BytesMut::new(); + publish.write(&mut buf).unwrap(); + + // println!("{:X?}", buf); + // println!("{:#04X?}", &buf[..]); + assert_eq!(&buf[..], sample2_bytes()); + } +} diff --git a/rumqttc/src/v5/packet/unsuback.rs b/rumqttc/src/v5/packet/unsuback.rs new file mode 100644 index 000000000..ce01bb4cd --- /dev/null +++ b/rumqttc/src/v5/packet/unsuback.rs @@ -0,0 +1,249 @@ +use super::*; +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +//// Return code in connack +#[derive(Debug, Clone, Copy, PartialEq)] +#[repr(u8)] +pub enum UnsubAckReason { + Success = 0x00, + NoSubscriptionExisted = 0x11, + UnspecifiedError = 0x80, + ImplementationSpecificError = 0x83, + NotAuthorized = 0x87, + TopicFilterInvalid = 0x8F, + PacketIdentifierInUse = 0x91, +} + +/// Acknowledgement to unsubscribe +#[derive(Debug, Clone, PartialEq)] +pub struct UnsubAck { + pub pkid: u16, + pub reasons: Vec, + pub properties: Option, +} + +impl UnsubAck { + pub fn new(pkid: u16) -> UnsubAck { + UnsubAck { + pkid, + reasons: Vec::new(), + properties: None, + } + } + + pub fn len(&self) -> usize { + let mut len = 2 + self.reasons.len(); + + match &self.properties { + Some(properties) => { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + None => { + // just 1 byte representing 0 len + len += 1; + } + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + let pkid = read_u16(&mut bytes)?; + let properties = UnsubAckProperties::extract(&mut bytes)?; + + if !bytes.has_remaining() { + return Err(Error::MalformedPacket); + } + + let mut reasons = Vec::new(); + while bytes.has_remaining() { + let r = read_u8(&mut bytes)?; + reasons.push(reason(r)?); + } + + let unsuback = UnsubAck { + pkid, + reasons, + properties, + }; + + Ok(unsuback) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + buffer.put_u8(0xB0); + let remaining_len = self.len(); + let remaining_len_bytes = write_remaining_length(buffer, remaining_len)?; + + buffer.put_u16(self.pkid); + + match &self.properties { + Some(properties) => properties.write(buffer)?, + None => { + write_remaining_length(buffer, 0)?; + } + }; + + let p: Vec = self.reasons.iter().map(|code| *code as u8).collect(); + buffer.extend_from_slice(&p); + Ok(1 + remaining_len_bytes + remaining_len) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct UnsubAckProperties { + pub reason_string: Option, + pub user_properties: Vec<(String, String)>, +} + +impl UnsubAckProperties { + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(reason) = &self.reason_string { + len += 1 + 2 + reason.len(); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut reason_string = None; + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::ReasonString => { + let reason = read_mqtt_string(&mut bytes)?; + cursor += 2 + reason.len(); + reason_string = Some(reason); + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(UnsubAckProperties { + reason_string, + user_properties, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(reason) = &self.reason_string { + buffer.put_u8(PropertyType::ReasonString as u8); + write_mqtt_string(buffer, reason); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } +} + +/// Connection return code type +fn reason(num: u8) -> Result { + let code = match num { + 0x00 => UnsubAckReason::Success, + 0x11 => UnsubAckReason::NoSubscriptionExisted, + 0x80 => UnsubAckReason::UnspecifiedError, + 0x83 => UnsubAckReason::ImplementationSpecificError, + 0x87 => UnsubAckReason::NotAuthorized, + 0x8F => UnsubAckReason::TopicFilterInvalid, + 0x91 => UnsubAckReason::PacketIdentifierInUse, + num => return Err(Error::InvalidSubscribeReasonCode(num)), + }; + + Ok(code) +} + +#[cfg(test)] +mod test { + use super::*; + use bytes::BytesMut; + use pretty_assertions::assert_eq; + + fn sample() -> UnsubAck { + let properties = UnsubAckProperties { + reason_string: Some("test".to_owned()), + user_properties: vec![("test".to_owned(), "test".to_owned())], + }; + + UnsubAck { + pkid: 10, + reasons: vec![ + UnsubAckReason::NotAuthorized, + UnsubAckReason::TopicFilterInvalid, + ], + properties: Some(properties), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0xb0, // packet type + 0x19, // remaining len + 0x00, 0x0a, // pkid + 0x14, // properties len + 0x1f, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // reason string + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // user properties + 0x87, 0x8f, // reasons + ] + } + + #[test] + fn unsuback_parsing_works() { + let mut stream = BytesMut::new(); + let packetstream = &sample_bytes(); + + stream.extend_from_slice(&packetstream[..]); + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let unsuback_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let unsuback = UnsubAck::read(fixed_header, unsuback_bytes).unwrap(); + assert_eq!(unsuback, sample()); + } + + #[test] + fn unsuback_encoding_works() { + let publish = sample(); + let mut buf = BytesMut::new(); + publish.write(&mut buf).unwrap(); + + // println!("{:X?}", buf); + // println!("{:#04X?}", &buf[..]); + assert_eq!(&buf[..], sample_bytes()); + } +} diff --git a/rumqttc/src/v5/packet/unsubscribe.rs b/rumqttc/src/v5/packet/unsubscribe.rs new file mode 100644 index 000000000..8700c5132 --- /dev/null +++ b/rumqttc/src/v5/packet/unsubscribe.rs @@ -0,0 +1,238 @@ +use super::*; +use bytes::{Buf, Bytes}; + +/// Unsubscribe packet +#[derive(Debug, Clone, PartialEq)] +pub struct Unsubscribe { + pub pkid: u16, + pub filters: Vec, + pub properties: Option, +} + +impl Unsubscribe { + pub fn new>(topic: S) -> Unsubscribe { + Unsubscribe { + pkid: 0, + filters: vec![topic.into()], + properties: None, + } + } + + pub fn len(&self) -> usize { + // Packet id + length of filters (unlike subscribe, this just a string. + // Hence 2 is prefixed for len per filter) + let mut len = 2 + self.filters.iter().fold(0, |s, t| 2 + s + t.len()); + + if let Some(properties) = &self.properties { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } else { + // just 1 byte representing 0 len + len += 1; + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + let pkid = read_u16(&mut bytes)?; + dbg!(pkid); + let properties = UnsubscribeProperties::extract(&mut bytes)?; + + let mut filters = Vec::with_capacity(1); + while bytes.has_remaining() { + let filter = read_mqtt_string(&mut bytes)?; + filters.push(filter); + } + + let unsubscribe = Unsubscribe { + pkid, + filters, + properties, + }; + Ok(unsubscribe) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + buffer.put_u8(0xA2); + + // write remaining length + let remaining_len = self.len(); + let remaining_len_bytes = write_remaining_length(buffer, remaining_len)?; + + // write packet id + buffer.put_u16(self.pkid); + + match &self.properties { + Some(properties) => properties.write(buffer)?, + None => { + write_remaining_length(buffer, 0)?; + } + }; + + // write filters + for filter in self.filters.iter() { + write_mqtt_string(buffer, filter); + } + + Ok(1 + remaining_len_bytes + remaining_len) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct UnsubscribeProperties { + pub user_properties: Vec<(String, String)>, +} + +impl UnsubscribeProperties { + fn len(&self) -> usize { + let mut len = 0; + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(UnsubscribeProperties { user_properties })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use super::*; + use bytes::BytesMut; + use pretty_assertions::assert_eq; + + fn sample() -> Unsubscribe { + let properties = UnsubscribeProperties { + user_properties: vec![("test".to_owned(), "test".to_owned())], + }; + + Unsubscribe { + pkid: 10, + filters: vec!["hello".to_owned(), "world".to_owned()], + properties: Some(properties), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0xa2, // packet type + 0x1e, // remaining len + 0x00, 0x0a, // pkid + 0x0d, // properties len + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // user properties + 0x00, 0x05, 0x68, 0x65, 0x6c, 0x6c, 0x6f, // filter 1 + 0x00, 0x05, 0x77, 0x6f, 0x72, 0x6c, 0x64, // filter 2 + ] + } + + #[test] + fn unsubscribe_parsing_works() { + let mut stream = BytesMut::new(); + let packetstream = &sample_bytes(); + + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let unsubscribe_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let unsubscribe = Unsubscribe::read(fixed_header, unsubscribe_bytes).unwrap(); + assert_eq!(unsubscribe, sample()); + } + + #[test] + fn subscribe_encoding_works() { + let publish = sample(); + let mut buf = BytesMut::new(); + publish.write(&mut buf).unwrap(); + + println!("{:X?}", buf); + println!("{:#04X?}", &buf[..]); + assert_eq!(&buf[..], sample_bytes()); + } + + fn sample2() -> Unsubscribe { + Unsubscribe { + pkid: 10, + filters: vec!["hello".to_owned()], + properties: None, + } + } + + fn sample2_bytes() -> Vec { + vec![ + 0xa2, 0x0a, 0x00, 0x0a, 0x00, 0x00, 0x05, 0x68, 0x65, 0x6c, 0x6c, 0x6f, + ] + } + + #[test] + fn subscribe2_parsing_works() { + let mut stream = BytesMut::new(); + let packetstream = &sample2_bytes(); + + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let subscribe_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let subscribe = Unsubscribe::read(fixed_header, subscribe_bytes).unwrap(); + assert_eq!(subscribe, sample2()); + } + + #[test] + fn subscribe2_encoding_works() { + let publish = sample2(); + let mut buf = BytesMut::new(); + publish.write(&mut buf).unwrap(); + + // println!("{:X?}", buf); + // println!("{:#04X?}", &buf[..]); + assert_eq!(&buf[..], sample2_bytes()); + } +} From ff8097b779ded03684140d6a84e326a8bb6db58a Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Fri, 11 Feb 2022 18:26:02 +0530 Subject: [PATCH 04/38] rumqttc: v5: use mqttbytes/v5 Signed-off-by: Abhik Jain --- rumqttc/src/v5/client.rs | 4 ++-- rumqttc/src/v5/eventloop.rs | 8 ++++---- rumqttc/src/v5/framed.rs | 4 ++-- rumqttc/src/v5/mod.rs | 3 +-- rumqttc/src/v5/state.rs | 12 ++++++------ rumqttc/src/v5/tls.rs | 2 +- 6 files changed, 16 insertions(+), 17 deletions(-) diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index 8b95c1bc2..53c0eaad5 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -1,10 +1,10 @@ //! This module offers a high level synchronous and asynchronous abstraction to //! async eventloop. -use crate::v4::{ConnectionError, Event, EventLoop, MqttOptions, Request}; +use crate::v5::{ConnectionError, Event, EventLoop, MqttOptions, Request}; use async_channel::{SendError, Sender, TrySendError}; use bytes::Bytes; -use mqttbytes::v4::*; +use mqttbytes::v5::*; use mqttbytes::*; use std::mem; use tokio::runtime; diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index d4ce78c7a..8a07317a1 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -1,11 +1,11 @@ -use crate::v4::{framed::Network, Transport}; -use crate::v4::{tls, Incoming, MqttState, Packet, Request, StateError}; -use crate::v4::{MqttOptions, Outgoing}; +use crate::v5::{framed::Network, Transport}; +use crate::v5::{tls, Incoming, MqttState, Packet, Request, StateError}; +use crate::v5::{MqttOptions, Outgoing}; use async_channel::{bounded, Receiver, Sender}; #[cfg(feature = "websocket")] use async_tungstenite::tokio::{connect_async, connect_async_with_tls_connector}; -use mqttbytes::v4::*; +use mqttbytes::v5::*; use tokio::net::TcpStream; #[cfg(unix)] use tokio::net::UnixStream; diff --git a/rumqttc/src/v5/framed.rs b/rumqttc/src/v5/framed.rs index 96bda17da..03855b062 100644 --- a/rumqttc/src/v5/framed.rs +++ b/rumqttc/src/v5/framed.rs @@ -1,9 +1,9 @@ use bytes::BytesMut; -use mqttbytes::v4::*; +use mqttbytes::v5::*; use mqttbytes::*; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use crate::v4::{Incoming, MqttState, StateError}; +use crate::v5::{Incoming, MqttState, StateError}; use std::io; /// Network transforms packets <-> frames efficiently. It takes diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index 01800c685..012f7e0a6 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -12,8 +12,7 @@ mod tls; pub use async_channel::{SendError, Sender, TrySendError}; pub use client::{AsyncClient, Client, ClientError, Connection}; pub use eventloop::{ConnectionError, Event, EventLoop}; -pub use mqttbytes::v4::*; -pub use mqttbytes::*; +pub use packet::*; pub use state::{MqttState, StateError}; pub use tokio_rustls::rustls::ClientConfig; pub use tls::Error; diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index c51116ccf..e8f72592f 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -1,7 +1,7 @@ -use crate::v4::{Event, Incoming, Outgoing, Request}; +use super::{Event, Incoming, Outgoing, Request}; use bytes::BytesMut; -use mqttbytes::v4::*; +use mqttbytes::v5::*; use mqttbytes::*; use std::collections::VecDeque; use std::{io, mem, time::Instant}; @@ -422,7 +422,7 @@ impl MqttState { debug!( "Unsubscribe. Topics = {:?}, Pkid = {:?}", - unsub.topics, unsub.pkid + unsub.filters, unsub.pkid ); unsub.write(&mut self.write)?; @@ -434,7 +434,7 @@ impl MqttState { fn outgoing_disconnect(&mut self) -> Result<(), StateError> { debug!("Disconnect"); - Disconnect.write(&mut self.write)?; + Disconnect::new().write(&mut self.write)?; let event = Event::Outgoing(Outgoing::Disconnect); self.events.push_back(event); Ok(()) @@ -487,8 +487,8 @@ impl MqttState { #[cfg(test)] mod test { use super::{MqttState, StateError}; - use crate::v4::{Event, Incoming, MqttOptions, Outgoing, Request}; - use mqttbytes::v4::*; + use crate::v5::{Event, Incoming, MqttOptions, Outgoing, Request}; + use mqttbytes::v5::*; use mqttbytes::*; fn build_outgoing_publish(qos: QoS) -> Publish { diff --git a/rumqttc/src/v5/tls.rs b/rumqttc/src/v5/tls.rs index ea1809e23..3d07819b2 100644 --- a/rumqttc/src/v5/tls.rs +++ b/rumqttc/src/v5/tls.rs @@ -7,7 +7,7 @@ use tokio_rustls::rustls::{ use tokio_rustls::webpki; use tokio_rustls::{client::TlsStream, TlsConnector}; -use crate::v4::{Key, MqttOptions, TlsConfiguration}; +use crate::v5::{Key, MqttOptions, TlsConfiguration}; use std::convert::TryFrom; use std::io; From 218b8a816721fa921e535b0135a45215026b3afe Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Fri, 11 Feb 2022 18:29:58 +0530 Subject: [PATCH 05/38] rumqttc: v5: use v5::packet instead of mqttbytes::v5 Signed-off-by: Abhik Jain --- rumqttc/src/v5/client.rs | 31 ++++++------------------------- rumqttc/src/v5/eventloop.rs | 8 ++++---- rumqttc/src/v5/framed.rs | 6 ++---- rumqttc/src/v5/state.rs | 10 ++++------ 4 files changed, 16 insertions(+), 39 deletions(-) diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index 53c0eaad5..846ae9901 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -1,11 +1,9 @@ //! This module offers a high level synchronous and asynchronous abstraction to //! async eventloop. -use crate::v5::{ConnectionError, Event, EventLoop, MqttOptions, Request}; +use crate::v5::{packet::*, ConnectionError, Event, EventLoop, MqttOptions, Request}; use async_channel::{SendError, Sender, TrySendError}; use bytes::Bytes; -use mqttbytes::v5::*; -use mqttbytes::*; use std::mem; use tokio::runtime; use tokio::runtime::Runtime; @@ -94,11 +92,7 @@ impl AsyncClient { } /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. - pub async fn ack( - &self, - publish: &Publish - ) -> Result<(), ClientError> - { + pub async fn ack(&self, publish: &Publish) -> Result<(), ClientError> { let ack = get_ack_req(publish); if let Some(ack) = ack { @@ -108,11 +102,7 @@ impl AsyncClient { } /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. - pub fn try_ack( - &self, - publish: &Publish - ) -> Result<(), ClientError> - { + pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { let ack = get_ack_req(publish); if let Some(ack) = ack { self.request_tx.try_send(ack)?; @@ -217,7 +207,7 @@ fn get_ack_req(publish: &Publish) -> Option { let ack = match publish.qos { QoS::AtMostOnce => return None, QoS::AtLeastOnce => Request::PubAck(PubAck::new(publish.pkid)), - QoS::ExactlyOnce => Request::PubRec(PubRec::new(publish.pkid)) + QoS::ExactlyOnce => Request::PubRec(PubRec::new(publish.pkid)), }; Some(ack) } @@ -277,26 +267,17 @@ impl Client { } /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. - pub fn ack( - &self, - publish: &Publish - ) -> Result<(), ClientError> - { + pub fn ack(&self, publish: &Publish) -> Result<(), ClientError> { pollster::block_on(self.client.ack(publish))?; Ok(()) } /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. - pub fn try_ack( - &self, - publish: &Publish - ) -> Result<(), ClientError> - { + pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { self.client.try_ack(publish)?; Ok(()) } - /// Sends a MQTT Subscribe to the eventloop pub fn subscribe>(&mut self, topic: S, qos: QoS) -> Result<(), ClientError> { pollster::block_on(self.client.subscribe(topic, qos))?; diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index 8a07317a1..fe8b1cb6b 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -1,11 +1,11 @@ -use crate::v5::{framed::Network, Transport}; -use crate::v5::{tls, Incoming, MqttState, Packet, Request, StateError}; -use crate::v5::{MqttOptions, Outgoing}; +use crate::v5::{ + framed::Network, packet::*, tls, Incoming, MqttOptions, MqttState, Outgoing, Packet, Request, + StateError, Transport, +}; use async_channel::{bounded, Receiver, Sender}; #[cfg(feature = "websocket")] use async_tungstenite::tokio::{connect_async, connect_async_with_tls_connector}; -use mqttbytes::v5::*; use tokio::net::TcpStream; #[cfg(unix)] use tokio::net::UnixStream; diff --git a/rumqttc/src/v5/framed.rs b/rumqttc/src/v5/framed.rs index 03855b062..684694d5a 100644 --- a/rumqttc/src/v5/framed.rs +++ b/rumqttc/src/v5/framed.rs @@ -1,9 +1,7 @@ use bytes::BytesMut; -use mqttbytes::v5::*; -use mqttbytes::*; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use crate::v5::{Incoming, MqttState, StateError}; +use crate::v5::{packet::*, Incoming, MqttState, StateError}; use std::io; /// Network transforms packets <-> frames efficiently. It takes @@ -61,7 +59,7 @@ impl Network { loop { let required = match read(&mut self.read, self.max_incoming_size) { Ok(packet) => return Ok(packet), - Err(mqttbytes::Error::InsufficientBytes(required)) => required, + Err(Error::InsufficientBytes(required)) => required, Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())), }; diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index e8f72592f..491eabc62 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -1,8 +1,6 @@ -use super::{Event, Incoming, Outgoing, Request}; +use super::{Event, Incoming, Outgoing, Request, packet::*}; use bytes::BytesMut; -use mqttbytes::v5::*; -use mqttbytes::*; use std::collections::VecDeque; use std::{io, mem, time::Instant}; @@ -30,11 +28,11 @@ pub enum StateError { #[error("Timeout while waiting to resolve collision")] CollisionTimeout, #[error("Mqtt serialization/deserialization error")] - Deserialization(mqttbytes::Error), + Deserialization(Error), } -impl From for StateError { - fn from(e: mqttbytes::Error) -> StateError { +impl From for StateError { + fn from(e: Error) -> StateError { StateError::Deserialization(e) } } From 076fb2666fe616dd7814feb32ffac13468742dcc Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Fri, 11 Feb 2022 18:44:09 +0530 Subject: [PATCH 06/38] rumqttc: fix examples, tests and docs Signed-off-by: Abhik Jain --- rumqttc/examples/async_manual_acks.rs | 1 - rumqttc/examples/syncpubsub.rs | 4 +++- rumqttc/examples/tls2.rs | 2 +- rumqttc/src/lib.rs | 4 ++-- rumqttc/src/v4/client.rs | 27 +++++---------------------- rumqttc/src/v4/mod.rs | 14 +++++++++----- rumqttc/src/v4/state.rs | 2 +- rumqttc/src/v5/mod.rs | 14 +++++++++----- rumqttc/src/v5/packet/disconnect.rs | 2 +- rumqttc/src/v5/state.rs | 8 +++----- 10 files changed, 34 insertions(+), 44 deletions(-) diff --git a/rumqttc/examples/async_manual_acks.rs b/rumqttc/examples/async_manual_acks.rs index 6790991e3..127e235db 100644 --- a/rumqttc/examples/async_manual_acks.rs +++ b/rumqttc/examples/async_manual_acks.rs @@ -14,7 +14,6 @@ fn create_conn() -> (AsyncClient, EventLoop) { AsyncClient::new(mqttoptions, 10) } - #[tokio::main(worker_threads = 1)] async fn main() -> Result<(), Box> { pretty_env_logger::init(); diff --git a/rumqttc/examples/syncpubsub.rs b/rumqttc/examples/syncpubsub.rs index b13abc65f..a14bcda4e 100644 --- a/rumqttc/examples/syncpubsub.rs +++ b/rumqttc/examples/syncpubsub.rs @@ -7,7 +7,9 @@ fn main() { let mut mqttoptions = MqttOptions::new("test-1", "localhost", 1883); let will = LastWill::new("hello/world", "good bye", QoS::AtMostOnce, false); - mqttoptions.set_keep_alive(Duration::from_secs(5)).set_last_will(will); + mqttoptions + .set_keep_alive(Duration::from_secs(5)) + .set_last_will(will); let (client, mut connection) = Client::new(mqttoptions, 10); thread::spawn(move || publish(client)); diff --git a/rumqttc/examples/tls2.rs b/rumqttc/examples/tls2.rs index 1bf5cf4fb..cd81364a4 100644 --- a/rumqttc/examples/tls2.rs +++ b/rumqttc/examples/tls2.rs @@ -14,7 +14,7 @@ async fn main() -> Result<(), Box> { // Dummies to prevent compilation error in CI let ca = vec![1, 2, 3]; let client_cert = vec![1, 2, 3]; - let client_key= vec![1, 2, 3]; + let client_key = vec![1, 2, 3]; // let ca = include_bytes!("/home/tekjar/tlsfiles/ca.cert.pem"); // let client_cert = include_bytes!("/home/tekjar/tlsfiles/device-1.cert.pem"); // let client_key = include_bytes!("/home/tekjar/tlsfiles/device-1.key.pem"); diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index 0a30cbefd..353397a1e 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -9,7 +9,7 @@ //! ---------------------------- //! //! ```no_run -//! use rumqttc::{MqttOptions, Client, QoS}; +//! use rumqttc::v4::{MqttOptions, Client, QoS}; //! use std::time::Duration; //! use std::thread; //! @@ -33,7 +33,7 @@ //! ------------------------------ //! //! ```no_run -//! use rumqttc::{MqttOptions, AsyncClient, QoS}; +//! use rumqttc::v4::{MqttOptions, AsyncClient, QoS}; //! use tokio::{task, time}; //! use std::time::Duration; //! use std::error::Error; diff --git a/rumqttc/src/v4/client.rs b/rumqttc/src/v4/client.rs index 8b95c1bc2..623281075 100644 --- a/rumqttc/src/v4/client.rs +++ b/rumqttc/src/v4/client.rs @@ -94,11 +94,7 @@ impl AsyncClient { } /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. - pub async fn ack( - &self, - publish: &Publish - ) -> Result<(), ClientError> - { + pub async fn ack(&self, publish: &Publish) -> Result<(), ClientError> { let ack = get_ack_req(publish); if let Some(ack) = ack { @@ -108,11 +104,7 @@ impl AsyncClient { } /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. - pub fn try_ack( - &self, - publish: &Publish - ) -> Result<(), ClientError> - { + pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { let ack = get_ack_req(publish); if let Some(ack) = ack { self.request_tx.try_send(ack)?; @@ -217,7 +209,7 @@ fn get_ack_req(publish: &Publish) -> Option { let ack = match publish.qos { QoS::AtMostOnce => return None, QoS::AtLeastOnce => Request::PubAck(PubAck::new(publish.pkid)), - QoS::ExactlyOnce => Request::PubRec(PubRec::new(publish.pkid)) + QoS::ExactlyOnce => Request::PubRec(PubRec::new(publish.pkid)), }; Some(ack) } @@ -277,26 +269,17 @@ impl Client { } /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. - pub fn ack( - &self, - publish: &Publish - ) -> Result<(), ClientError> - { + pub fn ack(&self, publish: &Publish) -> Result<(), ClientError> { pollster::block_on(self.client.ack(publish))?; Ok(()) } /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. - pub fn try_ack( - &self, - publish: &Publish - ) -> Result<(), ClientError> - { + pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { self.client.try_ack(publish)?; Ok(()) } - /// Sends a MQTT Subscribe to the eventloop pub fn subscribe>(&mut self, topic: S, qos: QoS) -> Result<(), ClientError> { pollster::block_on(self.client.subscribe(topic, qos))?; diff --git a/rumqttc/src/v4/mod.rs b/rumqttc/src/v4/mod.rs index 4797b0f2a..882dfdd06 100644 --- a/rumqttc/src/v4/mod.rs +++ b/rumqttc/src/v4/mod.rs @@ -14,8 +14,8 @@ pub use eventloop::{ConnectionError, Event, EventLoop}; pub use mqttbytes::v4::*; pub use mqttbytes::*; pub use state::{MqttState, StateError}; -pub use tokio_rustls::rustls::ClientConfig; pub use tls::Error; +pub use tokio_rustls::rustls::ClientConfig; pub type Incoming = Packet; @@ -336,7 +336,11 @@ impl MqttOptions { } /// Username and password - pub fn set_credentials, P: Into>(&mut self, username: U, password: P) -> &mut Self { + pub fn set_credentials, P: Into>( + &mut self, + username: U, + password: P, + ) -> &mut Self { self.credentials = Some((username.into(), password.into())); self } @@ -570,7 +574,7 @@ impl std::convert::TryFrom for MqttOptions { inflight, last_will: None, conn_timeout, - manual_acks: false + manual_acks: false, }) } } @@ -613,9 +617,9 @@ mod test { fn no_scheme() { let mut _mqtt_opts = MqttOptions::new("client_a", "a3f8czas.iot.eu-west-1.amazonaws.com/mqtt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=MyCreds%2F20201001%2Feu-west-1%2Fiotdevicegateway%2Faws4_request&X-Amz-Date=20201001T130812Z&X-Amz-Expires=7200&X-Amz-Signature=9ae09b49896f44270f2707551581953e6cac71a4ccf34c7c3415555be751b2d1&X-Amz-SignedHeaders=host", 443); - _mqtt_opts.set_transport(crate::Transport::wss(Vec::from("Test CA"), None, None)); + _mqtt_opts.set_transport(crate::v4::Transport::wss(Vec::from("Test CA"), None, None)); - if let crate::Transport::Wss(TlsConfiguration::Simple { + if let crate::v4::Transport::Wss(TlsConfiguration::Simple { ca, client_auth, alpn, diff --git a/rumqttc/src/v4/state.rs b/rumqttc/src/v4/state.rs index c51116ccf..0653d4e98 100644 --- a/rumqttc/src/v4/state.rs +++ b/rumqttc/src/v4/state.rs @@ -100,7 +100,7 @@ impl MqttState { // TODO: Optimize these sizes later events: VecDeque::with_capacity(100), write: BytesMut::with_capacity(10 * 1024), - manual_acks + manual_acks, } } diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index 012f7e0a6..d4b276151 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -14,8 +14,8 @@ pub use client::{AsyncClient, Client, ClientError, Connection}; pub use eventloop::{ConnectionError, Event, EventLoop}; pub use packet::*; pub use state::{MqttState, StateError}; -pub use tokio_rustls::rustls::ClientConfig; pub use tls::Error; +pub use tokio_rustls::rustls::ClientConfig; pub type Incoming = Packet; @@ -336,7 +336,11 @@ impl MqttOptions { } /// Username and password - pub fn set_credentials, P: Into>(&mut self, username: U, password: P) -> &mut Self { + pub fn set_credentials, P: Into>( + &mut self, + username: U, + password: P, + ) -> &mut Self { self.credentials = Some((username.into(), password.into())); self } @@ -570,7 +574,7 @@ impl std::convert::TryFrom for MqttOptions { inflight, last_will: None, conn_timeout, - manual_acks: false + manual_acks: false, }) } } @@ -613,9 +617,9 @@ mod test { fn no_scheme() { let mut _mqtt_opts = MqttOptions::new("client_a", "a3f8czas.iot.eu-west-1.amazonaws.com/mqtt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=MyCreds%2F20201001%2Feu-west-1%2Fiotdevicegateway%2Faws4_request&X-Amz-Date=20201001T130812Z&X-Amz-Expires=7200&X-Amz-Signature=9ae09b49896f44270f2707551581953e6cac71a4ccf34c7c3415555be751b2d1&X-Amz-SignedHeaders=host", 443); - _mqtt_opts.set_transport(crate::Transport::wss(Vec::from("Test CA"), None, None)); + _mqtt_opts.set_transport(crate::v5::Transport::wss(Vec::from("Test CA"), None, None)); - if let crate::Transport::Wss(TlsConfiguration::Simple { + if let crate::v5::Transport::Wss(TlsConfiguration::Simple { ca, client_auth, alpn, diff --git a/rumqttc/src/v5/packet/disconnect.rs b/rumqttc/src/v5/packet/disconnect.rs index 508f3c427..3662087f1 100644 --- a/rumqttc/src/v5/packet/disconnect.rs +++ b/rumqttc/src/v5/packet/disconnect.rs @@ -1,6 +1,6 @@ use std::convert::{TryFrom, TryInto}; -use bytes::{BufMut, BytesMut, Bytes}; +use bytes::{BufMut, Bytes, BytesMut}; use super::*; diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 491eabc62..2bc2b470a 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -1,4 +1,4 @@ -use super::{Event, Incoming, Outgoing, Request, packet::*}; +use super::{packet::*, Event, Incoming, Outgoing, Request}; use bytes::BytesMut; use std::collections::VecDeque; @@ -98,7 +98,7 @@ impl MqttState { // TODO: Optimize these sizes later events: VecDeque::with_capacity(100), write: BytesMut::with_capacity(10 * 1024), - manual_acks + manual_acks, } } @@ -485,9 +485,7 @@ impl MqttState { #[cfg(test)] mod test { use super::{MqttState, StateError}; - use crate::v5::{Event, Incoming, MqttOptions, Outgoing, Request}; - use mqttbytes::v5::*; - use mqttbytes::*; + use crate::v5::{packet::*, Event, Incoming, MqttOptions, Outgoing, Request}; fn build_outgoing_publish(qos: QoS) -> Publish { let topic = "hello/world".to_owned(); From 86ace5385eac31ba52126803f7c2b083438297d5 Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Sat, 12 Feb 2022 14:58:44 +0530 Subject: [PATCH 07/38] rumqttc: use request_buf in eventloop and client Signed-off-by: Abhik Jain --- rumqttc/src/v5/client.rs | 128 +++++++++++++++++------------ rumqttc/src/v5/eventloop.rs | 159 +++++++++++++++++++++--------------- 2 files changed, 169 insertions(+), 118 deletions(-) diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index 846ae9901..20ad59f95 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -4,19 +4,22 @@ use crate::v5::{packet::*, ConnectionError, Event, EventLoop, MqttOptions, Reque use async_channel::{SendError, Sender, TrySendError}; use bytes::Bytes; -use std::mem; -use tokio::runtime; -use tokio::runtime::Runtime; +use std::{ + collections::VecDeque, + mem, + sync::{Arc, Mutex}, +}; +use tokio::runtime::{self, Runtime}; /// Client Error #[derive(Debug, thiserror::Error)] pub enum ClientError { #[error("Failed to send cancel request to eventloop")] - Cancel(#[from] SendError<()>), + Cancel(SendError<()>), #[error("Failed to send mqtt requests to eventloop")] - Request(#[from] SendError), + Request(#[from] SendError<()>), #[error("Failed to send mqtt requests to eventloop")] - TryRequest(#[from] TrySendError), + TryRequest(#[from] TrySendError<()>), #[error("Serialization error")] Mqtt4(mqttbytes::Error), } @@ -25,7 +28,8 @@ pub enum ClientError { /// This is cloneable and can be used to asynchronously Publish, Subscribe. #[derive(Clone, Debug)] pub struct AsyncClient { - request_tx: Sender, + request_buf: Arc>>, + request_tx: Sender<()>, cancel_tx: Sender<()>, } @@ -33,10 +37,12 @@ impl AsyncClient { /// Create a new `AsyncClient` pub fn new(options: MqttOptions, cap: usize) -> (AsyncClient, EventLoop) { let mut eventloop = EventLoop::new(options, cap); + let request_buf = eventloop.buf().clone(); let request_tx = eventloop.handle(); let cancel_tx = eventloop.cancel_handle(); let client = AsyncClient { + request_buf, request_tx, cancel_tx, }; @@ -46,8 +52,13 @@ impl AsyncClient { /// Create a new `AsyncClient` from a pair of async channel `Sender`s. This is mostly useful for /// creating a test instance. - pub fn from_senders(request_tx: Sender, cancel_tx: Sender<()>) -> AsyncClient { + pub fn from_senders( + request_buf: Arc>>, + request_tx: Sender<()>, + cancel_tx: Sender<()>, + ) -> AsyncClient { AsyncClient { + request_buf, request_tx, cancel_tx, } @@ -60,16 +71,18 @@ impl AsyncClient { qos: QoS, retain: bool, payload: V, - ) -> Result<(), ClientError> + ) -> Result where S: Into, V: Into>, { let mut publish = Publish::new(topic, qos, payload); publish.retain = retain; + let pkid = publish.pkid; let publish = Request::Publish(publish); - self.request_tx.send(publish).await?; - Ok(()) + self.request_buf.lock().unwrap().push_back(publish); + self.request_tx.send(()).await?; + Ok(pkid) } /// Sends a MQTT Publish to the eventloop @@ -79,16 +92,18 @@ impl AsyncClient { qos: QoS, retain: bool, payload: V, - ) -> Result<(), ClientError> + ) -> Result where S: Into, V: Into>, { let mut publish = Publish::new(topic, qos, payload); publish.retain = retain; + let pkid = publish.pkid; let publish = Request::Publish(publish); - self.request_tx.try_send(publish)?; - Ok(()) + self.request_buf.lock().unwrap().push_back(publish); + self.request_tx.try_send(())?; + Ok(pkid) } /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. @@ -96,7 +111,8 @@ impl AsyncClient { let ack = get_ack_req(publish); if let Some(ack) = ack { - self.request_tx.send(ack).await?; + self.request_buf.lock().unwrap().push_back(ack); + self.request_tx.send(()).await?; } Ok(()) } @@ -105,7 +121,8 @@ impl AsyncClient { pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { let ack = get_ack_req(publish); if let Some(ack) = ack { - self.request_tx.try_send(ack)?; + self.request_buf.lock().unwrap().push_back(ack); + self.request_tx.try_send(())?; } Ok(()) } @@ -124,7 +141,8 @@ impl AsyncClient { let mut publish = Publish::from_bytes(topic, qos, payload); publish.retain = retain; let publish = Request::Publish(publish); - self.request_tx.send(publish).await?; + self.request_buf.lock().unwrap().push_back(publish); + self.request_tx.send(()).await?; Ok(()) } @@ -132,7 +150,8 @@ impl AsyncClient { pub async fn subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { let subscribe = Subscribe::new(topic.into(), qos); let request = Request::Subscribe(subscribe); - self.request_tx.send(request).await?; + self.request_buf.lock().unwrap().push_back(request); + self.request_tx.send(()).await?; Ok(()) } @@ -140,7 +159,8 @@ impl AsyncClient { pub fn try_subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { let subscribe = Subscribe::new(topic.into(), qos); let request = Request::Subscribe(subscribe); - self.request_tx.try_send(request)?; + self.request_buf.lock().unwrap().push_back(request); + self.request_tx.try_send(())?; Ok(()) } @@ -151,7 +171,8 @@ impl AsyncClient { { let subscribe = Subscribe::new_many(topics); let request = Request::Subscribe(subscribe); - self.request_tx.send(request).await?; + self.request_buf.lock().unwrap().push_back(request); + self.request_tx.send(()).await?; Ok(()) } @@ -162,7 +183,8 @@ impl AsyncClient { { let subscribe = Subscribe::new_many(topics); let request = Request::Subscribe(subscribe); - self.request_tx.try_send(request)?; + self.request_buf.lock().unwrap().push_back(request); + self.request_tx.try_send(())?; Ok(()) } @@ -170,7 +192,8 @@ impl AsyncClient { pub async fn unsubscribe>(&self, topic: S) -> Result<(), ClientError> { let unsubscribe = Unsubscribe::new(topic.into()); let request = Request::Unsubscribe(unsubscribe); - self.request_tx.send(request).await?; + self.request_buf.lock().unwrap().push_back(request); + self.request_tx.send(()).await?; Ok(()) } @@ -178,28 +201,42 @@ impl AsyncClient { pub fn try_unsubscribe>(&self, topic: S) -> Result<(), ClientError> { let unsubscribe = Unsubscribe::new(topic.into()); let request = Request::Unsubscribe(unsubscribe); - self.request_tx.try_send(request)?; + self.request_buf.lock().unwrap().push_back(request); + self.request_tx.try_send(())?; Ok(()) } /// Sends a MQTT disconnect to the eventloop pub async fn disconnect(&self) -> Result<(), ClientError> { let request = Request::Disconnect; - self.request_tx.send(request).await?; + self.request_buf.lock().unwrap().push_back(request); + self.request_tx.send(()).await?; Ok(()) } /// Sends a MQTT disconnect to the eventloop pub fn try_disconnect(&self) -> Result<(), ClientError> { let request = Request::Disconnect; - self.request_tx.try_send(request)?; + self.request_buf.lock().unwrap().push_back(request); + self.request_tx.try_send(())?; Ok(()) } /// Stops the eventloop right away pub async fn cancel(&self) -> Result<(), ClientError> { - self.cancel_tx.send(()).await?; - Ok(()) + self.cancel_tx.send(()).await.map_err(ClientError::Cancel) + } + + #[inline] + async fn send_and_notify(&self, request: Request) -> Result<(), async_channel::SendError<()>> { + self.request_buf.lock().unwrap().push_back(request); + self.request_tx.send(()).await + } + + #[inline] + fn try_send_and_notify(&self, request: Request) -> Result<(), async_channel::TrySendError<()>> { + self.request_buf.lock().unwrap().push_back(request); + self.request_tx.try_send(()) } } @@ -242,13 +279,12 @@ impl Client { qos: QoS, retain: bool, payload: V, - ) -> Result<(), ClientError> + ) -> Result where S: Into, V: Into>, { - pollster::block_on(self.client.publish(topic, qos, retain, payload))?; - Ok(()) + pollster::block_on(self.client.publish(topic, qos, retain, payload)) } pub fn try_publish( @@ -257,31 +293,27 @@ impl Client { qos: QoS, retain: bool, payload: V, - ) -> Result<(), ClientError> + ) -> Result where S: Into, V: Into>, { - self.client.try_publish(topic, qos, retain, payload)?; - Ok(()) + self.client.try_publish(topic, qos, retain, payload) } /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. pub fn ack(&self, publish: &Publish) -> Result<(), ClientError> { - pollster::block_on(self.client.ack(publish))?; - Ok(()) + pollster::block_on(self.client.ack(publish)) } /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { - self.client.try_ack(publish)?; - Ok(()) + self.client.try_ack(publish) } /// Sends a MQTT Subscribe to the eventloop pub fn subscribe>(&mut self, topic: S, qos: QoS) -> Result<(), ClientError> { - pollster::block_on(self.client.subscribe(topic, qos))?; - Ok(()) + pollster::block_on(self.client.subscribe(topic, qos)) } /// Sends a MQTT Subscribe to the eventloop @@ -290,8 +322,7 @@ impl Client { topic: S, qos: QoS, ) -> Result<(), ClientError> { - self.client.try_subscribe(topic, qos)?; - Ok(()) + self.client.try_subscribe(topic, qos) } /// Sends a MQTT Subscribe for multiple topics to the eventloop @@ -311,32 +342,27 @@ impl Client { /// Sends a MQTT Unsubscribe to the eventloop pub fn unsubscribe>(&mut self, topic: S) -> Result<(), ClientError> { - pollster::block_on(self.client.unsubscribe(topic))?; - Ok(()) + pollster::block_on(self.client.unsubscribe(topic)) } /// Sends a MQTT Unsubscribe to the eventloop pub fn try_unsubscribe>(&mut self, topic: S) -> Result<(), ClientError> { - self.client.try_unsubscribe(topic)?; - Ok(()) + self.client.try_unsubscribe(topic) } /// Sends a MQTT disconnect to the eventloop pub fn disconnect(&mut self) -> Result<(), ClientError> { - pollster::block_on(self.client.disconnect())?; - Ok(()) + pollster::block_on(self.client.disconnect()) } /// Sends a MQTT disconnect to the eventloop pub fn try_disconnect(&mut self) -> Result<(), ClientError> { - self.client.try_disconnect()?; - Ok(()) + self.client.try_disconnect() } /// Stops the eventloop right away pub fn cancel(&mut self) -> Result<(), ClientError> { - pollster::block_on(self.client.cancel())?; - Ok(()) + pollster::block_on(self.client.cancel()) } } diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index fe8b1cb6b..1c536db26 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -14,12 +14,16 @@ use tokio::time::{self, error::Elapsed, Instant, Sleep}; #[cfg(feature = "websocket")] use ws_stream_tungstenite::WsStream; -use std::io; #[cfg(unix)] use std::path::Path; -use std::pin::Pin; -use std::time::Duration; -use std::vec::IntoIter; +use std::{ + collections::VecDeque, + io, + pin::Pin, + sync::{Arc, Mutex}, + time::Duration, + vec::IntoIter, +}; /// Critical errors during eventloop polling #[derive(Debug, thiserror::Error)] @@ -48,10 +52,12 @@ pub struct EventLoop { pub options: MqttOptions, /// Current state of the connection pub state: MqttState, + request_buf: Arc>>, + request_buf_cache: VecDeque, /// Request stream - pub requests_rx: Receiver, + pub requests_rx: Receiver<()>, /// Requests handle to send requests - pub requests_tx: Sender, + pub requests_tx: Sender<()>, /// Pending packets from last session pub pending: IntoIter, /// Network connection to the broker @@ -78,7 +84,8 @@ impl EventLoop { /// access and update `options`, `state` and `requests`. pub fn new(options: MqttOptions, cap: usize) -> EventLoop { let (cancel_tx, cancel_rx) = bounded(5); - let (requests_tx, requests_rx) = bounded(cap); + let (requests_tx, requests_rx) = bounded(1); + let request_buf = Arc::new(Mutex::new(VecDeque::with_capacity(cap))); let pending = Vec::new(); let pending = pending.into_iter(); let max_inflight = options.inflight; @@ -87,6 +94,8 @@ impl EventLoop { EventLoop { options, state: MqttState::new(max_inflight, manual_acks), + request_buf, + request_buf_cache: VecDeque::with_capacity(cap), requests_tx, requests_rx, pending, @@ -98,10 +107,14 @@ impl EventLoop { } /// Returns a handle to communicate with this eventloop - pub fn handle(&self) -> Sender { + pub fn handle(&self) -> Sender<()> { self.requests_tx.clone() } + pub fn buf(&self) -> &Arc>> { + &self.request_buf + } + /// Handle for cancelling the eventloop. /// /// Can be useful in cases when connection should be halted immediately @@ -156,68 +169,80 @@ impl EventLoop { return Ok(event); } - // this loop is necessary since self.incoming.pop_front() might return None. In that case, - // instead of returning a None event, we try again. - select! { - // Pull a bunch of packets from network, reply in bunch and yield the first item - o = network.readb(&mut self.state) => { - o?; - // flush all the acks and return first incoming packet - network.flush(&mut self.state.write).await?; - Ok(self.state.events.pop_front().unwrap()) - }, - // Pull next request from user requests channel. - // If conditions in the below branch are for flow control. We read next user - // user request only when inflight messages are < configured inflight and there - // are no collisions while handling previous outgoing requests. - // - // Flow control is based on ack count. If inflight packet count in the buffer is - // less than max_inflight setting, next outgoing request will progress. For this - // to work correctly, broker should ack in sequence (a lot of brokers won't) - // - // E.g If max inflight = 5, user requests will be blocked when inflight queue - // looks like this -> [1, 2, 3, 4, 5]. - // If broker acking 2 instead of 1 -> [1, x, 3, 4, 5]. - // This pulls next user request. But because max packet id = max_inflight, next - // user request's packet id will roll to 1. This replaces existing packet id 1. - // Resulting in a collision - // - // Eventloop can stop receiving outgoing user requests when previous outgoing - // request collided. I.e collision state. Collision state will be cleared only - // when correct ack is received - // Full inflight queue will look like -> [1a, 2, 3, 4, 5]. - // If 3 is acked instead of 1 first -> [1a, 2, x, 4, 5]. - // After collision with pkid 1 -> [1b ,2, x, 4, 5]. - // 1a is saved to state and event loop is set to collision mode stopping new - // outgoing requests (along with 1b). - o = self.requests_rx.recv(), if !inflight_full && !pending && !collision => match o { - Ok(request) => { + // this loop is necessary as self.request_buf might be empty, in which case it is possible + // for self.state.events to be empty, and so popping off from it might return None. If None + // is returned, we select again. + loop { + select! { + // Pull a bunch of packets from network, reply in bunch and yield the first item + o = network.readb(&mut self.state) => { + o?; + // flush all the acks and return first incoming packet + network.flush(&mut self.state.write).await?; + return Ok(self.state.events.pop_front().unwrap()); + }, + // Pull next request from user requests channel. + // If conditions in the below branch are for flow control. We read next user + // request only when inflight messages are < configured inflight and there are no + // collisions while handling previous outgoing requests. + // + // Flow control is based on ack count. If inflight packet count in the buffer is + // less than max_inflight setting, next outgoing request will progress. For this to + // work correctly, broker should ack in sequence (a lot of brokers won't) + // + // E.g If max inflight = 5, user requests will be blocked when inflight queue looks + // like this -> [1, 2, 3, 4, 5]. + // If broker acking 2 instead of 1 -> [1, x, 3, 4, 5]. + // This pulls next user request. But because max packet id = max_inflight, next + // user request's packet id will roll to 1. This replaces existing packet id 1. + // Resulting in a collision + // + // Eventloop can stop receiving outgoing user requests when previous outgoing + // request collided. I.e collision state. Collision state will be cleared only + // when correct ack is received + // Full inflight queue will look like -> [1a, 2, 3, 4, 5]. + // If 3 is acked instead of 1 first -> [1a, 2, x, 4, 5]. + // After collision with pkid 1 -> [1b ,2, x, 4, 5]. + // 1a is saved to state and event loop is set to collision mode stopping new + // outgoing requests (along with 1b). + o = self.requests_rx.recv(), if !inflight_full && !pending && !collision => match o { + Ok(_request_notif) => { + // swapping to avoid blocking the mutex + std::mem::swap(&mut self.request_buf_cache,&mut *self.request_buf.lock().unwrap()); + if self.request_buf_cache.is_empty() { + continue; + } + for request in self.request_buf_cache.drain(..) { + self.state.handle_outgoing_packet(request)?; + } + network.flush(&mut self.state.write).await?; + // remaining events in the self.state.events will be taken out in next call + // to poll() even before the select! is used. + return Ok(self.state.events.pop_front().unwrap()) + } + Err(_) => return Err(ConnectionError::RequestsDone), + }, + // Handle the next pending packet from previous session. Disable + // this branch when done with all the pending packets + Some(request) = next_pending(throttle, &mut self.pending), if pending => { self.state.handle_outgoing_packet(request)?; network.flush(&mut self.state.write).await?; - Ok(self.state.events.pop_front().unwrap()) + return Ok(self.state.events.pop_front().unwrap()) + }, + // We generate pings irrespective of network activity. This keeps the ping logic + // simple. We can change this behavior in future if necessary (to prevent extra pings) + _ = self.keepalive_timeout.as_mut().unwrap() => { + let timeout = self.keepalive_timeout.as_mut().unwrap(); + timeout.as_mut().reset(Instant::now() + self.options.keep_alive); + + self.state.handle_outgoing_packet(Request::PingReq)?; + network.flush(&mut self.state.write).await?; + return Ok(self.state.events.pop_front().unwrap()) + } + // cancellation requests to stop the polling + _ = self.cancel_rx.recv() => { + return Err(ConnectionError::Cancel) } - Err(_) => Err(ConnectionError::RequestsDone), - }, - // Handle the next pending packet from previous session. Disable - // this branch when done with all the pending packets - Some(request) = next_pending(throttle, &mut self.pending), if pending => { - self.state.handle_outgoing_packet(request)?; - network.flush(&mut self.state.write).await?; - Ok(self.state.events.pop_front().unwrap()) - }, - // We generate pings irrespective of network activity. This keeps the ping logic - // simple. We can change this behavior in future if necessary (to prevent extra pings) - _ = self.keepalive_timeout.as_mut().unwrap() => { - let timeout = self.keepalive_timeout.as_mut().unwrap(); - timeout.as_mut().reset(Instant::now() + self.options.keep_alive); - - self.state.handle_outgoing_packet(Request::PingReq)?; - network.flush(&mut self.state.write).await?; - Ok(self.state.events.pop_front().unwrap()) - } - // cancellation requests to stop the polling - _ = self.cancel_rx.recv() => { - Err(ConnectionError::Cancel) } } } From 76ad5411fa1816cff91c947066c7f8c72163b87b Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Sat, 12 Feb 2022 15:25:19 +0530 Subject: [PATCH 08/38] rumqttc: client: more intuitive API + code reduction Signed-off-by: Abhik Jain --- rumqttc/src/v5/client.rs | 100 ++++++++++++++++----------------------- 1 file changed, 42 insertions(+), 58 deletions(-) diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index 20ad59f95..0d5acc461 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -16,10 +16,10 @@ use tokio::runtime::{self, Runtime}; pub enum ClientError { #[error("Failed to send cancel request to eventloop")] Cancel(SendError<()>), - #[error("Failed to send mqtt requests to eventloop")] - Request(#[from] SendError<()>), - #[error("Failed to send mqtt requests to eventloop")] - TryRequest(#[from] TrySendError<()>), + #[error("Failed to send mqtt request to eventloop, the evenloop has been closed")] + EventloopClosed, + #[error("Failed to send mqtt request to evenloop, to requests buffer is full right now")] + RequestsFull, #[error("Serialization error")] Mqtt4(mqttbytes::Error), } @@ -29,6 +29,7 @@ pub enum ClientError { #[derive(Clone, Debug)] pub struct AsyncClient { request_buf: Arc>>, + request_buf_capacity: usize, request_tx: Sender<()>, cancel_tx: Sender<()>, } @@ -43,6 +44,7 @@ impl AsyncClient { let client = AsyncClient { request_buf, + request_buf_capacity: cap, request_tx, cancel_tx, }; @@ -56,9 +58,11 @@ impl AsyncClient { request_buf: Arc>>, request_tx: Sender<()>, cancel_tx: Sender<()>, + cap: usize, ) -> AsyncClient { AsyncClient { request_buf, + request_buf_capacity: cap, request_tx, cancel_tx, } @@ -79,9 +83,7 @@ impl AsyncClient { let mut publish = Publish::new(topic, qos, payload); publish.retain = retain; let pkid = publish.pkid; - let publish = Request::Publish(publish); - self.request_buf.lock().unwrap().push_back(publish); - self.request_tx.send(()).await?; + self.send_and_notify(Request::Publish(publish)).await?; Ok(pkid) } @@ -100,9 +102,7 @@ impl AsyncClient { let mut publish = Publish::new(topic, qos, payload); publish.retain = retain; let pkid = publish.pkid; - let publish = Request::Publish(publish); - self.request_buf.lock().unwrap().push_back(publish); - self.request_tx.try_send(())?; + self.try_send_and_notify(Request::Publish(publish))?; Ok(pkid) } @@ -111,8 +111,7 @@ impl AsyncClient { let ack = get_ack_req(publish); if let Some(ack) = ack { - self.request_buf.lock().unwrap().push_back(ack); - self.request_tx.send(()).await?; + self.send_and_notify(ack).await?; } Ok(()) } @@ -121,8 +120,7 @@ impl AsyncClient { pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { let ack = get_ack_req(publish); if let Some(ack) = ack { - self.request_buf.lock().unwrap().push_back(ack); - self.request_tx.try_send(())?; + self.try_send_and_notify(ack)?; } Ok(()) } @@ -140,28 +138,19 @@ impl AsyncClient { { let mut publish = Publish::from_bytes(topic, qos, payload); publish.retain = retain; - let publish = Request::Publish(publish); - self.request_buf.lock().unwrap().push_back(publish); - self.request_tx.send(()).await?; - Ok(()) + self.send_and_notify(Request::Publish(publish)).await } /// Sends a MQTT Subscribe to the eventloop pub async fn subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { let subscribe = Subscribe::new(topic.into(), qos); - let request = Request::Subscribe(subscribe); - self.request_buf.lock().unwrap().push_back(request); - self.request_tx.send(()).await?; - Ok(()) + self.send_and_notify(Request::Subscribe(subscribe)).await } /// Sends a MQTT Subscribe to the eventloop pub fn try_subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { let subscribe = Subscribe::new(topic.into(), qos); - let request = Request::Subscribe(subscribe); - self.request_buf.lock().unwrap().push_back(request); - self.request_tx.try_send(())?; - Ok(()) + self.try_send_and_notify(Request::Subscribe(subscribe)) } /// Sends a MQTT Subscribe for multiple topics to the eventloop @@ -170,10 +159,7 @@ impl AsyncClient { T: IntoIterator, { let subscribe = Subscribe::new_many(topics); - let request = Request::Subscribe(subscribe); - self.request_buf.lock().unwrap().push_back(request); - self.request_tx.send(()).await?; - Ok(()) + self.send_and_notify(Request::Subscribe(subscribe)).await } /// Sends a MQTT Subscribe for multiple topics to the eventloop @@ -182,44 +168,30 @@ impl AsyncClient { T: IntoIterator, { let subscribe = Subscribe::new_many(topics); - let request = Request::Subscribe(subscribe); - self.request_buf.lock().unwrap().push_back(request); - self.request_tx.try_send(())?; - Ok(()) + self.try_send_and_notify(Request::Subscribe(subscribe)) } /// Sends a MQTT Unsubscribe to the eventloop pub async fn unsubscribe>(&self, topic: S) -> Result<(), ClientError> { let unsubscribe = Unsubscribe::new(topic.into()); - let request = Request::Unsubscribe(unsubscribe); - self.request_buf.lock().unwrap().push_back(request); - self.request_tx.send(()).await?; - Ok(()) + self.send_and_notify(Request::Unsubscribe(unsubscribe)) + .await } /// Sends a MQTT Unsubscribe to the eventloop pub fn try_unsubscribe>(&self, topic: S) -> Result<(), ClientError> { let unsubscribe = Unsubscribe::new(topic.into()); - let request = Request::Unsubscribe(unsubscribe); - self.request_buf.lock().unwrap().push_back(request); - self.request_tx.try_send(())?; - Ok(()) + self.try_send_and_notify(Request::Unsubscribe(unsubscribe)) } /// Sends a MQTT disconnect to the eventloop pub async fn disconnect(&self) -> Result<(), ClientError> { - let request = Request::Disconnect; - self.request_buf.lock().unwrap().push_back(request); - self.request_tx.send(()).await?; - Ok(()) + self.send_and_notify(Request::Disconnect).await } /// Sends a MQTT disconnect to the eventloop pub fn try_disconnect(&self) -> Result<(), ClientError> { - let request = Request::Disconnect; - self.request_buf.lock().unwrap().push_back(request); - self.request_tx.try_send(())?; - Ok(()) + self.try_send_and_notify(Request::Disconnect) } /// Stops the eventloop right away @@ -227,16 +199,28 @@ impl AsyncClient { self.cancel_tx.send(()).await.map_err(ClientError::Cancel) } - #[inline] - async fn send_and_notify(&self, request: Request) -> Result<(), async_channel::SendError<()>> { - self.request_buf.lock().unwrap().push_back(request); - self.request_tx.send(()).await + async fn send_and_notify(&self, request: Request) -> Result<(), ClientError> { + let mut request_buf = self.request_buf.lock().unwrap(); + if request_buf.len() == self.request_buf_capacity { + return Err(ClientError::RequestsFull); + } + request_buf.push_back(request); + if let Err(SendError(_)) = self.request_tx.send(()).await { + return Err(ClientError::EventloopClosed); + }; + Ok(()) } - #[inline] - fn try_send_and_notify(&self, request: Request) -> Result<(), async_channel::TrySendError<()>> { - self.request_buf.lock().unwrap().push_back(request); - self.request_tx.try_send(()) + fn try_send_and_notify(&self, request: Request) -> Result<(), ClientError> { + let mut request_buf = self.request_buf.lock().unwrap(); + if request_buf.len() == self.request_buf_capacity { + return Err(ClientError::RequestsFull); + } + request_buf.push_back(request); + if let Err(TrySendError::Closed(_)) = self.request_tx.try_send(()) { + return Err(ClientError::EventloopClosed); + } + Ok(()) } } From a491af494f1ef61fe43cd79ef749bbc4eb72dd41 Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Sat, 12 Feb 2022 15:25:45 +0530 Subject: [PATCH 09/38] rumqttc: examples: remove needless import Signed-off-by: Abhik Jain --- rumqttc/examples/async_manual_acks.rs | 2 +- rumqttc/examples/asyncpubsub.rs | 2 +- rumqttc/examples/syncpubsub.rs | 2 +- rumqttc/examples/tls.rs | 2 +- rumqttc/examples/tls2.rs | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/rumqttc/examples/async_manual_acks.rs b/rumqttc/examples/async_manual_acks.rs index 127e235db..9f2678b51 100644 --- a/rumqttc/examples/async_manual_acks.rs +++ b/rumqttc/examples/async_manual_acks.rs @@ -1,6 +1,6 @@ use tokio::{task, time}; -use rumqttc::v4::{self, AsyncClient, Event, EventLoop, Incoming, MqttOptions, QoS}; +use rumqttc::v4::{AsyncClient, Event, EventLoop, Incoming, MqttOptions, QoS}; use std::error::Error; use std::time::Duration; diff --git a/rumqttc/examples/asyncpubsub.rs b/rumqttc/examples/asyncpubsub.rs index fef15451d..b4de624e7 100644 --- a/rumqttc/examples/asyncpubsub.rs +++ b/rumqttc/examples/asyncpubsub.rs @@ -1,6 +1,6 @@ use tokio::{task, time}; -use rumqttc::v4::{self, AsyncClient, MqttOptions, QoS}; +use rumqttc::v4::{AsyncClient, MqttOptions, QoS}; use std::error::Error; use std::time::Duration; diff --git a/rumqttc/examples/syncpubsub.rs b/rumqttc/examples/syncpubsub.rs index a14bcda4e..2001b4d29 100644 --- a/rumqttc/examples/syncpubsub.rs +++ b/rumqttc/examples/syncpubsub.rs @@ -1,4 +1,4 @@ -use rumqttc::v4::{self, Client, LastWill, MqttOptions, QoS}; +use rumqttc::v4::{Client, LastWill, MqttOptions, QoS}; use std::thread; use std::time::Duration; diff --git a/rumqttc/examples/tls.rs b/rumqttc/examples/tls.rs index 90bb4180f..dedf975db 100644 --- a/rumqttc/examples/tls.rs +++ b/rumqttc/examples/tls.rs @@ -1,6 +1,6 @@ //! Example of how to configure rumqttd to connect to a server using TLS and authentication. -use rumqttc::v4::{self, AsyncClient, Event, Incoming, MqttOptions, Transport}; +use rumqttc::v4::{AsyncClient, Event, Incoming, MqttOptions, Transport}; use rustls::ClientConfig; use std::error::Error; diff --git a/rumqttc/examples/tls2.rs b/rumqttc/examples/tls2.rs index cd81364a4..f6392b99d 100644 --- a/rumqttc/examples/tls2.rs +++ b/rumqttc/examples/tls2.rs @@ -1,6 +1,6 @@ //! Example of how to configure rumqttd to connect to a server using TLS and authentication. -use rumqttc::v4::{self, AsyncClient, Key, MqttOptions, TlsConfiguration, Transport}; +use rumqttc::v4::{AsyncClient, Key, MqttOptions, TlsConfiguration, Transport}; use std::error::Error; #[tokio::main] From 9401e29758790db94291983fd80bf6a17924d4f8 Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Sun, 13 Feb 2022 18:03:07 +0530 Subject: [PATCH 10/38] rumqttc: v5: replace async_channel + pollseter -> flume Signed-off-by: Abhik Jain --- rumqttc/Cargo.toml | 1 + rumqttc/src/v5/client.rs | 74 +++++++++++++++++++++++++------------ rumqttc/src/v5/eventloop.rs | 8 ++-- rumqttc/src/v5/mod.rs | 2 +- 4 files changed, 56 insertions(+), 29 deletions(-) diff --git a/rumqttc/Cargo.toml b/rumqttc/Cargo.toml index 81a7b05b7..bbf7dbd7b 100644 --- a/rumqttc/Cargo.toml +++ b/rumqttc/Cargo.toml @@ -31,6 +31,7 @@ log = "0.4" thiserror = "1.0.21" http = "^0.2" url = { version = "2.2", default-features = false, optional = true } +flume = "0.10.10" [dev-dependencies] pretty_env_logger = "0.4" diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index 0d5acc461..6f2930bd2 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -2,8 +2,8 @@ //! async eventloop. use crate::v5::{packet::*, ConnectionError, Event, EventLoop, MqttOptions, Request}; -use async_channel::{SendError, Sender, TrySendError}; use bytes::Bytes; +use flume::{SendError, Sender, TrySendError}; use std::{ collections::VecDeque, mem, @@ -83,7 +83,8 @@ impl AsyncClient { let mut publish = Publish::new(topic, qos, payload); publish.retain = retain; let pkid = publish.pkid; - self.send_and_notify(Request::Publish(publish)).await?; + self.send_async_and_notify(Request::Publish(publish)) + .await?; Ok(pkid) } @@ -108,18 +109,15 @@ impl AsyncClient { /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. pub async fn ack(&self, publish: &Publish) -> Result<(), ClientError> { - let ack = get_ack_req(publish); - - if let Some(ack) = ack { - self.send_and_notify(ack).await?; + if let Some(ack) = get_ack_req(publish) { + self.send_async_and_notify(ack).await?; } Ok(()) } /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { - let ack = get_ack_req(publish); - if let Some(ack) = ack { + if let Some(ack) = get_ack_req(publish) { self.try_send_and_notify(ack)?; } Ok(()) @@ -138,13 +136,14 @@ impl AsyncClient { { let mut publish = Publish::from_bytes(topic, qos, payload); publish.retain = retain; - self.send_and_notify(Request::Publish(publish)).await + self.send_async_and_notify(Request::Publish(publish)).await } /// Sends a MQTT Subscribe to the eventloop pub async fn subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { let subscribe = Subscribe::new(topic.into(), qos); - self.send_and_notify(Request::Subscribe(subscribe)).await + self.send_async_and_notify(Request::Subscribe(subscribe)) + .await } /// Sends a MQTT Subscribe to the eventloop @@ -159,7 +158,8 @@ impl AsyncClient { T: IntoIterator, { let subscribe = Subscribe::new_many(topics); - self.send_and_notify(Request::Subscribe(subscribe)).await + self.send_async_and_notify(Request::Subscribe(subscribe)) + .await } /// Sends a MQTT Subscribe for multiple topics to the eventloop @@ -174,7 +174,7 @@ impl AsyncClient { /// Sends a MQTT Unsubscribe to the eventloop pub async fn unsubscribe>(&self, topic: S) -> Result<(), ClientError> { let unsubscribe = Unsubscribe::new(topic.into()); - self.send_and_notify(Request::Unsubscribe(unsubscribe)) + self.send_async_and_notify(Request::Unsubscribe(unsubscribe)) .await } @@ -186,7 +186,7 @@ impl AsyncClient { /// Sends a MQTT disconnect to the eventloop pub async fn disconnect(&self) -> Result<(), ClientError> { - self.send_and_notify(Request::Disconnect).await + self.send_async_and_notify(Request::Disconnect).await } /// Sends a MQTT disconnect to the eventloop @@ -196,16 +196,31 @@ impl AsyncClient { /// Stops the eventloop right away pub async fn cancel(&self) -> Result<(), ClientError> { - self.cancel_tx.send(()).await.map_err(ClientError::Cancel) + self.cancel_tx + .send_async(()) + .await + .map_err(ClientError::Cancel) } - async fn send_and_notify(&self, request: Request) -> Result<(), ClientError> { + async fn send_async_and_notify(&self, request: Request) -> Result<(), ClientError> { let mut request_buf = self.request_buf.lock().unwrap(); if request_buf.len() == self.request_buf_capacity { return Err(ClientError::RequestsFull); } request_buf.push_back(request); - if let Err(SendError(_)) = self.request_tx.send(()).await { + if let Err(SendError(_)) = self.request_tx.send_async(()).await { + return Err(ClientError::EventloopClosed); + }; + Ok(()) + } + + pub(crate) fn send_and_notify(&self, request: Request) -> Result<(), ClientError> { + let mut request_buf = self.request_buf.lock().unwrap(); + if request_buf.len() == self.request_buf_capacity { + return Err(ClientError::RequestsFull); + } + request_buf.push_back(request); + if let Err(SendError(_)) = self.request_tx.send(()) { return Err(ClientError::EventloopClosed); }; Ok(()) @@ -217,7 +232,7 @@ impl AsyncClient { return Err(ClientError::RequestsFull); } request_buf.push_back(request); - if let Err(TrySendError::Closed(_)) = self.request_tx.try_send(()) { + if let Err(TrySendError::Disconnected(_)) = self.request_tx.try_send(()) { return Err(ClientError::EventloopClosed); } Ok(()) @@ -268,7 +283,11 @@ impl Client { S: Into, V: Into>, { - pollster::block_on(self.client.publish(topic, qos, retain, payload)) + let mut publish = Publish::new(topic, qos, payload); + publish.retain = retain; + let pkid = publish.pkid; + self.client.send_and_notify(Request::Publish(publish))?; + Ok(pkid) } pub fn try_publish( @@ -287,7 +306,10 @@ impl Client { /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. pub fn ack(&self, publish: &Publish) -> Result<(), ClientError> { - pollster::block_on(self.client.ack(publish)) + if let Some(ack) = get_ack_req(publish) { + self.client.send_and_notify(ack)?; + } + Ok(()) } /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. @@ -297,7 +319,8 @@ impl Client { /// Sends a MQTT Subscribe to the eventloop pub fn subscribe>(&mut self, topic: S, qos: QoS) -> Result<(), ClientError> { - pollster::block_on(self.client.subscribe(topic, qos)) + let subscribe = Subscribe::new(topic.into(), qos); + self.client.send_and_notify(Request::Subscribe(subscribe)) } /// Sends a MQTT Subscribe to the eventloop @@ -314,7 +337,8 @@ impl Client { where T: IntoIterator, { - pollster::block_on(self.client.subscribe_many(topics)) + let subscribe = Subscribe::new_many(topics); + self.client.send_and_notify(Request::Subscribe(subscribe)) } pub fn try_subscribe_many(&mut self, topics: T) -> Result<(), ClientError> @@ -326,7 +350,9 @@ impl Client { /// Sends a MQTT Unsubscribe to the eventloop pub fn unsubscribe>(&mut self, topic: S) -> Result<(), ClientError> { - pollster::block_on(self.client.unsubscribe(topic)) + let unsubscribe = Unsubscribe::new(topic.into()); + self.client + .send_and_notify(Request::Unsubscribe(unsubscribe)) } /// Sends a MQTT Unsubscribe to the eventloop @@ -336,7 +362,7 @@ impl Client { /// Sends a MQTT disconnect to the eventloop pub fn disconnect(&mut self) -> Result<(), ClientError> { - pollster::block_on(self.client.disconnect()) + self.client.send_and_notify(Request::Disconnect) } /// Sends a MQTT disconnect to the eventloop @@ -346,7 +372,7 @@ impl Client { /// Stops the eventloop right away pub fn cancel(&mut self) -> Result<(), ClientError> { - pollster::block_on(self.client.cancel()) + self.client.cancel_tx.send(()).map_err(ClientError::Cancel) } } diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index 1c536db26..3eb1dd6c6 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -3,9 +3,9 @@ use crate::v5::{ StateError, Transport, }; -use async_channel::{bounded, Receiver, Sender}; #[cfg(feature = "websocket")] use async_tungstenite::tokio::{connect_async, connect_async_with_tls_connector}; +use flume::{bounded, Receiver, Sender}; use tokio::net::TcpStream; #[cfg(unix)] use tokio::net::UnixStream; @@ -205,7 +205,7 @@ impl EventLoop { // After collision with pkid 1 -> [1b ,2, x, 4, 5]. // 1a is saved to state and event loop is set to collision mode stopping new // outgoing requests (along with 1b). - o = self.requests_rx.recv(), if !inflight_full && !pending && !collision => match o { + o = self.requests_rx.recv_async(), if !inflight_full && !pending && !collision => match o { Ok(_request_notif) => { // swapping to avoid blocking the mutex std::mem::swap(&mut self.request_buf_cache,&mut *self.request_buf.lock().unwrap()); @@ -240,7 +240,7 @@ impl EventLoop { return Ok(self.state.events.pop_front().unwrap()) } // cancellation requests to stop the polling - _ = self.cancel_rx.recv() => { + _ = self.cancel_rx.recv_async() => { return Err(ConnectionError::Cancel) } } @@ -256,7 +256,7 @@ async fn connect_or_cancel( // resolved. Returns with an error if connections fail continuously select! { o = connect(options) => o, - _ = cancel_rx.recv() => { + _ = cancel_rx.recv_async() => { Err(ConnectionError::Cancel) } } diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index d4b276151..83e21ef01 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -9,9 +9,9 @@ mod packet; mod state; mod tls; -pub use async_channel::{SendError, Sender, TrySendError}; pub use client::{AsyncClient, Client, ClientError, Connection}; pub use eventloop::{ConnectionError, Event, EventLoop}; +pub use flume::{SendError, Sender, TrySendError}; pub use packet::*; pub use state::{MqttState, StateError}; pub use tls::Error; From 72fb077b1f90569c2e742fa336cb88cada083d14 Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Tue, 15 Feb 2022 11:40:58 +0530 Subject: [PATCH 11/38] rumqttc: v5: AsyncClient: limit scope of MutecGuard for async runtime Signed-off-by: Abhik Jain --- rumqttc/src/v5/client.rs | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index 6f2930bd2..06eec457e 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -203,11 +203,13 @@ impl AsyncClient { } async fn send_async_and_notify(&self, request: Request) -> Result<(), ClientError> { - let mut request_buf = self.request_buf.lock().unwrap(); - if request_buf.len() == self.request_buf_capacity { - return Err(ClientError::RequestsFull); + { + let mut request_buf = self.request_buf.lock().unwrap(); + if request_buf.len() == self.request_buf_capacity { + return Err(ClientError::RequestsFull); + } + request_buf.push_back(request); } - request_buf.push_back(request); if let Err(SendError(_)) = self.request_tx.send_async(()).await { return Err(ClientError::EventloopClosed); }; From 2269cd1e85d9a477815547a8f3342fa826d7a4d8 Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Tue, 15 Feb 2022 11:43:54 +0530 Subject: [PATCH 12/38] rumqttc: v5: add examples Signed-off-by: Abhik Jain --- rumqttc/examples/async_manual_acks_v5.rs | 76 ++++++++++++++++++++++++ rumqttc/examples/asyncpubsub_v5.rs | 43 ++++++++++++++ rumqttc/examples/syncpubsub_v5.rs | 35 +++++++++++ 3 files changed, 154 insertions(+) create mode 100644 rumqttc/examples/async_manual_acks_v5.rs create mode 100644 rumqttc/examples/asyncpubsub_v5.rs create mode 100644 rumqttc/examples/syncpubsub_v5.rs diff --git a/rumqttc/examples/async_manual_acks_v5.rs b/rumqttc/examples/async_manual_acks_v5.rs new file mode 100644 index 000000000..eb01d2e3d --- /dev/null +++ b/rumqttc/examples/async_manual_acks_v5.rs @@ -0,0 +1,76 @@ +use tokio::{task, time}; + +use rumqttc::v5::{AsyncClient, Event, EventLoop, Incoming, MqttOptions, QoS}; +use std::error::Error; +use std::time::Duration; + +fn create_conn() -> (AsyncClient, EventLoop) { + let mut mqttoptions = MqttOptions::new("test-1", "localhost", 1883); + mqttoptions + .set_keep_alive(Duration::from_secs(5)) + .set_manual_acks(true) + .set_clean_session(false); + + AsyncClient::new(mqttoptions, 10) +} + +#[tokio::main(worker_threads = 1)] +async fn main() -> Result<(), Box> { + pretty_env_logger::init(); + + // create mqtt connection with clean_session = false and manual_acks = true + let (client, mut eventloop) = create_conn(); + + // subscribe example topic + client + .subscribe("hello/world", QoS::AtLeastOnce) + .await + .unwrap(); + + task::spawn(async move { + // send some messages to example topic and disconnect + requests(client.clone()).await; + client.disconnect().await.unwrap() + }); + + loop { + // get subscribed messages without acking + let event = eventloop.poll().await; + println!("{:?}", event); + if let Err(_err) = event { + // break loop on disconnection + break; + } + } + + // create new broker connection + let (client, mut eventloop) = create_conn(); + + loop { + // previously published messages should be republished after reconnection. + let event = eventloop.poll().await; + println!("{:?}", event); + match event { + Ok(Event::Incoming(Incoming::Publish(publish))) => { + // this time we will ack incoming publishes. + // Its important not to block eventloop as this can cause deadlock. + let c = client.clone(); + tokio::spawn(async move { + c.ack(&publish).await.unwrap(); + }); + } + _ => {} + } + } +} + +async fn requests(client: AsyncClient) { + for i in 1..=10 { + client + .publish("hello/world", QoS::AtLeastOnce, false, vec![1; i]) + .await + .unwrap(); + + time::sleep(Duration::from_secs(1)).await; + } +} diff --git a/rumqttc/examples/asyncpubsub_v5.rs b/rumqttc/examples/asyncpubsub_v5.rs new file mode 100644 index 000000000..b0b7a6698 --- /dev/null +++ b/rumqttc/examples/asyncpubsub_v5.rs @@ -0,0 +1,43 @@ +use tokio::{task, time}; + +use rumqttc::v5::{AsyncClient, MqttOptions, QoS}; +use std::error::Error; +use std::time::Duration; + +#[tokio::main(worker_threads = 1)] +async fn main() -> Result<(), Box> { + pretty_env_logger::init(); + // color_backtrace::install(); + + let mut mqttoptions = MqttOptions::new("test-1", "localhost", 1883); + mqttoptions.set_keep_alive(Duration::from_secs(5)); + + let (client, mut eventloop) = AsyncClient::new(mqttoptions, 10); + task::spawn(async move { + requests(client).await; + time::sleep(Duration::from_secs(3)).await; + }); + + loop { + let event = eventloop.poll().await; + println!("{:?}", event.unwrap()); + } +} + +async fn requests(client: AsyncClient) { + client + .subscribe("hello/world", QoS::AtMostOnce) + .await + .unwrap(); + + for i in 1..=10 { + client + .publish("hello/world", QoS::ExactlyOnce, false, vec![1; i]) + .await + .unwrap(); + + time::sleep(Duration::from_secs(1)).await; + } + + time::sleep(Duration::from_secs(120)).await; +} diff --git a/rumqttc/examples/syncpubsub_v5.rs b/rumqttc/examples/syncpubsub_v5.rs new file mode 100644 index 000000000..857ab4789 --- /dev/null +++ b/rumqttc/examples/syncpubsub_v5.rs @@ -0,0 +1,35 @@ +use rumqttc::v5::{Client, LastWill, MqttOptions, QoS}; +use std::thread; +use std::time::Duration; + +fn main() { + pretty_env_logger::init(); + + let mut mqttoptions = MqttOptions::new("test-1", "localhost", 1883); + let will = LastWill::new("hello/world", "good bye", QoS::AtMostOnce, false); + mqttoptions + .set_keep_alive(Duration::from_secs(5)) + .set_last_will(will); + + let (client, mut connection) = Client::new(mqttoptions, 10); + thread::spawn(move || publish(client)); + + for (i, notification) in connection.iter().enumerate() { + println!("{}. Notification = {:?}", i, notification); + } + + println!("Done with the stream!!"); +} + +fn publish(mut client: Client) { + client.subscribe("hello/+/world", QoS::AtMostOnce).unwrap(); + for i in 0..10 { + let payload = vec![1; i as usize]; + let topic = format!("hello/{}/world", i); + let qos = QoS::AtLeastOnce; + + client.publish(topic, qos, true, payload).unwrap(); + } + + thread::sleep(Duration::from_secs(1)); +} From 826e9d0927859e2adcb5bed986a7458fb6b56704 Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Tue, 15 Feb 2022 18:34:24 +0530 Subject: [PATCH 13/38] rumqttc: v4: examples: fix Signed-off-by: Abhik Jain --- rumqttc/examples/websocket.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/rumqttc/examples/websocket.rs b/rumqttc/examples/websocket.rs index 4336d2f7a..1467723c9 100644 --- a/rumqttc/examples/websocket.rs +++ b/rumqttc/examples/websocket.rs @@ -1,6 +1,8 @@ use tokio::{task, time}; -use rumqttc::{self, AsyncClient, MqttOptions, QoS, Transport}; +use rumqttc::v4::{AsyncClient, MqttOptions, QoS}; +#[cfg(feature = "websocket")] +use rumqttc::v4::Transport; use std::error::Error; use std::time::Duration; From ad91f7e1be0e84c5d0de8239f289b5602538948f Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Thu, 17 Feb 2022 16:01:19 +0530 Subject: [PATCH 14/38] benchmarks: add a version agnostic simple router Signed-off-by: Abhik Jain --- Cargo.lock | 269 ++- Cargo.toml | 3 +- benchmarks/simplerouter/Cargo.toml | 13 + .../simplerouter/src/bin/simplerouter.rs | 11 + benchmarks/simplerouter/src/lib.rs | 122 ++ benchmarks/simplerouter/src/network.rs | 102 + benchmarks/simplerouter/src/protocol/mod.rs | 430 ++++ benchmarks/simplerouter/src/protocol/v4.rs | 863 ++++++++ benchmarks/simplerouter/src/protocol/v5.rs | 1952 +++++++++++++++++ 9 files changed, 3672 insertions(+), 93 deletions(-) create mode 100644 benchmarks/simplerouter/Cargo.toml create mode 100644 benchmarks/simplerouter/src/bin/simplerouter.rs create mode 100644 benchmarks/simplerouter/src/lib.rs create mode 100644 benchmarks/simplerouter/src/network.rs create mode 100644 benchmarks/simplerouter/src/protocol/mod.rs create mode 100644 benchmarks/simplerouter/src/protocol/v4.rs create mode 100644 benchmarks/simplerouter/src/protocol/v5.rs diff --git a/Cargo.lock b/Cargo.lock index 8dbaf5cef..5281ecfe4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -103,18 +103,18 @@ dependencies = [ [[package]] name = "async-tungstenite" -version = "0.13.1" +version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07b30ef0ea5c20caaa54baea49514a206308989c68be7ecd86c7f956e4da6378" +checksum = "5682ea0913e5c20780fe5785abacb85a411e7437bf52a1bedb93ddb3972cb8dd" dependencies = [ "futures-io", "futures-util", "log", "pin-project-lite 0.2.7", - "tokio 1.8.2", + "rustls-native-certs", + "tokio 1.17.0", "tokio-rustls", - "tungstenite 0.13.0", - "webpki-roots", + "tungstenite 0.16.0", ] [[package]] @@ -126,7 +126,7 @@ dependencies = [ "futures 0.3.15", "pharos", "rustc_version 0.3.3", - "tokio 1.8.2", + "tokio 1.17.0", ] [[package]] @@ -185,7 +185,7 @@ version = "0.4.0" dependencies = [ "argh", "async-channel", - "bytes 1.0.1", + "bytes 1.1.0", "futures 0.3.15", "itoa", "jemallocator", @@ -197,7 +197,7 @@ dependencies = [ "rumqttlog 0.9.0", "serde", "serde_json", - "tokio 1.8.2", + "tokio 1.17.0", ] [[package]] @@ -263,9 +263,9 @@ checksum = "0e4cec68f03f32e44924783795810fa50a7035d8c8ebe78580ad7e6c703fba38" [[package]] name = "bytes" -version = "1.0.1" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b700ce4376041dcd0a327fd0097c41095743c4c8af8887265942faf1100bd040" +checksum = "c4872d67bab6358e59559027aa3b9157c53d9358c51423c17554809a8858e0f8" [[package]] name = "cache-padded" @@ -798,7 +798,7 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "825343c4eef0b63f541f8903f395dc5beb362a979b5799a84062527ef1e37726" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "fnv", "futures-core", "futures-sink", @@ -806,7 +806,7 @@ dependencies = [ "http", "indexmap", "slab", - "tokio 1.8.2", + "tokio 1.17.0", "tokio-util", "tracing", ] @@ -831,7 +831,7 @@ checksum = "f0b7591fb62902706ae8e7aaff416b1b0fa2c0fd0878b46dc13baa3712d8a855" dependencies = [ "base64 0.13.0", "bitflags", - "bytes 1.0.1", + "bytes 1.1.0", "headers-core", "http", "mime", @@ -872,7 +872,7 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "527e8c9ac747e28542699a951517aa9a6945af506cd1f2e1b53a576c17b6cc11" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "fnv", "itoa", ] @@ -883,7 +883,7 @@ version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "60daa14be0e0786db0f03a9e57cb404c9d756eed2b6c62b9ea98ec5743ec75a9" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "http", "pin-project-lite 0.2.7", ] @@ -915,7 +915,7 @@ version = "0.14.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7728a72c4c7d72665fde02204bcbd93b247721025b222ef78606f14513e0fd03" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "futures-channel", "futures-core", "futures-util", @@ -927,7 +927,7 @@ dependencies = [ "itoa", "pin-project-lite 0.2.7", "socket2", - "tokio 1.8.2", + "tokio 1.17.0", "tower-service", "tracing", "want", @@ -978,7 +978,7 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f97967975f448f1a7ddb12b0bc41069d09ed6a1c161a92687e057325db35d413" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", ] [[package]] @@ -1111,15 +1111,15 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.98" +version = "0.2.118" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "320cfe77175da3a483efed4bc0adc1968ca050b098ce4f2f1c13a56626128790" +checksum = "06e509672465a0504304aa87f9f176f2b2b716ed8fb105ebe5c02dc6dce96a94" [[package]] name = "lock_api" -version = "0.4.4" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0382880606dff6d15c9476c416d18690b72742aa7b605bb6dd6ec9030fbf07eb" +checksum = "88943dd7ef4a2e5a4bfa2753aaab3013e34ce2533d1996fb18ef591e315e2b3b" dependencies = [ "scopeguard", ] @@ -1211,9 +1211,9 @@ dependencies = [ [[package]] name = "mio" -version = "0.7.13" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c2bdb6314ec10835cd3293dd268473a835c02b7b352e788be788b3c6ca6bb16" +checksum = "ba272f85fa0b41fc91872be579b3bbe0f56b792aa361a380eb669469f68dafb2" dependencies = [ "libc", "log", @@ -1279,7 +1279,7 @@ dependencies = [ name = "mqttbytes" version = "0.6.0" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "criterion", "pretty_assertions", "rand 0.7.3", @@ -1513,7 +1513,17 @@ checksum = "6d7744ac029df22dca6284efe4e898991d28e3085c706c972bcd7da4a27a15eb" dependencies = [ "instant", "lock_api", - "parking_lot_core", + "parking_lot_core 0.8.3", +] + +[[package]] +name = "parking_lot" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87f5ec2493a61ac0506c0f4199f99070cbe83857b0337006a30f3e6719b8ef58" +dependencies = [ + "lock_api", + "parking_lot_core 0.9.1", ] [[package]] @@ -1530,6 +1540,19 @@ dependencies = [ "winapi 0.3.9", ] +[[package]] +name = "parking_lot_core" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28141e0cc4143da2443301914478dc976a61ffdb3f043058310c70df2fed8954" +dependencies = [ + "cfg-if 1.0.0", + "libc", + "redox_syscall", + "smallvec", + "windows-sys", +] + [[package]] name = "pem" version = "0.8.3" @@ -1666,7 +1689,7 @@ dependencies = [ "libc", "log", "nix 0.19.1", - "parking_lot", + "parking_lot 0.11.1", "prost 0.7.0", "prost-build", "prost-derive 0.7.0", @@ -1687,7 +1710,7 @@ dependencies = [ "libc", "log", "nix 0.20.0", - "parking_lot", + "parking_lot 0.11.1", "prost 0.7.0", "prost-build", "prost-derive 0.7.0", @@ -1761,7 +1784,7 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e6984d2f1a23009bd270b8bb56d0926810a3d483f59c987d77969e9d8e840b2" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "prost-derive 0.7.0", ] @@ -1771,7 +1794,7 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32d3ebd75ac2679c2af3a92246639f9fcc8a442ee420719cc4fe195b98dd5fa3" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "heck", "itertools 0.9.0", "log", @@ -1815,7 +1838,7 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b518d7cdd93dab1d1122cf07fa9a60771836c668dde9d9e2a139f957f0d9f1bb" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "prost 0.7.0", ] @@ -2030,7 +2053,7 @@ version = "0.10.0" dependencies = [ "async-channel", "async-tungstenite", - "bytes 1.0.1", + "bytes 1.1.0", "color-backtrace", "crossbeam-channel", "envy", @@ -2043,9 +2066,10 @@ dependencies = [ "pretty_env_logger", "rustls", "rustls-native-certs", + "rustls-pemfile 0.3.0", "serde", "thiserror", - "tokio 1.8.2", + "tokio 1.17.0", "tokio-rustls", "url", "webpki", @@ -2057,7 +2081,7 @@ name = "rumqttd" version = "0.9.0" dependencies = [ "argh", - "bytes 1.0.1", + "bytes 1.1.0", "confy", "futures-util", "jemallocator", @@ -2066,9 +2090,10 @@ dependencies = [ "pprof 0.4.4", "pretty_env_logger", "rumqttlog 0.9.0", + "rustls-pemfile 0.3.0", "serde", "thiserror", - "tokio 1.8.2", + "tokio 1.17.0", "tokio-native-tls", "tokio-rustls", "warp", @@ -2101,7 +2126,7 @@ dependencies = [ "argh", "bencher", "byteorder", - "bytes 1.0.1", + "bytes 1.1.0", "fnv", "futures-util", "jackiechan", @@ -2123,9 +2148,9 @@ dependencies = [ name = "rumqttmesh" version = "0.1.0" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "rumqttlog 0.1.4", - "tokio 1.8.2", + "tokio 1.17.0", ] [[package]] @@ -2154,11 +2179,10 @@ dependencies = [ [[package]] name = "rustls" -version = "0.19.1" +version = "0.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35edb675feee39aec9c99fa5ff985081995a06d594114ae14cbe797ad7b7a6d7" +checksum = "b323592e3164322f5b193dc4302e4e36cd8d37158a712d664efae1a5c2791700" dependencies = [ - "base64 0.13.0", "log", "ring", "sct", @@ -2167,16 +2191,34 @@ dependencies = [ [[package]] name = "rustls-native-certs" -version = "0.5.0" +version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a07b7c1885bd8ed3831c289b7870b13ef46fe0e856d288c30d9cc17d75a2092" +checksum = "5ca9ebdfa27d3fc180e42879037b5338ab1c040c06affd00d8338598e7800943" dependencies = [ "openssl-probe", - "rustls", + "rustls-pemfile 0.2.1", "schannel", "security-framework", ] +[[package]] +name = "rustls-pemfile" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5eebeaeb360c87bfb72e84abdb3447159c0eaececf1bef2aecd65a8be949d1c9" +dependencies = [ + "base64 0.13.0", +] + +[[package]] +name = "rustls-pemfile" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ee86d63972a7c661d1536fefe8c3c8407321c3df668891286de28abcd087360" +dependencies = [ + "base64 0.13.0", +] + [[package]] name = "ryu" version = "1.0.5" @@ -2222,9 +2264,9 @@ checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" [[package]] name = "sct" -version = "0.6.1" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b362b83898e0e69f38515b82ee15aa80636befe47c3b6d3d89a911e78fc228ce" +checksum = "d53dcdb7c9f8158937a7981b48accfd39a43af418591a5d008c7b22b5e1b7ca4" dependencies = [ "ring", "untrusted", @@ -2375,6 +2417,17 @@ dependencies = [ "num-traits", ] +[[package]] +name = "simplerouter" +version = "0.1.0" +dependencies = [ + "bytes 1.1.0", + "log", + "pretty_env_logger", + "thiserror", + "tokio 1.17.0", +] + [[package]] name = "slab" version = "0.4.3" @@ -2389,9 +2442,9 @@ checksum = "fe0f37c9e8f3c5a4a66ad655a93c74daac4ad00c441533bf5c6e7990bb42604e" [[package]] name = "socket2" -version = "0.4.0" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e3dfc207c526015c632472a77be09cf1b6e46866581aecae5cc38fb4235dea2" +checksum = "66d72b759436ae32898a2af0a14218dbf55efde3feeb170eb623637db85ee1e0" dependencies = [ "libc", "winapi 0.3.9", @@ -2492,18 +2545,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.26" +version = "1.0.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93119e4feac1cbe6c798c34d3a53ea0026b0b1de6a120deef895137c0529bfe2" +checksum = "854babe52e4df1653706b98fcfc05843010039b406875930a70e4d9644e5c417" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.26" +version = "1.0.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "060d69a0afe7796bf42e9e2ff91f5ee691fb15c53d38b4b62a9a53eb23164745" +checksum = "aa32fd3f627f367fe16f893e2597ae3c05020f8bba2666a4e6ea73d377e5714b" dependencies = [ "proc-macro2", "quote", @@ -2571,21 +2624,21 @@ dependencies = [ [[package]] name = "tokio" -version = "1.8.2" +version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2602b8af3767c285202012822834005f596c811042315fa7e9f5b12b2a43207" +checksum = "2af73ac49756f3f7c01172e34a23e5d0216f6c32333757c2c61feb2bbff5a5ee" dependencies = [ - "autocfg", - "bytes 1.0.1", + "bytes 1.1.0", "libc", "memchr", - "mio 0.7.13", + "mio 0.8.0", "num_cpus", "once_cell", - "parking_lot", + "parking_lot 0.12.0", "pin-project-lite 0.2.7", "signal-hook-registry", - "tokio-macros 1.3.0", + "socket2", + "tokio-macros 1.7.0", "winapi 0.3.9", ] @@ -2602,9 +2655,9 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "1.3.0" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54473be61f4ebe4efd09cec9bd5d16fa51d70ea0192213d754d2d500457db110" +checksum = "b557f72f448c511a979e2564e55d74e6c4432fc96ff4f6241bc6bded342643b7" dependencies = [ "proc-macro2", "quote", @@ -2618,17 +2671,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f7d995660bd2b7f8c1568414c1126076c13fbb725c40112dc0120b78eb9b717b" dependencies = [ "native-tls", - "tokio 1.8.2", + "tokio 1.17.0", ] [[package]] name = "tokio-rustls" -version = "0.22.0" +version = "0.23.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc6844de72e57df1980054b38be3a9f4702aba4858be64dd700181a8a6d0e1b6" +checksum = "a27d5f2b839802bd8267fa19b0530f5a08b9c08cd417976be2a65d130fe1c11b" dependencies = [ "rustls", - "tokio 1.8.2", + "tokio 1.17.0", "webpki", ] @@ -2640,7 +2693,7 @@ checksum = "7b2f3f698253f03119ac0102beaa64f67a67e08074d03a22d18784104543727f" dependencies = [ "futures-core", "pin-project-lite 0.2.7", - "tokio 1.8.2", + "tokio 1.17.0", ] [[package]] @@ -2652,7 +2705,7 @@ dependencies = [ "futures-util", "log", "pin-project", - "tokio 1.8.2", + "tokio 1.17.0", "tungstenite 0.12.0", ] @@ -2662,12 +2715,12 @@ version = "0.6.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1caa0b0c8d94a049db56b5acf8cba99dc0623aab1b26d5b5f5e2d945846b3592" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "futures-core", "futures-sink", "log", "pin-project-lite 0.2.7", - "tokio 1.8.2", + "tokio 1.17.0", ] [[package]] @@ -2720,7 +2773,7 @@ checksum = "8ada8297e8d70872fa9a551d93250a9f407beb9f37ef86494eb20012a2ff7c24" dependencies = [ "base64 0.13.0", "byteorder", - "bytes 1.0.1", + "bytes 1.1.0", "http", "httparse", "input_buffer", @@ -2733,16 +2786,15 @@ dependencies = [ [[package]] name = "tungstenite" -version = "0.13.0" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5fe8dada8c1a3aeca77d6b51a4f1314e0f4b8e438b7b1b71e3ddaca8080e4093" +checksum = "6ad3713a14ae247f22a728a0456a545df14acf3867f905adff84be99e23b3ad1" dependencies = [ "base64 0.13.0", "byteorder", - "bytes 1.0.1", + "bytes 1.1.0", "http", "httparse", - "input_buffer", "log", "rand 0.8.4", "rustls", @@ -2751,7 +2803,6 @@ dependencies = [ "url", "utf-8", "webpki", - "webpki-roots", ] [[package]] @@ -2889,7 +2940,7 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "332d47745e9a0c38636dbd454729b147d16bd1ed08ae67b3ab281c4506771054" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "futures 0.3.15", "headers", "http", @@ -2904,7 +2955,7 @@ dependencies = [ "serde", "serde_json", "serde_urlencoded", - "tokio 1.8.2", + "tokio 1.17.0", "tokio-stream", "tokio-tungstenite", "tokio-util", @@ -2990,23 +3041,14 @@ dependencies = [ [[package]] name = "webpki" -version = "0.21.4" +version = "0.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8e38c0608262c46d4a56202ebabdeb094cef7e560ca7a226c6bf055188aa4ea" +checksum = "f095d78192e208183081cc07bc5515ef55216397af48b873e5edcd72637fa1bd" dependencies = [ "ring", "untrusted", ] -[[package]] -name = "webpki-roots" -version = "0.21.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aabe153544e473b775453675851ecc86863d2a81d786d741f6b76778f2a48940" -dependencies = [ - "webpki", -] - [[package]] name = "which" version = "4.1.0" @@ -3060,6 +3102,49 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows-sys" +version = "0.32.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3df6e476185f92a12c072be4a189a0210dcdcf512a1891d6dff9edb874deadc6" +dependencies = [ + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_msvc" +version = "0.32.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8e92753b1c443191654ec532f14c199742964a061be25d77d7a96f09db20bf5" + +[[package]] +name = "windows_i686_gnu" +version = "0.32.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a711c68811799e017b6038e0922cb27a5e2f43a2ddb609fe0b6f3eeda9de615" + +[[package]] +name = "windows_i686_msvc" +version = "0.32.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "146c11bb1a02615db74680b32a68e2d61f553cc24c4eb5b4ca10311740e44172" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.32.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c912b12f7454c6620635bbff3450962753834be2a594819bd5e945af18ec64bc" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.32.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "504a2476202769977a040c6364301a3f65d0cc9e3fb08600b2bda150a0488316" + [[package]] name = "ws2_32-sys" version = "0.2.1" @@ -3072,9 +3157,9 @@ dependencies = [ [[package]] name = "ws_stream_tungstenite" -version = "0.6.1" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34c786fc3d0a792f8a6e7a69f3b85afa1cf7b2560bbd434d7d5c32a580e153c0" +checksum = "a672ec78525bf189cefa7f1b72c55f928b3edbdb967e680ca49748ab20821045" dependencies = [ "async-tungstenite", "async_io_stream", @@ -3086,6 +3171,6 @@ dependencies = [ "log", "pharos", "rustc_version 0.4.0", - "tokio 1.8.2", - "tungstenite 0.13.0", + "tokio 1.17.0", + "tungstenite 0.16.0", ] diff --git a/Cargo.toml b/Cargo.toml index ecf22a6f5..d25103f1b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,5 +6,6 @@ members = [ "rumqttlog", "rumqttmesh", "rumqttd", - "benchmarks" + "benchmarks", + "benchmarks/simplerouter", ] diff --git a/benchmarks/simplerouter/Cargo.toml b/benchmarks/simplerouter/Cargo.toml new file mode 100644 index 000000000..3e2ee24d8 --- /dev/null +++ b/benchmarks/simplerouter/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "simplerouter" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +bytes = "1.1.0" +log = "0.4.14" +pretty_env_logger = "0.4.0" +thiserror = "1.0.30" +tokio = { version = "1.17.0", features = ["net", "sync", "rt-multi-thread", "io-util", "macros"] } diff --git a/benchmarks/simplerouter/src/bin/simplerouter.rs b/benchmarks/simplerouter/src/bin/simplerouter.rs new file mode 100644 index 000000000..cf5527749 --- /dev/null +++ b/benchmarks/simplerouter/src/bin/simplerouter.rs @@ -0,0 +1,11 @@ +use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; + +#[tokio::main] +async fn main() { + pretty_env_logger::init(); + simplerouter::run(simplerouter::Config { + addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1883)), + }) + .await + .unwrap(); +} diff --git a/benchmarks/simplerouter/src/lib.rs b/benchmarks/simplerouter/src/lib.rs new file mode 100644 index 000000000..18e7fa4ee --- /dev/null +++ b/benchmarks/simplerouter/src/lib.rs @@ -0,0 +1,122 @@ +use std::{io, net::SocketAddr}; + +use bytes::BytesMut; +use log::*; +use tokio::net::TcpListener; + +mod network; +mod protocol; +use network::Network; +use protocol::{v4, v5}; + +pub struct Config { + pub addr: SocketAddr, +} + +pub async fn run(config: Config) -> Result<(), Error> { + let listener = TcpListener::bind(config.addr).await?; + info!("router: listening on {}", config.addr); + + loop { + let (stream, addr) = listener.accept().await?; + info!("router: accepted connection from {}", addr); + let (network, _) = match Network::read_connect(stream).await { + Ok(v) => v, + Err(e) => { + error!("router: unable to read connect : {}", e); + continue; + } + }; + info!("connection: sent connack"); + tokio::spawn(publisher_handle(network)); + } +} + +async fn publisher_handle(mut network: Network) { + let mut payload = BytesMut::with_capacity(2); + v4::pingresp::write(&mut payload).unwrap(); + let pingresp_bytes = payload.split().freeze(); + + loop { + let packet = match network.poll().await { + Ok(packet) => packet, + Err(e) => { + error!("connection: unable to read packet: {}", e); + return; + } + }; + match packet { + protocol::Packet::V4(packet) => match packet { + v4::Packet::Disconnect => { + info!("connection: received disconnect, exiting"); + return; + }, + v4::Packet::PingReq => { + if let Err(e) = network.send_data(&pingresp_bytes).await { + error!("unable to send pingresp, exiting : {}", e); + return; + }; + } + v4::Packet::Publish(publish) => { + let pkid = match publish.view_meta() { + Ok(v) => v.2, + Err(e) => { + error!("connection: malformed publish packet : {}", e); + continue; + } + }; + payload.reserve(2); + v4::puback::write(pkid, &mut payload).unwrap(); + if let Err(e) = network.send_data(&payload.split().freeze()).await { + error!("unable to send puback pkid = {}, exiting : {}", pkid, e); + return; + }; + } + p => { + error!("connection: invalid packet {:?}", p); + continue; + } + }, + protocol::Packet::V5(packet) => match packet { + v5::Packet::Disconnect => { + info!("connection: received disconnect, exiting"); + return; + }, + v5::Packet::PingReq => { + if let Err(e) = network.send_data(&pingresp_bytes).await { + error!("unable to send pingresp, exiting : {}", e); + return; + }; + } + v5::Packet::Publish(publish) => { + let pkid = match publish.view_meta() { + Ok(v) => v.2, + Err(e) => { + error!("connection: malformed publish packet : {}", e); + continue; + } + }; + payload.reserve(8); + v5::puback::write(pkid, v5::puback::PubAckReason::Success, None, &mut payload) + .unwrap(); + if let Err(e) = network.send_data(&payload.split().freeze()).await { + error!("unable to send puback pkid = {}, exiting : {}", pkid, e); + return; + }; + } + p => { + error!("connection: invalid packet {:?}", p); + continue; + } + }, + } + } +} + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("MQTT : {0}")] + MQTT(#[from] crate::protocol::Error), + #[error("i/O : {0}")] + IO(#[from] io::Error), +} diff --git a/benchmarks/simplerouter/src/network.rs b/benchmarks/simplerouter/src/network.rs new file mode 100644 index 000000000..b61c02fe4 --- /dev/null +++ b/benchmarks/simplerouter/src/network.rs @@ -0,0 +1,102 @@ +use std::io; + +use bytes::{Bytes, BytesMut}; +use log::*; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::TcpStream, +}; + +use crate::{ + protocol::{self, v4, v5, Connect, Packet}, + Error, +}; + +pub(crate) struct Network { + stream: TcpStream, + buf: BytesMut, + protocol_level: u8, +} + +impl Network { + pub(crate) async fn read_connect(stream: TcpStream) -> Result<(Self, Connect), Error> { + let mut network = Self { + stream, + buf: BytesMut::with_capacity(4096), + protocol_level: 0, + }; + debug!("network: reading from stream"); + network.stream.read_buf(&mut network.buf).await?; + let connect_packet = loop { + match protocol::read_first_connect(&mut network.buf, 4096) { + Err(protocol::Error::InsufficientBytes(count)) => { + network.read_atleast(count).await? + } + res => break res?, + } + }; + debug!("network: read connect"); + match &connect_packet { + Connect::V4(_) => { + network.protocol_level = 4; + let mut payload = BytesMut::with_capacity(10); + v4::connack::write(v4::connack::ConnectReturnCode::Success, false, &mut payload)?; + network.send_data(&payload.split().freeze()).await?; + } + Connect::V5(_) => { + network.protocol_level = 5; + let mut payload = BytesMut::with_capacity(10); + v5::connack::write( + v5::connack::ConnectReturnCode::Success, + false, + None, + &mut payload, + )?; + network.send_data(&payload.split().freeze()).await?; + } + } + debug!("network: sent connack"); + Ok((network, connect_packet)) + } + + async fn read_atleast(&mut self, count: usize) -> io::Result<()> { + let mut len = 0; + while len < count { + len += self.stream.read_buf(&mut self.buf).await?; + } + debug!("network: read {} bytes", len); + + Ok(()) + } + + pub(crate) async fn poll(&mut self) -> Result { + loop { + match self.protocol_level { + 4 => match v4::read_mut(&mut self.buf, 4096) { + Err(protocol::Error::InsufficientBytes(count)) => { + self.read_atleast(count).await?; + continue; + } + res => return Ok(Packet::V4(res?)), + }, + 5 => match v5::read_mut(&mut self.buf, 4096) { + Err(protocol::Error::InsufficientBytes(count)) => { + self.read_atleast(count).await?; + continue; + } + res => return Ok(Packet::V5(res?)), + }, + // SAFETY: we don't allow changing protocol_level + _ => unsafe { std::hint::unreachable_unchecked() }, + } + } + } + + pub(crate) async fn send_data(&mut self, data: &Bytes) -> Result<(), Error> { + debug!( + "network: sent {} bytes", + self.stream.write(data.as_ref()).await? + ); + Ok(()) + } +} diff --git a/benchmarks/simplerouter/src/protocol/mod.rs b/benchmarks/simplerouter/src/protocol/mod.rs new file mode 100644 index 000000000..2ecdbd320 --- /dev/null +++ b/benchmarks/simplerouter/src/protocol/mod.rs @@ -0,0 +1,430 @@ +#![allow(dead_code)] +use std::{slice::Iter, str::Utf8Error}; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +pub mod v4; +pub mod v5; + +/// Checks if the filter is valid +/// +/// https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718106 +pub fn valid_filter(filter: &str) -> bool { + if filter.is_empty() { + return false; + } + + let hirerarchy = filter.split('/').collect::>(); + if let Some((last, remaining)) = hirerarchy.split_last() { + // # is not allowed in filer except as a last entry + // invalid: sport/tennis#/player + // invalid: sport/tennis/#/ranking + for entry in remaining.iter() { + if entry.contains('#') { + return false; + } + } + + // only single '#" is allowed in last entry + // invalid: sport/tennis# + if last.len() != 1 && last.contains('#') { + return false; + } + } + + true +} + +/// MQTT packet type +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PacketType { + Connect = 1, + ConnAck, + Publish, + PubAck, + PubRec, + PubRel, + PubComp, + Subscribe, + SubAck, + Unsubscribe, + UnsubAck, + PingReq, + PingResp, + Disconnect, +} + +/// Error during serialization and deserialization +#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)] +pub enum Error { + #[error("Expected connect packet, received = {0:?}")] + NotConnect(PacketType), + #[error("Received an unexpected connect packet")] + UnexpectedConnect, + #[error("Invalid return code received as response for connect = {0}")] + InvalidConnectReturnCode(u8), + #[error("Invalid reason = {0}")] + InvalidReason(u8), + #[error("Invalid protocol used")] + InvalidProtocol, + #[error("Invalid protocol level")] + InvalidProtocolLevel(u8), + #[error("Invalid packet format")] + IncorrectPacketFormat, + #[error("Invalid packet type = {0}")] + InvalidPacketType(u8), + #[error("Packet type unsupported = {0:?}")] + UnsupportedPacket(PacketType), + #[error("Invalid retain forward rule = {0}")] + InvalidRetainForwardRule(u8), + #[error("Invalid QoS level = {0}")] + InvalidQoS(u8), + #[error("Invalid subscribe reason code = {0}")] + InvalidSubscribeReasonCode(u8), + #[error("Packet received has id Zero")] + PacketIdZero, + #[error("Subscription had id Zero")] + SubscriptionIdZero, + #[error("Payload size is incorrect")] + PayloadSizeIncorrect, + #[error("Payload is too long")] + PayloadTooLong, + #[error("Payload size has been exceeded by {0} bytes")] + PayloadSizeLimitExceeded(usize), + #[error("Payload is required")] + PayloadRequired, + #[error("Topic not utf-8 = {0}")] + TopicNotUtf8(#[from] Utf8Error), + #[error("Promised boundary crossed, contains {0} bytes")] + BoundaryCrossed(usize), + #[error("Packet is malformed")] + MalformedPacket, + #[error("Remaining length is malformed")] + MalformedRemainingLength, + /// More bytes required to frame packet. Argument + /// implies minimum additional bytes required to + /// proceed further + #[error("Insufficient number of bytes to frame packet, {0} more bytes required")] + InsufficientBytes(usize), + #[error("Property does not exist = {0}")] + InvalidPropertyType(u8), +} + +/// Quality of service +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)] +pub enum QoS { + AtMostOnce = 0, + AtLeastOnce = 1, +} + +/// Maps a number to QoS +pub fn qos(num: u8) -> Result { + match num { + 0 => Ok(QoS::AtMostOnce), + 1 => Ok(QoS::AtLeastOnce), + qos => Err(Error::InvalidQoS(qos)), + } +} + +/// Packet type from a byte +/// +/// ```ignore +/// 7 3 0 +/// +--------------------------+--------------------------+ +/// byte 1 | MQTT Control Packet Type | Flags for each type | +/// +--------------------------+--------------------------+ +/// | Remaining Bytes Len (1/2/3/4 bytes) | +/// +-----------------------------------------------------+ +/// +/// http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Figure_2.2_- +/// ``` +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)] +pub struct FixedHeader { + /// First byte of the stream. Used to identify packet types and + /// several flags + pub byte1: u8, + /// Length of fixed header. Byte 1 + (1..4) bytes. So fixed header + /// len can vary from 2 bytes to 5 bytes + /// 1..4 bytes are variable length encoded to represent remaining length + pub fixed_header_len: usize, + /// Remaining length of the packet. Doesn't include fixed header bytes + /// Represents variable header + payload size + pub remaining_len: usize, +} + +impl FixedHeader { + pub fn new(byte1: u8, remaining_len_len: usize, remaining_len: usize) -> FixedHeader { + FixedHeader { + byte1, + fixed_header_len: remaining_len_len + 1, + remaining_len, + } + } + + pub fn packet_type(&self) -> Result { + let num = self.byte1 >> 4; + match num { + 1 => Ok(PacketType::Connect), + 2 => Ok(PacketType::ConnAck), + 3 => Ok(PacketType::Publish), + 4 => Ok(PacketType::PubAck), + 5 => Ok(PacketType::PubRec), + 6 => Ok(PacketType::PubRel), + 7 => Ok(PacketType::PubComp), + 8 => Ok(PacketType::Subscribe), + 9 => Ok(PacketType::SubAck), + 10 => Ok(PacketType::Unsubscribe), + 11 => Ok(PacketType::UnsubAck), + 12 => Ok(PacketType::PingReq), + 13 => Ok(PacketType::PingResp), + 14 => Ok(PacketType::Disconnect), + _ => Err(Error::InvalidPacketType(num)), + } + } + + /// Returns the size of full packet (fixed header + variable header + payload) + /// Fixed header is enough to get the size of a frame in the stream + pub fn frame_length(&self) -> usize { + self.fixed_header_len + self.remaining_len + } +} + +/// Checks if the stream has enough bytes to frame a packet and returns fixed header +/// only if a packet can be framed with existing bytes in the `stream`. +/// The passed stream doesn't modify parent stream's cursor. If this function +/// returned an error, next `check` on the same parent stream is forced start +/// with cursor at 0 again (Iter is owned. Only Iter's cursor is changed internally) +pub fn check(stream: Iter, max_packet_size: usize) -> Result { + // Create fixed header if there are enough bytes in the stream + // to frame full packet + let stream_len = stream.len(); + let fixed_header = parse_fixed_header(stream)?; + + // Don't let rogue connections attack with huge payloads. + // Disconnect them before reading all that data + if fixed_header.remaining_len > max_packet_size { + return Err(Error::PayloadSizeLimitExceeded(fixed_header.remaining_len)); + } + + // If the current call fails due to insufficient bytes in the stream, + // after calculating remaining length, we extend the stream + let frame_length = fixed_header.frame_length(); + if stream_len < frame_length { + return Err(Error::InsufficientBytes(frame_length - stream_len)); + } + + Ok(fixed_header) +} + +/// Parses fixed header +fn parse_fixed_header(mut stream: Iter) -> Result { + // At least 2 bytes are necessary to frame a packet + let stream_len = stream.len(); + if stream_len < 2 { + return Err(Error::InsufficientBytes(2 - stream_len)); + } + + let byte1 = stream.next().unwrap(); + let (len_len, len) = length(stream)?; + + Ok(FixedHeader::new(*byte1, len_len, len)) +} + +/// Parses variable byte integer in the stream and returns the length +/// and number of bytes that make it. Used for remaining length calculation +/// as well as for calculating property lengths +pub fn length(stream: Iter) -> Result<(usize, usize), Error> { + let mut len: usize = 0; + let mut len_len = 0; + let mut done = false; + let mut shift = 0; + + // Use continuation bit at position 7 to continue reading next + // byte to frame 'length'. + // Stream 0b1xxx_xxxx 0b1yyy_yyyy 0b1zzz_zzzz 0b0www_wwww will + // be framed as number 0bwww_wwww_zzz_zzzz_yyy_yyyy_xxx_xxxx + for byte in stream { + len_len += 1; + let byte = *byte as usize; + len += (byte & 0x7F) << shift; + + // stop when continue bit is 0 + done = (byte & 0x80) == 0; + if done { + break; + } + + shift += 7; + + // Only a max of 4 bytes allowed for remaining length + // more than 4 shifts (0, 7, 14, 21) implies bad length + if shift > 21 { + return Err(Error::MalformedRemainingLength); + } + } + + // Not enough bytes to frame remaining length. wait for + // one more byte + if !done { + return Err(Error::InsufficientBytes(1)); + } + + Ok((len_len, len)) +} + +/// Returns big endian u16 view from next 2 bytes +pub fn view_u16(stream: &[u8]) -> Result { + let v = match stream.get(0..2) { + Some(v) => (v[0] as u16) << 8 | (v[1] as u16), + None => return Err(Error::MalformedPacket), + }; + + Ok(v) +} + +/// Returns big endian u16 view from next 2 bytes +pub fn view_str(stream: &[u8], end: usize) -> Result<&str, Error> { + let v = match stream.get(0..end) { + Some(v) => v, + None => return Err(Error::BoundaryCrossed(stream.len())), + }; + + let v = std::str::from_utf8(v)?; + Ok(v) +} + +/// After collecting enough bytes to frame a packet (packet's frame()) +/// , It's possible that content itself in the stream is wrong. Like expected +/// packet id or qos not being present. In cases where `read_mqtt_string` or +/// `read_mqtt_bytes` exhausted remaining length but packet framing expects to +/// parse qos next, these pre checks will prevent `bytes` crashes + +fn read_u32(stream: &mut Bytes) -> Result { + if stream.len() < 4 { + return Err(Error::MalformedPacket); + } + + Ok(stream.get_u32()) +} + +pub fn read_u16(stream: &mut Bytes) -> Result { + if stream.len() < 2 { + return Err(Error::MalformedPacket); + } + + Ok(stream.get_u16()) +} + +fn read_u8(stream: &mut Bytes) -> Result { + if stream.len() < 1 { + return Err(Error::MalformedPacket); + } + + Ok(stream.get_u8()) +} + +/// Reads a series of bytes with a length from a byte stream +fn read_mqtt_bytes(stream: &mut Bytes) -> Result { + let len = read_u16(stream)? as usize; + + // Prevent attacks with wrong remaining length. This method is used in + // `packet.assembly()` with (enough) bytes to frame packet. Ensures that + // reading variable len string or bytes doesn't cross promised boundary + // with `read_fixed_header()` + if len > stream.len() { + return Err(Error::BoundaryCrossed(len)); + } + + Ok(stream.split_to(len)) +} + +/// Serializes bytes to stream (including length) +fn write_mqtt_bytes(stream: &mut BytesMut, bytes: &[u8]) { + stream.put_u16(bytes.len() as u16); + stream.extend_from_slice(bytes); +} + +/// Serializes a string to stream +pub fn write_mqtt_string(stream: &mut BytesMut, string: &str) { + write_mqtt_bytes(stream, string.as_bytes()); +} + +/// Writes remaining length to stream and returns number of bytes for remaining length +pub fn write_remaining_length(stream: &mut BytesMut, len: usize) -> Result { + if len > 268_435_455 { + return Err(Error::PayloadTooLong); + } + + let mut done = false; + let mut x = len; + let mut count = 0; + + while !done { + let mut byte = (x % 128) as u8; + x /= 128; + if x > 0 { + byte |= 128; + } + + stream.put_u8(byte); + count += 1; + done = x == 0; + } + + Ok(count) +} + +/// Return number of remaining length bytes required for encoding length +fn len_len(len: usize) -> usize { + if len >= 2_097_152 { + 4 + } else if len >= 16_384 { + 3 + } else if len >= 128 { + 2 + } else { + 1 + } +} + +pub enum Connect { + V4(v4::connect::Connect), + V5(v5::connect::Connect), +} + +#[derive(Debug)] +pub enum Packet { + V4(v4::Packet), + V5(v5::Packet), +} + +pub(crate) fn read_first_connect(stream: &mut BytesMut, max_size: usize) -> Result { + let fixed_header = check(stream.iter(), max_size)?; + + // Test with a stream with exactly the size to check border panics + let packet = stream.split_to(fixed_header.frame_length()); + match fixed_header.packet_type()? { + PacketType::Connect => {} + p => return Err(Error::NotConnect(p)), + } + let mut packet = packet.freeze(); + + let variable_header_index = fixed_header.fixed_header_len; + packet.advance(variable_header_index); + + // Variable header + let protocol_name = read_mqtt_bytes(&mut packet)?; + let protocol_name = std::str::from_utf8(&protocol_name)?.to_owned(); + let protocol_level = read_u8(&mut packet)?; + if protocol_name != "MQTT" { + return Err(Error::InvalidProtocol); + } + + match protocol_level { + 4 => Ok(Connect::V4(v4::connect::connect_v4_part(packet)?)), + 5 => Ok(Connect::V5(v5::connect::connect_v5_part(packet)?)), + _ => Err(Error::InvalidProtocolLevel(protocol_level)), + } +} diff --git a/benchmarks/simplerouter/src/protocol/v4.rs b/benchmarks/simplerouter/src/protocol/v4.rs new file mode 100644 index 000000000..e418877c0 --- /dev/null +++ b/benchmarks/simplerouter/src/protocol/v4.rs @@ -0,0 +1,863 @@ +#![allow(dead_code)] + +use super::*; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +pub(crate) mod connect { + use super::*; + use bytes::Bytes; + + /// Connection packet initiated by the client + #[derive(Debug, Clone, PartialEq)] + pub struct Connect { + /// Mqtt keep alive time + pub keep_alive: u16, + /// Client Id + pub client_id: String, + /// Clean session. Asks the broker to clear previous state + pub clean_session: bool, + /// Will that broker needs to publish when the client disconnects + pub last_will: Option, + /// Login credentials + pub login: Option, + } + + impl Connect { + pub fn new>(id: S) -> Connect { + Connect { + keep_alive: 10, + client_id: id.into(), + clean_session: true, + last_will: None, + login: None, + } + } + + pub fn len(&self) -> usize { + let mut len = 2 + "MQTT".len() // protocol name + + 1 // protocol version + + 1 // connect flags + + 2; // keep alive + + len += 2 + self.client_id.len(); + + // last will len + if let Some(last_will) = &self.last_will { + len += last_will.len(); + } + + // username and password len + if let Some(login) = &self.login { + len += login.len(); + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + // Variable header + let protocol_name = read_mqtt_bytes(&mut bytes)?; + let protocol_name = std::str::from_utf8(&protocol_name)?.to_owned(); + let protocol_level = read_u8(&mut bytes)?; + if protocol_name != "MQTT" { + return Err(Error::InvalidProtocol); + } + + if protocol_level != 4 { + return Err(Error::InvalidProtocolLevel(protocol_level)); + } + + connect_v4_part(bytes) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + let len = self.len(); + buffer.put_u8(0b0001_0000); + let count = write_remaining_length(buffer, len)?; + write_mqtt_string(buffer, "MQTT"); + buffer.put_u8(0x04); + + let flags_index = 1 + count + 2 + 4 + 1; + + let mut connect_flags = 0; + if self.clean_session { + connect_flags |= 0x02; + } + + buffer.put_u8(connect_flags); + buffer.put_u16(self.keep_alive); + write_mqtt_string(buffer, &self.client_id); + + if let Some(last_will) = &self.last_will { + connect_flags |= last_will.write(buffer)?; + } + + if let Some(login) = &self.login { + connect_flags |= login.write(buffer); + } + + // update connect flags + buffer[flags_index] = connect_flags; + Ok(len) + } + } + + pub(crate) fn connect_v4_part(mut bytes: Bytes) -> Result { + let connect_flags = read_u8(&mut bytes)?; + let clean_session = (connect_flags & 0b10) != 0; + let keep_alive = read_u16(&mut bytes)?; + + let client_id = read_mqtt_bytes(&mut bytes)?; + let client_id = std::str::from_utf8(&client_id)?.to_owned(); + let last_will = LastWill::read(connect_flags, &mut bytes)?; + let login = Login::read(connect_flags, &mut bytes)?; + + let connect = Connect { + keep_alive, + client_id, + clean_session, + last_will, + login, + }; + + Ok(connect) + } + + /// LastWill that broker forwards on behalf of the client + #[derive(Debug, Clone, PartialEq)] + pub struct LastWill { + pub topic: String, + pub message: Bytes, + pub qos: QoS, + pub retain: bool, + } + + impl LastWill { + pub fn _new( + topic: impl Into, + payload: impl Into>, + qos: QoS, + retain: bool, + ) -> LastWill { + LastWill { + topic: topic.into(), + message: Bytes::from(payload.into()), + qos, + retain, + } + } + + fn len(&self) -> usize { + let mut len = 0; + len += 2 + self.topic.len() + 2 + self.message.len(); + len + } + + fn read(connect_flags: u8, mut bytes: &mut Bytes) -> Result, Error> { + let last_will = match connect_flags & 0b100 { + 0 if (connect_flags & 0b0011_1000) != 0 => { + return Err(Error::IncorrectPacketFormat); + } + 0 => None, + _ => { + let will_topic = read_mqtt_bytes(&mut bytes)?; + let will_topic = std::str::from_utf8(&will_topic)?.to_owned(); + let will_message = read_mqtt_bytes(&mut bytes)?; + let will_qos = qos((connect_flags & 0b11000) >> 3)?; + Some(LastWill { + topic: will_topic, + message: will_message, + qos: will_qos, + retain: (connect_flags & 0b0010_0000) != 0, + }) + } + }; + + Ok(last_will) + } + + fn write(&self, buffer: &mut BytesMut) -> Result { + let mut connect_flags = 0; + + connect_flags |= 0x04 | (self.qos as u8) << 3; + if self.retain { + connect_flags |= 0x20; + } + + write_mqtt_string(buffer, &self.topic); + write_mqtt_bytes(buffer, &self.message); + Ok(connect_flags) + } + } + + #[derive(Debug, Clone, PartialEq)] + pub struct Login { + username: String, + password: String, + } + + impl Login { + pub fn new>(u: S, p: S) -> Login { + Login { + username: u.into(), + password: p.into(), + } + } + + fn read(connect_flags: u8, mut bytes: &mut Bytes) -> Result, Error> { + let username = match connect_flags & 0b1000_0000 { + 0 => String::new(), + _ => { + let username = read_mqtt_bytes(&mut bytes)?; + std::str::from_utf8(&username)?.to_owned() + } + }; + + let password = match connect_flags & 0b0100_0000 { + 0 => String::new(), + _ => { + let password = read_mqtt_bytes(&mut bytes)?; + std::str::from_utf8(&password)?.to_owned() + } + }; + + if username.is_empty() && password.is_empty() { + Ok(None) + } else { + Ok(Some(Login { username, password })) + } + } + + fn len(&self) -> usize { + let mut len = 0; + + if !self.username.is_empty() { + len += 2 + self.username.len(); + } + + if !self.password.is_empty() { + len += 2 + self.password.len(); + } + + len + } + + fn write(&self, buffer: &mut BytesMut) -> u8 { + let mut connect_flags = 0; + if !self.username.is_empty() { + connect_flags |= 0x80; + write_mqtt_string(buffer, &self.username); + } + + if !self.password.is_empty() { + connect_flags |= 0x40; + write_mqtt_string(buffer, &self.password); + } + + connect_flags + } + } +} + +pub(crate) mod connack { + use super::*; + use bytes::{Buf, BufMut, Bytes, BytesMut}; + + /// Return code in connack + #[derive(Debug, Clone, Copy, PartialEq)] + #[repr(u8)] + pub enum ConnectReturnCode { + Success = 0, + RefusedProtocolVersion, + BadClientId, + ServiceUnavailable, + BadUserNamePassword, + NotAuthorized, + } + + /// Acknowledgement to connect packet + #[derive(Debug, Clone, PartialEq)] + pub struct ConnAck { + pub session_present: bool, + pub code: ConnectReturnCode, + } + + impl ConnAck { + pub fn new(code: ConnectReturnCode, session_present: bool) -> ConnAck { + ConnAck { + code, + session_present, + } + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + let flags = read_u8(&mut bytes)?; + let return_code = read_u8(&mut bytes)?; + + let session_present = (flags & 0x01) == 1; + let code = connect_return(return_code)?; + let connack = ConnAck { + session_present, + code, + }; + + Ok(connack) + } + } + + pub fn write( + code: ConnectReturnCode, + session_present: bool, + buffer: &mut BytesMut, + ) -> Result { + // sesssion present + code + let len = 1 + 1; + buffer.put_u8(0x20); + + let count = write_remaining_length(buffer, len)?; + buffer.put_u8(session_present as u8); + buffer.put_u8(code as u8); + + Ok(1 + count + len) + } + + /// Connection return code type + fn connect_return(num: u8) -> Result { + match num { + 0 => Ok(ConnectReturnCode::Success), + 1 => Ok(ConnectReturnCode::RefusedProtocolVersion), + 2 => Ok(ConnectReturnCode::BadClientId), + 3 => Ok(ConnectReturnCode::ServiceUnavailable), + 4 => Ok(ConnectReturnCode::BadUserNamePassword), + 5 => Ok(ConnectReturnCode::NotAuthorized), + num => Err(Error::InvalidConnectReturnCode(num)), + } + } +} + +pub(crate) mod publish { + use super::*; + use bytes::{BufMut, Bytes, BytesMut}; + + #[derive(Debug, Clone, PartialEq)] + pub struct Publish { + pub fixed_header: FixedHeader, + pub raw: Bytes, + } + + impl Publish { + // pub fn new, P: Into>>(topic: S, qos: QoS, payload: P) -> Publish { + // Publish { + // dup: false, + // qos, + // retain: false, + // pkid: 0, + // topic: topic.into(), + // payload: Bytes::from(payload.into()), + // } + // } + + // pub fn from_bytes>(topic: S, qos: QoS, payload: Bytes) -> Publish { + // Publish { + // dup: false, + // qos, + // retain: false, + // pkid: 0, + // topic: topic.into(), + // payload, + // } + // } + + // pub fn len(&self) -> usize { + // let mut len = 2 + self.topic.len(); + // if self.qos != QoS::AtMostOnce && self.pkid != 0 { + // len += 2; + // } + + // len += self.payload.len(); + // len + // } + + pub fn view_meta(&self) -> Result<(&str, u8, u16, bool, bool), Error> { + let qos = (self.fixed_header.byte1 & 0b0110) >> 1; + let dup = (self.fixed_header.byte1 & 0b1000) != 0; + let retain = (self.fixed_header.byte1 & 0b0001) != 0; + + // FIXME: Remove indexes and use get method + let stream = &self.raw[self.fixed_header.fixed_header_len..]; + let topic_len = view_u16(&stream)? as usize; + + let stream = &stream[2..]; + let topic = view_str(stream, topic_len)?; + + let pkid = match qos { + 0 => 0, + 1 => { + let stream = &stream[topic_len..]; + let pkid = view_u16(stream)?; + pkid + } + v => return Err(Error::InvalidQoS(v)), + }; + + if qos == 1 && pkid == 0 { + return Err(Error::PacketIdZero); + } + + Ok((topic, qos, pkid, dup, retain)) + } + + pub fn view_topic(&self) -> Result<&str, Error> { + // FIXME: Remove indexes + let stream = &self.raw[self.fixed_header.fixed_header_len..]; + let topic_len = view_u16(&stream)? as usize; + + let stream = &stream[2..]; + let topic = view_str(stream, topic_len)?; + Ok(topic) + } + + pub fn take_topic_and_payload(mut self) -> Result<(Bytes, Bytes), Error> { + let qos = (self.fixed_header.byte1 & 0b0110) >> 1; + + let variable_header_index = self.fixed_header.fixed_header_len; + self.raw.advance(variable_header_index); + let topic = read_mqtt_bytes(&mut self.raw)?; + + match qos { + 0 => (), + 1 => self.raw.advance(2), + v => return Err(Error::InvalidQoS(v)), + }; + + let payload = self.raw; + Ok((topic, payload)) + } + + pub fn read(fixed_header: FixedHeader, bytes: Bytes) -> Result { + let publish = Publish { + fixed_header, + raw: bytes, + }; + + Ok(publish) + } + } + + pub struct PublishBytes(pub Bytes); + + impl From for Result { + fn from(raw: PublishBytes) -> Self { + let fixed_header = check(raw.0.iter(), 100 * 1024 * 1024)?; + Ok(Publish { + fixed_header, + raw: raw.0, + }) + } + } + + pub fn write( + topic: &str, + qos: QoS, + pkid: u16, + dup: bool, + retain: bool, + payload: &[u8], + buffer: &mut BytesMut, + ) -> Result { + let mut len = 2 + topic.len(); + if qos != QoS::AtMostOnce { + len += 2; + } + + len += payload.len(); + + let dup = dup as u8; + let qos = qos as u8; + let retain = retain as u8; + + buffer.put_u8(0b0011_0000 | retain | qos << 1 | dup << 3); + + let count = write_remaining_length(buffer, len)?; + write_mqtt_string(buffer, topic); + + if qos != 0 { + if pkid == 0 { + return Err(Error::PacketIdZero); + } + + buffer.put_u16(pkid); + } + + buffer.extend_from_slice(&payload); + + // TODO: Returned length is wrong in other packets. Fix it + Ok(1 + count + len) + } +} + +pub(crate) mod puback { + use super::*; + use bytes::{Buf, BufMut, Bytes, BytesMut}; + + /// Acknowledgement to QoS1 publish + #[derive(Debug, Clone, PartialEq)] + pub struct PubAck { + pub pkid: u16, + } + + impl PubAck { + pub fn new(pkid: u16) -> PubAck { + PubAck { pkid } + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + let pkid = read_u16(&mut bytes)?; + + // No reason code or properties if remaining length == 2 + if fixed_header.remaining_len == 2 { + return Ok(PubAck { pkid }); + } + + // No properties len or properties if remaining len > 2 but < 4 + if fixed_header.remaining_len < 4 { + return Ok(PubAck { pkid }); + } + + let puback = PubAck { pkid }; + + Ok(puback) + } + } + + pub fn write(pkid: u16, buffer: &mut BytesMut) -> Result { + let len = 2; // pkid + buffer.put_u8(0x40); + + let count = write_remaining_length(buffer, len)?; + buffer.put_u16(pkid); + Ok(1 + count + len) + } +} + +pub(crate) mod subscribe { + use super::*; + use bytes::{Buf, Bytes}; + + /// Subscription packet + #[derive(Debug, Clone, PartialEq)] + pub struct Subscribe { + pub pkid: u16, + pub filters: Vec, + } + + impl Subscribe { + pub fn new>(path: S, qos: QoS) -> Subscribe { + let filter = SubscribeFilter { + path: path.into(), + qos, + }; + + let mut filters = Vec::new(); + filters.push(filter); + Subscribe { pkid: 0, filters } + } + + pub fn add(&mut self, path: String, qos: QoS) -> &mut Self { + let filter = SubscribeFilter { path, qos }; + + self.filters.push(filter); + self + } + + pub fn len(&self) -> usize { + let len = 2 + self.filters.iter().fold(0, |s, t| s + t.len()); // len of pkid + vec![subscribe filter len] + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + let pkid = read_u16(&mut bytes)?; + + // variable header size = 2 (packet identifier) + let mut filters = Vec::new(); + + while bytes.has_remaining() { + let path = read_mqtt_bytes(&mut bytes)?; + let path = std::str::from_utf8(&path)?.to_owned(); + let options = read_u8(&mut bytes)?; + let requested_qos = options & 0b0000_0011; + + filters.push(SubscribeFilter { + path, + qos: qos(requested_qos)?, + }); + } + + let subscribe = Subscribe { pkid, filters }; + + Ok(subscribe) + } + } + + pub fn write( + filters: Vec, + pkid: u16, + buffer: &mut BytesMut, + ) -> Result { + let len = 2 + filters.iter().fold(0, |s, t| s + t.len()); // len of pkid + vec![subscribe filter len] + // write packet type + buffer.put_u8(0x82); + + // write remaining length + let remaining_len_bytes = write_remaining_length(buffer, len)?; + + // write packet id + buffer.put_u16(pkid); + + // write filters + for filter in filters.iter() { + filter.write(buffer); + } + + Ok(1 + remaining_len_bytes + len) + } + + /// Subscription filter + #[derive(Debug, Clone, PartialEq)] + pub struct SubscribeFilter { + pub path: String, + pub qos: QoS, + } + + impl SubscribeFilter { + pub fn new(path: String, qos: QoS) -> SubscribeFilter { + SubscribeFilter { path, qos } + } + + pub fn len(&self) -> usize { + // filter len + filter + options + 2 + self.path.len() + 1 + } + + fn write(&self, buffer: &mut BytesMut) { + let mut options = 0; + options |= self.qos as u8; + + write_mqtt_string(buffer, self.path.as_str()); + buffer.put_u8(options); + } + } +} + +pub(crate) mod suback { + use std::convert::{TryFrom, TryInto}; + + use super::*; + use bytes::{Buf, Bytes}; + + /// Acknowledgement to subscribe + #[derive(Debug, Clone, PartialEq)] + pub struct SubAck { + pub pkid: u16, + pub return_codes: Vec, + } + + impl SubAck { + pub fn new(pkid: u16, return_codes: Vec) -> SubAck { + SubAck { pkid, return_codes } + } + + pub fn len(&self) -> usize { + let len = 2 + self.return_codes.len(); + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + let pkid = read_u16(&mut bytes)?; + + if !bytes.has_remaining() { + return Err(Error::MalformedPacket); + } + + let mut return_codes = Vec::new(); + while bytes.has_remaining() { + let return_code = read_u8(&mut bytes)?; + return_codes.push(return_code.try_into()?); + } + + let suback = SubAck { pkid, return_codes }; + Ok(suback) + } + } + + pub fn write( + return_codes: Vec, + pkid: u16, + buffer: &mut BytesMut, + ) -> Result { + let len = 2 + return_codes.len(); + buffer.put_u8(0x90); + + let remaining_len_bytes = write_remaining_length(buffer, len)?; + buffer.put_u16(pkid); + + let p: Vec = return_codes + .iter() + .map(|&code| match code { + SubscribeReasonCode::Success(qos) => qos as u8, + SubscribeReasonCode::Failure => 0x80, + }) + .collect(); + + buffer.extend_from_slice(&p); + Ok(1 + remaining_len_bytes + len) + } + + #[derive(Debug, Clone, Copy, PartialEq, Eq)] + pub enum SubscribeReasonCode { + Success(QoS), + Failure, + } + + impl TryFrom for SubscribeReasonCode { + type Error = Error; + + fn try_from(value: u8) -> Result { + let v = match value { + 0 => SubscribeReasonCode::Success(QoS::AtMostOnce), + 1 => SubscribeReasonCode::Success(QoS::AtLeastOnce), + 128 => SubscribeReasonCode::Failure, + v => return Err(Error::InvalidSubscribeReasonCode(v)), + }; + + Ok(v) + } + } + + pub fn codes(c: Vec) -> Vec { + c.into_iter() + .map(|v| { + let qos = qos(v).unwrap(); + SubscribeReasonCode::Success(qos) + }) + .collect() + } +} + +pub(crate) mod pingresp { + use super::*; + + pub fn write(payload: &mut BytesMut) -> Result { + payload.put_slice(&[0xD0, 0x00]); + Ok(2) + } +} + +pub(crate) mod pingreq { + use super::*; + + pub fn write(payload: &mut BytesMut) -> Result { + payload.put_slice(&[0xC0, 0x00]); + Ok(2) + } +} + +/// Reads a stream of bytes and extracts next MQTT packet out of it +pub fn read_mut(stream: &mut BytesMut, max_size: usize) -> Result { + let fixed_header = check(stream.iter(), max_size)?; + + // Test with a stream with exactly the size to check border panics + let packet = stream.split_to(fixed_header.frame_length()); + let packet_type = fixed_header.packet_type()?; + + if fixed_header.remaining_len == 0 { + // no payload packets + return match packet_type { + PacketType::PingReq => Ok(Packet::PingReq), + PacketType::PingResp => Ok(Packet::PingResp), + PacketType::Disconnect => Ok(Packet::Disconnect), + _ => Err(Error::PayloadRequired), + }; + } + + let packet = packet.freeze(); + let packet = match packet_type { + PacketType::Connect => Packet::Connect(connect::Connect::read(fixed_header, packet)?), + PacketType::ConnAck => Packet::ConnAck(connack::ConnAck::read(fixed_header, packet)?), + PacketType::Publish => Packet::Publish(publish::Publish::read(fixed_header, packet)?), + PacketType::PubAck => Packet::PubAck(puback::PubAck::read(fixed_header, packet)?), + PacketType::Subscribe => { + Packet::Subscribe(subscribe::Subscribe::read(fixed_header, packet)?) + } + PacketType::SubAck => Packet::SubAck(suback::SubAck::read(fixed_header, packet)?), + PacketType::PingReq => Packet::PingReq, + PacketType::PingResp => Packet::PingResp, + PacketType::Disconnect => Packet::Disconnect, + v => return Err(Error::UnsupportedPacket(v)), + }; + + Ok(packet) +} + +/// Reads a stream of bytes and extracts next MQTT packet out of it +pub fn read(stream: &mut Bytes, max_size: usize) -> Result { + let fixed_header = check(stream.iter(), max_size)?; + + // Test with a stream with exactly the size to check border panics + let packet = stream.split_to(fixed_header.frame_length()); + let packet_type = fixed_header.packet_type()?; + + if fixed_header.remaining_len == 0 { + // no payload packets + return match packet_type { + PacketType::PingReq => Ok(Packet::PingReq), + PacketType::PingResp => Ok(Packet::PingResp), + PacketType::Disconnect => Ok(Packet::Disconnect), + _ => Err(Error::PayloadRequired), + }; + } + + let packet = match packet_type { + PacketType::Connect => Packet::Connect(connect::Connect::read(fixed_header, packet)?), + PacketType::ConnAck => Packet::ConnAck(connack::ConnAck::read(fixed_header, packet)?), + PacketType::Publish => Packet::Publish(publish::Publish::read(fixed_header, packet)?), + PacketType::PubAck => Packet::PubAck(puback::PubAck::read(fixed_header, packet)?), + PacketType::Subscribe => { + Packet::Subscribe(subscribe::Subscribe::read(fixed_header, packet)?) + } + PacketType::SubAck => Packet::SubAck(suback::SubAck::read(fixed_header, packet)?), + PacketType::PingReq => Packet::PingReq, + PacketType::PingResp => Packet::PingResp, + PacketType::Disconnect => Packet::Disconnect, + v => return Err(Error::UnsupportedPacket(v)), + }; + + Ok(packet) +} + +#[derive(Clone, Debug, PartialEq)] +pub enum Packet { + Connect(connect::Connect), + Publish(publish::Publish), + ConnAck(connack::ConnAck), + PubAck(puback::PubAck), + PingReq, + PingResp, + Subscribe(subscribe::Subscribe), + SubAck(suback::SubAck), + Disconnect, +} diff --git a/benchmarks/simplerouter/src/protocol/v5.rs b/benchmarks/simplerouter/src/protocol/v5.rs new file mode 100644 index 000000000..c63fe816e --- /dev/null +++ b/benchmarks/simplerouter/src/protocol/v5.rs @@ -0,0 +1,1952 @@ +#![allow(dead_code)] + +use std::fmt; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +use super::*; + +pub(crate) mod connect { + use super::*; + use bytes::Bytes; + + /// Connection packet initiated by the client + #[derive(Debug, Clone, PartialEq)] + pub struct Connect { + /// Mqtt keep alive time + pub keep_alive: u16, + /// Client Id + pub client_id: String, + /// Clean session. Asks the broker to clear previous state + pub clean_session: bool, + /// Will that broker needs to publish when the client disconnects + pub last_will: Option, + /// Login credentials + pub login: Option, + /// Properties + pub properties: Option, + } + + impl Connect { + pub fn new>(id: S) -> Connect { + Connect { + keep_alive: 10, + client_id: id.into(), + clean_session: true, + last_will: None, + login: None, + properties: None, + } + } + + pub fn len(&self) -> usize { + let mut len = 2 + "MQTT".len() // protocol name + + 1 // protocol version + + 1 // connect flags + + 2; // keep alive + + len += 2 + self.client_id.len(); + + // last will len + if let Some(last_will) = &self.last_will { + len += last_will.len(); + } + + // username and password len + if let Some(login) = &self.login { + len += login.len(); + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + // Variable header + let protocol_name = read_mqtt_bytes(&mut bytes)?; + let protocol_name = std::str::from_utf8(&protocol_name)?.to_owned(); + if protocol_name != "MQTT" { + return Err(Error::InvalidProtocol); + } + + let protocol_level = read_u8(&mut bytes)?; + if protocol_level != 5 { + return Err(Error::InvalidProtocolLevel(protocol_level)); + } + + connect_v5_part(bytes) + } + } + + pub(crate) fn connect_v5_part(mut bytes: Bytes) -> Result { + let connect_flags = read_u8(&mut bytes)?; + let clean_session = (connect_flags & 0b10) != 0; + let keep_alive = read_u16(&mut bytes)?; + + let properties = ConnectProperties::read(&mut bytes)?; + + // Payload + let client_id = read_mqtt_bytes(&mut bytes)?; + let client_id = std::str::from_utf8(&client_id)?.to_owned(); + let last_will = LastWill::read(connect_flags, &mut bytes)?; + let login = Login::read(connect_flags, &mut bytes)?; + + let connect = Connect { + keep_alive, + client_id, + clean_session, + last_will, + login, + properties, + }; + + Ok(connect) + } + + /// LastWill that broker forwards on behalf of the client + #[derive(Debug, Clone, PartialEq)] + pub struct LastWill { + pub topic: String, + pub message: Bytes, + pub qos: QoS, + pub retain: bool, + } + + impl LastWill { + pub fn _new( + topic: impl Into, + payload: impl Into>, + qos: QoS, + retain: bool, + ) -> LastWill { + LastWill { + topic: topic.into(), + message: Bytes::from(payload.into()), + qos, + retain, + } + } + + fn len(&self) -> usize { + let mut len = 0; + len += 2 + self.topic.len() + 2 + self.message.len(); + len + } + + fn read(connect_flags: u8, mut bytes: &mut Bytes) -> Result, Error> { + let last_will = match connect_flags & 0b100 { + 0 if (connect_flags & 0b0011_1000) != 0 => { + return Err(Error::IncorrectPacketFormat); + } + 0 => None, + _ => { + let will_topic = read_mqtt_bytes(&mut bytes)?; + let will_topic = std::str::from_utf8(&will_topic)?.to_owned(); + let will_message = read_mqtt_bytes(&mut bytes)?; + let will_qos = qos((connect_flags & 0b11000) >> 3)?; + Some(LastWill { + topic: will_topic, + message: will_message, + qos: will_qos, + retain: (connect_flags & 0b0010_0000) != 0, + }) + } + }; + + Ok(last_will) + } + } + + #[derive(Debug, Clone, PartialEq)] + pub struct Login { + username: String, + password: String, + } + + impl Login { + pub fn new>(u: S, p: S) -> Login { + Login { + username: u.into(), + password: p.into(), + } + } + + fn read(connect_flags: u8, mut bytes: &mut Bytes) -> Result, Error> { + let username = match connect_flags & 0b1000_0000 { + 0 => String::new(), + _ => { + let username = read_mqtt_bytes(&mut bytes)?; + std::str::from_utf8(&username)?.to_owned() + } + }; + + let password = match connect_flags & 0b0100_0000 { + 0 => String::new(), + _ => { + let password = read_mqtt_bytes(&mut bytes)?; + std::str::from_utf8(&password)?.to_owned() + } + }; + + if username.is_empty() && password.is_empty() { + Ok(None) + } else { + Ok(Some(Login { username, password })) + } + } + + fn len(&self) -> usize { + let mut len = 0; + + if !self.username.is_empty() { + len += 2 + self.username.len(); + } + + if !self.password.is_empty() { + len += 2 + self.password.len(); + } + + len + } + } + + #[derive(Debug, Clone, PartialEq)] + pub struct ConnectProperties { + /// Expiry interval property after loosing connection + pub session_expiry_interval: Option, + /// Maximum simultaneous packets + pub receive_maximum: Option, + /// Maximum packet size + pub max_packet_size: Option, + /// Maximum mapping integer for a topic + pub topic_alias_max: Option, + pub request_response_info: Option, + pub request_problem_info: Option, + /// List of user properties + pub user_properties: Vec<(String, String)>, + /// Method of authentication + pub authentication_method: Option, + /// Authentication data + pub authentication_data: Option, + } + + impl ConnectProperties { + fn _new() -> ConnectProperties { + ConnectProperties { + session_expiry_interval: None, + receive_maximum: None, + max_packet_size: None, + topic_alias_max: None, + request_response_info: None, + request_problem_info: None, + user_properties: Vec::new(), + authentication_method: None, + authentication_data: None, + } + } + + fn read(mut bytes: &mut Bytes) -> Result, Error> { + let mut session_expiry_interval = None; + let mut receive_maximum = None; + let mut max_packet_size = None; + let mut topic_alias_max = None; + let mut request_response_info = None; + let mut request_problem_info = None; + let mut user_properties = Vec::new(); + let mut authentication_method = None; + let mut authentication_data = None; + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + match property(prop)? { + PropertyType::SessionExpiryInterval => { + session_expiry_interval = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::ReceiveMaximum => { + receive_maximum = Some(read_u16(&mut bytes)?); + cursor += 2; + } + PropertyType::MaximumPacketSize => { + max_packet_size = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::TopicAliasMaximum => { + topic_alias_max = Some(read_u16(&mut bytes)?); + cursor += 2; + } + PropertyType::RequestResponseInformation => { + request_response_info = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::RequestProblemInformation => { + request_problem_info = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::UserProperty => { + let key = read_mqtt_bytes(&mut bytes)?; + let key = std::str::from_utf8(&key)?.to_owned(); + let value = read_mqtt_bytes(&mut bytes)?; + let value = std::str::from_utf8(&value)?.to_owned(); + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + PropertyType::AuthenticationMethod => { + let method = read_mqtt_bytes(&mut bytes)?; + let method = std::str::from_utf8(&method)?.to_owned(); + cursor += 2 + method.len(); + authentication_method = Some(method); + } + PropertyType::AuthenticationData => { + let data = read_mqtt_bytes(&mut bytes)?; + cursor += 2 + data.len(); + authentication_data = Some(data); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(ConnectProperties { + session_expiry_interval, + receive_maximum, + max_packet_size, + topic_alias_max, + request_response_info, + request_problem_info, + user_properties, + authentication_method, + authentication_data, + })) + } + + fn len(&self) -> usize { + let mut len = 0; + + if self.session_expiry_interval.is_some() { + len += 1 + 4; + } + + if self.receive_maximum.is_some() { + len += 1 + 2; + } + + if self.max_packet_size.is_some() { + len += 1 + 4; + } + + if self.topic_alias_max.is_some() { + len += 1 + 2; + } + + if self.request_response_info.is_some() { + len += 1 + 1; + } + + if self.request_problem_info.is_some() { + len += 1 + 1; + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + if let Some(authentication_method) = &self.authentication_method { + len += 1 + 2 + authentication_method.len(); + } + + if let Some(authentication_data) = &self.authentication_data { + len += 1 + 2 + authentication_data.len(); + } + + len + } + } +} + +pub(crate) mod connack { + use super::*; + use bytes::{Buf, BufMut, Bytes, BytesMut}; + + /// Return code in connack + #[derive(Debug, Clone, Copy, PartialEq)] + #[repr(u8)] + pub enum ConnectReturnCode { + Success = 0x00, + UnspecifiedError = 0x80, + MalformedPacket = 0x81, + ProtocolError = 0x82, + ImplementationSpecificError = 0x83, + UnsupportedProtocolVersion = 0x84, + ClientIdentifierNotValid = 0x85, + BadUserNamePassword = 0x86, + NotAuthorized = 0x87, + ServerUnavailable = 0x88, + ServerBusy = 0x89, + Banned = 0x8a, + BadAuthenticationMethod = 0x8c, + TopicNameInvalid = 0x90, + PacketTooLarge = 0x95, + QuotaExceeded = 0x97, + PayloadFormatInvalid = 0x99, + RetainNotSupported = 0x9a, + QoSNotSupported = 0x9b, + UseAnotherServer = 0x9c, + ServerMoved = 0x9d, + ConnectionRateExceeded = 0x94, + } + + /// Acknowledgement to connect packet + #[derive(Debug, Clone, PartialEq)] + pub struct ConnAck { + pub session_present: bool, + pub code: ConnectReturnCode, + pub properties: Option, + } + + impl ConnAck { + pub fn new(code: ConnectReturnCode, session_present: bool) -> ConnAck { + ConnAck { + code, + session_present, + properties: None, + } + } + + pub fn len(&self) -> usize { + let mut len = 1 // session present + + 1; // code + + if let Some(properties) = &self.properties { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } else { + // 1 byte for 0 len + len += 1; + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + let flags = read_u8(&mut bytes)?; + let return_code = read_u8(&mut bytes)?; + + let session_present = (flags & 0x01) == 1; + let code = connect_return(return_code)?; + let connack = ConnAck { + session_present, + code, + properties: ConnAckProperties::extract(&mut bytes)?, + }; + + Ok(connack) + } + } + + pub fn write( + code: ConnectReturnCode, + session_present: bool, + properties: Option, + buffer: &mut BytesMut, + ) -> Result { + // TODO: maybe we can remove double checking if properties == None ? + + let mut len = 1 // session present + + 1; // code + if let Some(ref properties) = properties { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } else { + // 1 byte for 0 len + len += 1; + } + + buffer.put_u8(0x20); + + let count = write_remaining_length(buffer, len)?; + + buffer.put_u8(session_present as u8); + buffer.put_u8(code as u8); + + if let Some(properties) = properties { + properties.write(buffer)?; + } else { + // 1 byte for 0 len + buffer.put_u8(0); + } + + Ok(1 + count + len) + } + + #[derive(Debug, Clone, PartialEq)] + pub struct ConnAckProperties { + pub session_expiry_interval: Option, + pub receive_max: Option, + pub max_qos: Option, + pub retain_available: Option, + pub max_packet_size: Option, + pub assigned_client_identifier: Option, + pub topic_alias_max: Option, + pub reason_string: Option, + pub user_properties: Vec<(String, String)>, + pub wildcard_subscription_available: Option, + pub subscription_identifiers_available: Option, + pub shared_subscription_available: Option, + pub server_keep_alive: Option, + pub response_information: Option, + pub server_reference: Option, + pub authentication_method: Option, + pub authentication_data: Option, + } + + impl ConnAckProperties { + pub fn new() -> ConnAckProperties { + ConnAckProperties { + session_expiry_interval: None, + receive_max: None, + max_qos: None, + retain_available: None, + max_packet_size: None, + assigned_client_identifier: None, + topic_alias_max: None, + reason_string: None, + user_properties: Vec::new(), + wildcard_subscription_available: None, + subscription_identifiers_available: None, + shared_subscription_available: None, + server_keep_alive: None, + response_information: None, + server_reference: None, + authentication_method: None, + authentication_data: None, + } + } + + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(_) = &self.session_expiry_interval { + len += 1 + 4; + } + + if let Some(_) = &self.receive_max { + len += 1 + 2; + } + + if let Some(_) = &self.max_qos { + len += 1 + 1; + } + + if let Some(_) = &self.retain_available { + len += 1 + 1; + } + + if let Some(_) = &self.max_packet_size { + len += 1 + 4; + } + + if let Some(id) = &self.assigned_client_identifier { + len += 1 + 2 + id.len(); + } + + if let Some(_) = &self.topic_alias_max { + len += 1 + 2; + } + + if let Some(reason) = &self.reason_string { + len += 1 + 2 + reason.len(); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + if let Some(_) = &self.wildcard_subscription_available { + len += 1 + 1; + } + + if let Some(_) = &self.subscription_identifiers_available { + len += 1 + 1; + } + + if let Some(_) = &self.shared_subscription_available { + len += 1 + 1; + } + + if let Some(_) = &self.server_keep_alive { + len += 1 + 2; + } + + if let Some(info) = &self.response_information { + len += 1 + 2 + info.len(); + } + + if let Some(reference) = &self.server_reference { + len += 1 + 2 + reference.len(); + } + + if let Some(authentication_method) = &self.authentication_method { + len += 1 + 2 + authentication_method.len(); + } + + if let Some(authentication_data) = &self.authentication_data { + len += 1 + 2 + authentication_data.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut session_expiry_interval = None; + let mut receive_max = None; + let mut max_qos = None; + let mut retain_available = None; + let mut max_packet_size = None; + let mut assigned_client_identifier = None; + let mut topic_alias_max = None; + let mut reason_string = None; + let mut user_properties = Vec::new(); + let mut wildcard_subscription_available = None; + let mut subscription_identifiers_available = None; + let mut shared_subscription_available = None; + let mut server_keep_alive = None; + let mut response_information = None; + let mut server_reference = None; + let mut authentication_method = None; + let mut authentication_data = None; + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::SessionExpiryInterval => { + session_expiry_interval = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::ReceiveMaximum => { + receive_max = Some(read_u16(&mut bytes)?); + cursor += 2; + } + PropertyType::MaximumQos => { + max_qos = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::RetainAvailable => { + retain_available = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::AssignedClientIdentifier => { + let bytes = read_mqtt_bytes(&mut bytes)?; + let id = std::str::from_utf8(&bytes)?.to_owned(); + cursor += 2 + id.len(); + assigned_client_identifier = Some(id); + } + PropertyType::MaximumPacketSize => { + max_packet_size = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::TopicAliasMaximum => { + topic_alias_max = Some(read_u16(&mut bytes)?); + cursor += 2; + } + PropertyType::ReasonString => { + let reason = read_mqtt_bytes(&mut bytes)?; + let reason = std::str::from_utf8(&reason)?.to_owned(); + cursor += 2 + reason.len(); + reason_string = Some(reason); + } + PropertyType::UserProperty => { + let key = read_mqtt_bytes(&mut bytes)?; + let key = std::str::from_utf8(&key)?.to_owned(); + let value = read_mqtt_bytes(&mut bytes)?; + let value = std::str::from_utf8(&value)?.to_owned(); + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + PropertyType::WildcardSubscriptionAvailable => { + wildcard_subscription_available = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::SubscriptionIdentifierAvailable => { + subscription_identifiers_available = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::SharedSubscriptionAvailable => { + shared_subscription_available = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::ServerKeepAlive => { + server_keep_alive = Some(read_u16(&mut bytes)?); + cursor += 2; + } + PropertyType::ResponseInformation => { + let info = read_mqtt_bytes(&mut bytes)?; + let info = std::str::from_utf8(&info)?.to_owned(); + cursor += 2 + info.len(); + response_information = Some(info); + } + PropertyType::ServerReference => { + let bytes = read_mqtt_bytes(&mut bytes)?; + let reference = std::str::from_utf8(&bytes)?.to_owned(); + cursor += 2 + reference.len(); + server_reference = Some(reference); + } + PropertyType::AuthenticationMethod => { + let bytes = read_mqtt_bytes(&mut bytes)?; + let method = std::str::from_utf8(&bytes)?.to_owned(); + cursor += 2 + method.len(); + authentication_method = Some(method); + } + PropertyType::AuthenticationData => { + let data = read_mqtt_bytes(&mut bytes)?; + cursor += 2 + data.len(); + authentication_data = Some(data); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(ConnAckProperties { + session_expiry_interval, + receive_max, + max_qos, + retain_available, + max_packet_size, + assigned_client_identifier, + topic_alias_max, + reason_string, + user_properties, + wildcard_subscription_available, + subscription_identifiers_available, + shared_subscription_available, + server_keep_alive, + response_information, + server_reference, + authentication_method, + authentication_data, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(session_expiry_interval) = self.session_expiry_interval { + buffer.put_u8(PropertyType::SessionExpiryInterval as u8); + buffer.put_u32(session_expiry_interval); + } + + if let Some(receive_maximum) = self.receive_max { + buffer.put_u8(PropertyType::ReceiveMaximum as u8); + buffer.put_u16(receive_maximum); + } + + if let Some(qos) = self.max_qos { + buffer.put_u8(PropertyType::MaximumQos as u8); + buffer.put_u8(qos); + } + + if let Some(retain_available) = self.retain_available { + buffer.put_u8(PropertyType::RetainAvailable as u8); + buffer.put_u8(retain_available); + } + + if let Some(max_packet_size) = self.max_packet_size { + buffer.put_u8(PropertyType::MaximumPacketSize as u8); + buffer.put_u32(max_packet_size); + } + + if let Some(id) = &self.assigned_client_identifier { + buffer.put_u8(PropertyType::AssignedClientIdentifier as u8); + write_mqtt_string(buffer, id); + } + + if let Some(topic_alias_max) = self.topic_alias_max { + buffer.put_u8(PropertyType::TopicAliasMaximum as u8); + buffer.put_u16(topic_alias_max); + } + + if let Some(reason) = &self.reason_string { + buffer.put_u8(PropertyType::ReasonString as u8); + write_mqtt_string(buffer, reason); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + if let Some(w) = self.wildcard_subscription_available { + buffer.put_u8(PropertyType::WildcardSubscriptionAvailable as u8); + buffer.put_u8(w); + } + + if let Some(s) = self.subscription_identifiers_available { + buffer.put_u8(PropertyType::SubscriptionIdentifierAvailable as u8); + buffer.put_u8(s); + } + + if let Some(s) = self.shared_subscription_available { + buffer.put_u8(PropertyType::SharedSubscriptionAvailable as u8); + buffer.put_u8(s); + } + + if let Some(keep_alive) = self.server_keep_alive { + buffer.put_u8(PropertyType::ServerKeepAlive as u8); + buffer.put_u16(keep_alive); + } + + if let Some(info) = &self.response_information { + buffer.put_u8(PropertyType::ResponseInformation as u8); + write_mqtt_string(buffer, info); + } + + if let Some(reference) = &self.server_reference { + buffer.put_u8(PropertyType::ServerReference as u8); + write_mqtt_string(buffer, reference); + } + + if let Some(authentication_method) = &self.authentication_method { + buffer.put_u8(PropertyType::AuthenticationMethod as u8); + write_mqtt_string(buffer, authentication_method); + } + + if let Some(authentication_data) = &self.authentication_data { + buffer.put_u8(PropertyType::AuthenticationData as u8); + write_mqtt_bytes(buffer, authentication_data); + } + + Ok(()) + } + } + + /// Connection return code type + fn connect_return(num: u8) -> Result { + match num { + 0x00 => Ok(ConnectReturnCode::Success), + 0x80 => Ok(ConnectReturnCode::UnspecifiedError), + 0x81 => Ok(ConnectReturnCode::MalformedPacket), + 0x82 => Ok(ConnectReturnCode::ProtocolError), + 0x83 => Ok(ConnectReturnCode::ImplementationSpecificError), + 0x84 => Ok(ConnectReturnCode::UnsupportedProtocolVersion), + 0x85 => Ok(ConnectReturnCode::ClientIdentifierNotValid), + 0x86 => Ok(ConnectReturnCode::BadUserNamePassword), + 0x87 => Ok(ConnectReturnCode::NotAuthorized), + 0x88 => Ok(ConnectReturnCode::ServerUnavailable), + 0x89 => Ok(ConnectReturnCode::ServerBusy), + 0x8a => Ok(ConnectReturnCode::Banned), + 0x8c => Ok(ConnectReturnCode::BadAuthenticationMethod), + 0x90 => Ok(ConnectReturnCode::TopicNameInvalid), + 0x95 => Ok(ConnectReturnCode::PacketTooLarge), + 0x97 => Ok(ConnectReturnCode::QuotaExceeded), + 0x99 => Ok(ConnectReturnCode::PayloadFormatInvalid), + 0x9a => Ok(ConnectReturnCode::RetainNotSupported), + 0x9b => Ok(ConnectReturnCode::QoSNotSupported), + 0x9c => Ok(ConnectReturnCode::UseAnotherServer), + 0x9d => Ok(ConnectReturnCode::ServerMoved), + 0x94 => Ok(ConnectReturnCode::ConnectionRateExceeded), + num => Err(Error::InvalidConnectReturnCode(num)), + } + } +} + +pub(crate) mod publish { + use super::*; + use bytes::{BufMut, Bytes, BytesMut}; + + #[derive(Debug, Clone, PartialEq)] + pub struct Publish { + pub fixed_header: FixedHeader, + pub raw: Bytes, + } + + impl Publish { + // pub fn new, P: Into>>(topic: S, qos: QoS, payload: P) -> Publish { + // Publish { + // dup: false, + // qos, + // retain: false, + // pkid: 0, + // topic: topic.into(), + // payload: Bytes::from(payload.into()), + // } + // } + + // pub fn from_bytes>(topic: S, qos: QoS, payload: Bytes) -> Publish { + // Publish { + // dup: false, + // qos, + // retain: false, + // pkid: 0, + // topic: topic.into(), + // payload, + // } + // } + + // pub fn len(&self) -> usize { + // let mut len = 2 + self.topic.len(); + // if self.qos != QoS::AtMostOnce && self.pkid != 0 { + // len += 2; + // } + + // len += self.payload.len(); + // len + // } + + pub fn view_meta(&self) -> Result<(&str, u8, u16, bool, bool), Error> { + let qos = (self.fixed_header.byte1 & 0b0110) >> 1; + let dup = (self.fixed_header.byte1 & 0b1000) != 0; + let retain = (self.fixed_header.byte1 & 0b0001) != 0; + + // FIXME: Remove indexes and use get method + let stream = &self.raw[self.fixed_header.fixed_header_len..]; + let topic_len = view_u16(&stream)? as usize; + + let stream = &stream[2..]; + let topic = view_str(stream, topic_len)?; + + let pkid = match qos { + 0 => 0, + 1 => { + let stream = &stream[topic_len..]; + let pkid = view_u16(stream)?; + pkid + } + v => return Err(Error::InvalidQoS(v)), + }; + + if qos == 1 && pkid == 0 { + return Err(Error::PacketIdZero); + } + + Ok((topic, qos, pkid, dup, retain)) + } + + pub fn view_topic(&self) -> Result<&str, Error> { + // FIXME: Remove indexes + let stream = &self.raw[self.fixed_header.fixed_header_len..]; + let topic_len = view_u16(&stream)? as usize; + + let stream = &stream[2..]; + let topic = view_str(stream, topic_len)?; + Ok(topic) + } + + pub fn take_topic_and_payload(mut self) -> Result<(Bytes, Bytes), Error> { + let qos = (self.fixed_header.byte1 & 0b0110) >> 1; + + let variable_header_index = self.fixed_header.fixed_header_len; + self.raw.advance(variable_header_index); + let topic = read_mqtt_bytes(&mut self.raw)?; + + match qos { + 0 => (), + 1 => self.raw.advance(2), + v => return Err(Error::InvalidQoS(v)), + }; + + let payload = self.raw; + Ok((topic, payload)) + } + + pub fn read(fixed_header: FixedHeader, bytes: Bytes) -> Result { + let publish = Publish { + fixed_header, + raw: bytes, + }; + + Ok(publish) + } + } + + pub struct PublishBytes(pub Bytes); + + impl From for Result { + fn from(raw: PublishBytes) -> Self { + let fixed_header = check(raw.0.iter(), 100 * 1024 * 1024)?; + Ok(Publish { + fixed_header, + raw: raw.0, + }) + } + } + + pub fn write( + topic: &str, + qos: QoS, + pkid: u16, + dup: bool, + retain: bool, + payload: &[u8], + buffer: &mut BytesMut, + ) -> Result { + let mut len = 2 + topic.len(); + if qos != QoS::AtMostOnce { + len += 2; + } + + len += payload.len(); + + let dup = dup as u8; + let qos = qos as u8; + let retain = retain as u8; + + buffer.put_u8(0b0011_0000 | retain | qos << 1 | dup << 3); + + let count = write_remaining_length(buffer, len)?; + write_mqtt_string(buffer, topic); + + if qos != 0 { + if pkid == 0 { + return Err(Error::PacketIdZero); + } + + buffer.put_u16(pkid); + } + + buffer.extend_from_slice(&payload); + + // TODO: Returned length is wrong in other packets. Fix it + Ok(1 + count + len) + } +} + +pub(crate) mod puback { + use super::*; + use bytes::{Buf, BufMut, Bytes, BytesMut}; + + /// Acknowledgement to QoS1 publish + #[derive(Debug, Clone, PartialEq)] + pub struct PubAck { + pub pkid: u16, + pub reason: PubAckReason, + pub properties: Option, + } + + impl PubAck { + pub fn new(pkid: u16) -> PubAck { + PubAck { + pkid, + reason: PubAckReason::Success, + properties: None, + } + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + let pkid = read_u16(&mut bytes)?; + + // No reason code or properties if remaining length == 2 + if fixed_header.remaining_len == 2 { + return Ok(PubAck { + pkid, + reason: PubAckReason::Success, + properties: None, + }); + } + + // No properties len or properties if remaining len > 2 but < 4 + let ack_reason = read_u8(&mut bytes)?; + if fixed_header.remaining_len < 4 { + return Ok(PubAck { + pkid, + reason: reason(ack_reason)?, + properties: None, + }); + } + + let puback = PubAck { + pkid, + reason: reason(ack_reason)?, + properties: PubAckProperties::extract(&mut bytes)?, + }; + + Ok(puback) + } + } + + pub fn write( + pkid: u16, + reason: PubAckReason, + properties: Option, + buffer: &mut BytesMut, + ) -> Result { + buffer.put_u8(0x40); + + match &properties { + Some(properties) => { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + let len = 2 + 1 + properties_len_len + properties_len; + + let count = write_remaining_length(buffer, len)?; + buffer.put_u16(pkid); + buffer.put_u8(reason as u8); + properties.write(buffer)?; + + Ok(len + count + 1) + } + None => { + // Unlike other packets, property length can be ignored if there are + // no properties in acks + // + // TODO: maybe we should set len = 2 for PubAckReason == Success + let len = 2 + 1; + let count = write_remaining_length(buffer, len)?; + buffer.put_u16(pkid); + buffer.put_u8(reason as u8); + + Ok(len + count + 1) + } + } + } + + #[derive(Debug, Clone, PartialEq)] + pub struct PubAckProperties { + pub reason_string: Option, + pub user_properties: Vec<(String, String)>, + } + + /// Return code in connack + #[derive(Debug, Clone, Copy, PartialEq)] + #[repr(u8)] + pub enum PubAckReason { + Success = 0, + NoMatchingSubscribers = 16, + UnspecifiedError = 128, + ImplementationSpecificError = 131, + NotAuthorized = 135, + TopicNameInvalid = 144, + PacketIdentifierInUse = 145, + QuotaExceeded = 151, + PayloadFormatInvalid = 153, + } + + impl PubAckProperties { + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(reason) = &self.reason_string { + len += 1 + 2 + reason.len(); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut reason_string = None; + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::ReasonString => { + let bytes = read_mqtt_bytes(&mut bytes)?; + let reason = std::str::from_utf8(&bytes)?.to_owned(); + cursor += 2 + reason.len(); + reason_string = Some(reason); + } + PropertyType::UserProperty => { + let key = read_mqtt_bytes(&mut bytes)?; + let key = std::str::from_utf8(&key)?.to_owned(); + let value = read_mqtt_bytes(&mut bytes)?; + let value = std::str::from_utf8(&value)?.to_owned(); + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(PubAckProperties { + reason_string, + user_properties, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(reason) = &self.reason_string { + buffer.put_u8(PropertyType::ReasonString as u8); + write_mqtt_string(buffer, reason); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } + } + /// Connection return code type + fn reason(num: u8) -> Result { + let code = match num { + 0 => PubAckReason::Success, + 16 => PubAckReason::NoMatchingSubscribers, + 128 => PubAckReason::UnspecifiedError, + 131 => PubAckReason::ImplementationSpecificError, + 135 => PubAckReason::NotAuthorized, + 144 => PubAckReason::TopicNameInvalid, + 145 => PubAckReason::PacketIdentifierInUse, + 151 => PubAckReason::QuotaExceeded, + 153 => PubAckReason::PayloadFormatInvalid, + num => return Err(Error::InvalidConnectReturnCode(num)), + }; + + Ok(code) + } +} + +pub(crate) mod subscribe { + use super::*; + use bytes::{Buf, Bytes}; + + /// Subscription packet + #[derive(Clone, PartialEq)] + pub struct Subscribe { + pub pkid: u16, + pub filters: Vec, + pub properties: Option, + } + + impl Subscribe { + pub fn new>(path: S, qos: QoS) -> Subscribe { + let filter = SubscribeFilter { + path: path.into(), + qos, + nolocal: false, + preserve_retain: false, + retain_forward_rule: RetainForwardRule::OnEverySubscribe, + }; + + let mut filters = Vec::new(); + filters.push(filter); + Subscribe { + pkid: 0, + filters, + properties: None, + } + } + + pub fn add(&mut self, path: String, qos: QoS) -> &mut Self { + let filter = SubscribeFilter { + path, + qos, + nolocal: false, + preserve_retain: false, + retain_forward_rule: RetainForwardRule::OnEverySubscribe, + }; + + self.filters.push(filter); + self + } + + pub fn len(&self) -> usize { + let mut len = 2 + self.filters.iter().fold(0, |s, t| s + t.len()); + + if let Some(properties) = &self.properties { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } else { + // just 1 byte representing 0 len + len += 1; + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + let pkid = read_u16(&mut bytes)?; + let properties = SubscribeProperties::extract(&mut bytes)?; + + // variable header size = 2 (packet identifier) + let mut filters = Vec::new(); + + while bytes.has_remaining() { + let path = read_mqtt_bytes(&mut bytes)?; + let path = std::str::from_utf8(&path)?.to_owned(); + let options = read_u8(&mut bytes)?; + let requested_qos = options & 0b0000_0011; + + let nolocal = options >> 2 & 0b0000_0001; + let nolocal = if nolocal == 0 { false } else { true }; + + let preserve_retain = options >> 3 & 0b0000_0001; + let preserve_retain = if preserve_retain == 0 { false } else { true }; + + let retain_forward_rule = (options >> 4) & 0b0000_0011; + let retain_forward_rule = match retain_forward_rule { + 0 => RetainForwardRule::OnEverySubscribe, + 1 => RetainForwardRule::OnNewSubscribe, + 2 => RetainForwardRule::Never, + r => return Err(Error::InvalidRetainForwardRule(r)), + }; + + filters.push(SubscribeFilter { + path, + qos: qos(requested_qos)?, + nolocal, + preserve_retain, + retain_forward_rule, + }); + } + + let subscribe = Subscribe { + pkid, + filters, + properties, + }; + + Ok(subscribe) + } + } + + pub fn write( + filters: Vec, + pkid: u16, + properties: Option, + buffer: &mut BytesMut, + ) -> Result { + // write packet type + buffer.put_u8(0x82); + + // write remaining length + let mut len = 2 + filters.iter().fold(0, |s, t| s + t.len()); + + if let Some(properties) = &properties { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } else { + // just 1 byte representing 0 len + len += 1; + } + let remaining_len = len; + let remaining_len_bytes = write_remaining_length(buffer, remaining_len)?; + + // write packet id + buffer.put_u16(pkid); + + match &properties { + Some(properties) => properties.write(buffer)?, + None => { + write_remaining_length(buffer, 0)?; + } + }; + + // write filters + for filter in filters.iter() { + filter.write(buffer); + } + + Ok(1 + remaining_len_bytes + remaining_len) + } + + /// Subscription filter + #[derive(Clone, PartialEq)] + pub struct SubscribeFilter { + pub path: String, + pub qos: QoS, + pub nolocal: bool, + pub preserve_retain: bool, + pub retain_forward_rule: RetainForwardRule, + } + + impl SubscribeFilter { + pub fn new(path: String, qos: QoS) -> SubscribeFilter { + SubscribeFilter { + path, + qos, + nolocal: false, + preserve_retain: false, + retain_forward_rule: RetainForwardRule::OnEverySubscribe, + } + } + + pub fn len(&self) -> usize { + // filter len + filter + options + 2 + self.path.len() + 1 + } + + fn write(&self, buffer: &mut BytesMut) { + let mut options = 0; + options |= self.qos as u8; + + if self.nolocal { + options |= 1 << 2; + } + + if self.preserve_retain { + options |= 1 << 3; + } + + match self.retain_forward_rule { + RetainForwardRule::OnEverySubscribe => options |= 0 << 4, + RetainForwardRule::OnNewSubscribe => options |= 1 << 4, + RetainForwardRule::Never => options |= 2 << 4, + } + + write_mqtt_string(buffer, self.path.as_str()); + buffer.put_u8(options); + } + } + + #[derive(Debug, Clone, PartialEq)] + pub struct SubscribeProperties { + pub id: Option, + pub user_properties: Vec<(String, String)>, + } + + impl SubscribeProperties { + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(id) = &self.id { + len += 1 + len_len(*id); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut id = None; + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::SubscriptionIdentifier => { + let (id_len, sub_id) = length(bytes.iter())?; + // TODO: Validate 1 +. Tests are working either way + cursor += 1 + id_len; + bytes.advance(id_len); + id = Some(sub_id) + } + PropertyType::UserProperty => { + let key = read_mqtt_bytes(&mut bytes)?; + let key = std::str::from_utf8(&key)?.to_owned(); + let value = read_mqtt_bytes(&mut bytes)?; + let value = std::str::from_utf8(&value)?.to_owned(); + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(SubscribeProperties { + id, + user_properties, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(id) = &self.id { + buffer.put_u8(PropertyType::SubscriptionIdentifier as u8); + write_remaining_length(buffer, *id)?; + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } + } + + #[derive(Debug, Clone, PartialEq)] + pub enum RetainForwardRule { + OnEverySubscribe, + OnNewSubscribe, + Never, + } + + impl fmt::Debug for Subscribe { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "Filters = {:?}, Packet id = {:?}", + self.filters, self.pkid + ) + } + } + + impl fmt::Debug for SubscribeFilter { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "Filter = {}, Qos = {:?}, Nolocal = {}, Preserve retain = {}, Forward rule = {:?}", + self.path, self.qos, self.nolocal, self.preserve_retain, self.retain_forward_rule + ) + } + } +} + +pub(crate) mod suback { + use std::convert::{TryFrom, TryInto}; + + use super::*; + use bytes::{Buf, Bytes}; + + /// Acknowledgement to subscribe + #[derive(Debug, Clone, PartialEq)] + pub struct SubAck { + pub pkid: u16, + pub return_codes: Vec, + pub properties: Option, + } + + impl SubAck { + pub fn new(pkid: u16, return_codes: Vec) -> SubAck { + SubAck { + pkid, + return_codes, + properties: None, + } + } + + pub fn len(&self) -> usize { + let mut len = 2 + self.return_codes.len(); + + match &self.properties { + Some(properties) => { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + None => { + // just 1 byte representing 0 len + len += 1; + } + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + let pkid = read_u16(&mut bytes)?; + let properties = SubAckProperties::extract(&mut bytes)?; + + if !bytes.has_remaining() { + return Err(Error::MalformedPacket); + } + + let mut return_codes = Vec::new(); + while bytes.has_remaining() { + let return_code = read_u8(&mut bytes)?; + return_codes.push(return_code.try_into()?); + } + + let suback = SubAck { + pkid, + return_codes, + properties, + }; + + Ok(suback) + } + } + + pub fn write( + return_codes: Vec, + pkid: u16, + properties: Option, + buffer: &mut BytesMut, + ) -> Result { + buffer.put_u8(0x90); + + let mut len = 2 + return_codes.len(); + + match &properties { + Some(properties) => { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + None => { + // just 1 byte representing 0 len + len += 1; + } + } + + let remaining_len = len; + let remaining_len_bytes = write_remaining_length(buffer, remaining_len)?; + + buffer.put_u16(pkid); + + match &properties { + Some(properties) => properties.write(buffer)?, + None => { + write_remaining_length(buffer, 0)?; + } + }; + + let p: Vec = return_codes.iter().map(|code| *code as u8).collect(); + buffer.extend_from_slice(&p); + Ok(1 + remaining_len_bytes + remaining_len) + } + + #[derive(Debug, Clone, PartialEq)] + pub struct SubAckProperties { + pub reason_string: Option, + pub user_properties: Vec<(String, String)>, + } + + impl SubAckProperties { + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(reason) = &self.reason_string { + len += 1 + 2 + reason.len(); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut reason_string = None; + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::ReasonString => { + let bytes = read_mqtt_bytes(&mut bytes)?; + let reason = std::str::from_utf8(&bytes)?.to_owned(); + cursor += 2 + reason.len(); + reason_string = Some(reason); + } + PropertyType::UserProperty => { + let key = read_mqtt_bytes(&mut bytes)?; + let key = std::str::from_utf8(&key)?.to_owned(); + let value = read_mqtt_bytes(&mut bytes)?; + let value = std::str::from_utf8(&value)?.to_owned(); + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(SubAckProperties { + reason_string, + user_properties, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(reason) = &self.reason_string { + buffer.put_u8(PropertyType::ReasonString as u8); + write_mqtt_string(buffer, reason); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } + } + + #[derive(Debug, Clone, Copy, PartialEq, Eq)] + pub enum SubscribeReasonCode { + QoS0 = 0, + QoS1 = 1, + QoS2 = 2, + Unspecified = 128, + ImplementationSpecific = 131, + NotAuthorized = 135, + TopicFilterInvalid = 143, + PkidInUse = 145, + QuotaExceeded = 151, + SharedSubscriptionsNotSupported = 158, + SubscriptionIdNotSupported = 161, + WildcardSubscriptionsNotSupported = 162, + } + + impl TryFrom for SubscribeReasonCode { + type Error = Error; + + fn try_from(value: u8) -> Result { + let v = match value { + 0 => SubscribeReasonCode::QoS0, + 1 => SubscribeReasonCode::QoS1, + 2 => SubscribeReasonCode::QoS2, + 128 => SubscribeReasonCode::Unspecified, + 131 => SubscribeReasonCode::ImplementationSpecific, + 135 => SubscribeReasonCode::NotAuthorized, + 143 => SubscribeReasonCode::TopicFilterInvalid, + 145 => SubscribeReasonCode::PkidInUse, + 151 => SubscribeReasonCode::QuotaExceeded, + 158 => SubscribeReasonCode::SharedSubscriptionsNotSupported, + 161 => SubscribeReasonCode::SubscriptionIdNotSupported, + 162 => SubscribeReasonCode::WildcardSubscriptionsNotSupported, + v => return Err(Error::InvalidSubscribeReasonCode(v)), + }; + + Ok(v) + } + } + + pub fn codes(c: Vec) -> Vec { + c.into_iter() + .map(|v| match qos(v).unwrap() { + QoS::AtMostOnce => SubscribeReasonCode::QoS0, + QoS::AtLeastOnce => SubscribeReasonCode::QoS1, + }) + .collect() + } +} + +pub(crate) mod pingresp { + use super::*; + + pub fn write(payload: &mut BytesMut) -> Result { + payload.put_slice(&[0xD0, 0x00]); + Ok(2) + } +} + +/// Reads a stream of bytes and extracts next MQTT packet out of it +pub fn read_mut(stream: &mut BytesMut, max_size: usize) -> Result { + let fixed_header = check(stream.iter(), max_size)?; + + // Test with a stream with exactly the size to check border panics + let packet = stream.split_to(fixed_header.frame_length()); + let packet_type = fixed_header.packet_type()?; + + if fixed_header.remaining_len == 0 { + // no payload packets + return match packet_type { + PacketType::PingReq => Ok(Packet::PingReq), + PacketType::PingResp => Ok(Packet::PingResp), + PacketType::Disconnect => Ok(Packet::Disconnect), + _ => Err(Error::PayloadRequired), + }; + } + + let packet = packet.freeze(); + let packet = match packet_type { + PacketType::Connect => Packet::Connect(connect::Connect::read(fixed_header, packet)?), + PacketType::ConnAck => Packet::ConnAck(connack::ConnAck::read(fixed_header, packet)?), + PacketType::Publish => Packet::Publish(publish::Publish::read(fixed_header, packet)?), + PacketType::PubAck => Packet::PubAck(puback::PubAck::read(fixed_header, packet)?), + PacketType::Subscribe => { + Packet::Subscribe(subscribe::Subscribe::read(fixed_header, packet)?) + } + PacketType::SubAck => Packet::SubAck(suback::SubAck::read(fixed_header, packet)?), + PacketType::PingReq => Packet::PingReq, + PacketType::PingResp => Packet::PingResp, + PacketType::Disconnect => Packet::Disconnect, + v => return Err(Error::UnsupportedPacket(v)), + }; + + Ok(packet) +} + +/// Reads a stream of bytes and extracts next MQTT packet out of it +pub fn read(stream: &mut Bytes, max_size: usize) -> Result { + let fixed_header = check(stream.iter(), max_size)?; + + // Test with a stream with exactly the size to check border panics + let packet = stream.split_to(fixed_header.frame_length()); + let packet_type = fixed_header.packet_type()?; + + if fixed_header.remaining_len == 0 { + // no payload packets + return match packet_type { + PacketType::PingReq => Ok(Packet::PingReq), + PacketType::PingResp => Ok(Packet::PingResp), + PacketType::Disconnect => Ok(Packet::Disconnect), + _ => Err(Error::PayloadRequired), + }; + } + + let packet = match packet_type { + PacketType::Connect => Packet::Connect(connect::Connect::read(fixed_header, packet)?), + PacketType::ConnAck => Packet::ConnAck(connack::ConnAck::read(fixed_header, packet)?), + PacketType::Publish => Packet::Publish(publish::Publish::read(fixed_header, packet)?), + PacketType::PubAck => Packet::PubAck(puback::PubAck::read(fixed_header, packet)?), + PacketType::Subscribe => { + Packet::Subscribe(subscribe::Subscribe::read(fixed_header, packet)?) + } + PacketType::SubAck => Packet::SubAck(suback::SubAck::read(fixed_header, packet)?), + PacketType::PingReq => Packet::PingReq, + PacketType::PingResp => Packet::PingResp, + PacketType::Disconnect => Packet::Disconnect, + v => return Err(Error::UnsupportedPacket(v)), + }; + + Ok(packet) +} + +#[derive(Clone, Debug, PartialEq)] +pub enum Packet { + Connect(connect::Connect), + Publish(publish::Publish), + ConnAck(connack::ConnAck), + PubAck(puback::PubAck), + PingReq, + PingResp, + Subscribe(subscribe::Subscribe), + SubAck(suback::SubAck), + Disconnect, +} + +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum PropertyType { + PayloadFormatIndicator = 1, + MessageExpiryInterval = 2, + ContentType = 3, + ResponseTopic = 8, + CorrelationData = 9, + SubscriptionIdentifier = 11, + SessionExpiryInterval = 17, + AssignedClientIdentifier = 18, + ServerKeepAlive = 19, + AuthenticationMethod = 21, + AuthenticationData = 22, + RequestProblemInformation = 23, + WillDelayInterval = 24, + RequestResponseInformation = 25, + ResponseInformation = 26, + ServerReference = 28, + ReasonString = 31, + ReceiveMaximum = 33, + TopicAliasMaximum = 34, + TopicAlias = 35, + MaximumQos = 36, + RetainAvailable = 37, + UserProperty = 38, + MaximumPacketSize = 39, + WildcardSubscriptionAvailable = 40, + SubscriptionIdentifierAvailable = 41, + SharedSubscriptionAvailable = 42, +} + +fn property(num: u8) -> Result { + let property = match num { + 1 => PropertyType::PayloadFormatIndicator, + 2 => PropertyType::MessageExpiryInterval, + 3 => PropertyType::ContentType, + 8 => PropertyType::ResponseTopic, + 9 => PropertyType::CorrelationData, + 11 => PropertyType::SubscriptionIdentifier, + 17 => PropertyType::SessionExpiryInterval, + 18 => PropertyType::AssignedClientIdentifier, + 19 => PropertyType::ServerKeepAlive, + 21 => PropertyType::AuthenticationMethod, + 22 => PropertyType::AuthenticationData, + 23 => PropertyType::RequestProblemInformation, + 24 => PropertyType::WillDelayInterval, + 25 => PropertyType::RequestResponseInformation, + 26 => PropertyType::ResponseInformation, + 28 => PropertyType::ServerReference, + 31 => PropertyType::ReasonString, + 33 => PropertyType::ReceiveMaximum, + 34 => PropertyType::TopicAliasMaximum, + 35 => PropertyType::TopicAlias, + 36 => PropertyType::MaximumQos, + 37 => PropertyType::RetainAvailable, + 38 => PropertyType::UserProperty, + 39 => PropertyType::MaximumPacketSize, + 40 => PropertyType::WildcardSubscriptionAvailable, + 41 => PropertyType::SubscriptionIdentifierAvailable, + 42 => PropertyType::SharedSubscriptionAvailable, + num => return Err(Error::InvalidPropertyType(num)), + }; + + Ok(property) +} From f2394675355627167d94d577c473b61ba754f9a2 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Sun, 6 Mar 2022 19:53:26 +0530 Subject: [PATCH 15/38] Make rustls optional --- rumqttc/Cargo.toml | 6 ++++-- rumqttc/src/eventloop.rs | 9 ++++++--- rumqttc/src/lib.rs | 9 +++++++++ 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/rumqttc/Cargo.toml b/rumqttc/Cargo.toml index 54730e835..2483d62be 100644 --- a/rumqttc/Cargo.toml +++ b/rumqttc/Cargo.toml @@ -14,13 +14,15 @@ all-features = true rustdoc-args = ["--cfg", "docsrs"] [features] -websocket = ["async-tungstenite", "ws_stream_tungstenite"] +default = ["use-rustls"] +websocket = ["async-tungstenite", "ws_stream_tungstenite", "use-rustls"] +use-rustls = ["tokio-rustls"] [dependencies] tokio = { version = "1.0", features = ["rt", "macros", "io-util", "net", "time"] } bytes = "1.0" webpki = "0.22.0" -tokio-rustls = "0.23.2" +tokio-rustls = { version = "0.23.2", optional = true } rustls-pemfile = "0.3.0" async-tungstenite = { version = "0.16.1", default-features = false, features = ["tokio-rustls-native-certs"], optional = true } ws_stream_tungstenite = { version = "0.7.0", default-features = false, features = ["tokio_io"], optional = true } diff --git a/rumqttc/src/eventloop.rs b/rumqttc/src/eventloop.rs index a98893171..1b1a79f8a 100644 --- a/rumqttc/src/eventloop.rs +++ b/rumqttc/src/eventloop.rs @@ -1,6 +1,7 @@ -use crate::{framed::Network, Transport}; -use crate::{tls, Incoming, MqttState, Packet, Request, StateError}; -use crate::{MqttOptions, Outgoing}; +use crate::framed::Network; +#[cfg(feature = "use-rustls")] +use crate::tls; +use crate::{Incoming, MqttOptions, MqttState, Outgoing, Packet, Request, StateError, Transport}; use crate::mqttbytes; use crate::mqttbytes::v4::*; @@ -31,6 +32,7 @@ pub enum ConnectionError { Timeout(#[from] Elapsed), #[error("Packet parsing error: {0}")] Mqtt4Bytes(mqttbytes::Error), + #[cfg(feature = "use-rustls")] #[error("Network: {0}")] Network(#[from] tls::Error), #[error("I/O: {0}")] @@ -274,6 +276,7 @@ async fn network_connect(options: &MqttOptions) -> Result { let socket = tls::tls_connect(options, &tls_config).await?; Network::new(socket, options.max_incoming_packet_size) diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index 364f651fd..91ae8ae96 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -100,6 +100,7 @@ extern crate log; use std::fmt::{self, Debug, Formatter}; +#[cfg(feature = "use-rustls")] use std::sync::Arc; use std::time::Duration; @@ -108,6 +109,7 @@ mod eventloop; mod framed; pub mod mqttbytes; mod state; +#[cfg(feature = "use-rustls")] mod tls; pub use async_channel::{SendError, Sender, TrySendError}; @@ -116,7 +118,9 @@ pub use eventloop::{ConnectionError, Event, EventLoop}; pub use mqttbytes::v4::*; pub use mqttbytes::*; pub use state::{MqttState, StateError}; +#[cfg(feature = "use-rustls")] pub use tls::Error; +#[cfg(feature = "use-rustls")] pub use tokio_rustls::rustls::ClientConfig; pub type Incoming = Packet; @@ -194,6 +198,7 @@ impl From for Request { #[derive(Clone)] pub enum Transport { Tcp, + #[cfg(feature = "use-rustls")] Tls(TlsConfiguration), #[cfg(unix)] Unix, @@ -218,6 +223,7 @@ impl Transport { } /// Use secure tcp with tls as transport + #[cfg(feature = "use-rustls")] pub fn tls( ca: Vec, client_auth: Option<(Vec, Key)>, @@ -232,6 +238,7 @@ impl Transport { Self::tls_with_config(config) } + #[cfg(feature = "use-rustls")] pub fn tls_with_config(tls_config: TlsConfiguration) -> Self { Self::Tls(tls_config) } @@ -273,6 +280,7 @@ impl Transport { } #[derive(Clone)] +#[cfg(feature = "use-rustls")] pub enum TlsConfiguration { Simple { /// connection method @@ -286,6 +294,7 @@ pub enum TlsConfiguration { Rustls(Arc), } +#[cfg(feature = "use-rustls")] impl From for TlsConfiguration { fn from(config: ClientConfig) -> Self { TlsConfiguration::Rustls(Arc::new(config)) From d10fd408bebe72344a311687e7ddb540e0dc154b Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Sun, 6 Mar 2022 20:24:04 +0530 Subject: [PATCH 16/38] Enable WS w/o TLS; fix breaking examples --- rumqttc/Cargo.toml | 2 +- rumqttc/examples/tls.rs | 12 +++++++++--- rumqttc/examples/tls2.rs | 10 ++++++++-- rumqttc/src/eventloop.rs | 6 ++++-- rumqttc/src/lib.rs | 14 +++++++------- 5 files changed, 29 insertions(+), 15 deletions(-) diff --git a/rumqttc/Cargo.toml b/rumqttc/Cargo.toml index 2483d62be..af57dd083 100644 --- a/rumqttc/Cargo.toml +++ b/rumqttc/Cargo.toml @@ -15,7 +15,7 @@ rustdoc-args = ["--cfg", "docsrs"] [features] default = ["use-rustls"] -websocket = ["async-tungstenite", "ws_stream_tungstenite", "use-rustls"] +websocket = ["async-tungstenite", "ws_stream_tungstenite"] use-rustls = ["tokio-rustls"] [dependencies] diff --git a/rumqttc/examples/tls.rs b/rumqttc/examples/tls.rs index 2bb1b2272..0e25860e3 100644 --- a/rumqttc/examples/tls.rs +++ b/rumqttc/examples/tls.rs @@ -1,11 +1,12 @@ //! Example of how to configure rumqttd to connect to a server using TLS and authentication. - -use rumqttc::{self, AsyncClient, Event, Incoming, MqttOptions, Transport}; -use rustls::ClientConfig; use std::error::Error; +#[cfg(feature = "use-rustls")] #[tokio::main] async fn main() -> Result<(), Box> { + use rumqttc::{self, AsyncClient, Event, Incoming, MqttOptions, Transport}; + use rustls::ClientConfig; + pretty_env_logger::init(); color_backtrace::install(); @@ -43,3 +44,8 @@ async fn main() -> Result<(), Box> { } } } + +#[cfg(not(feature = "use-rustls"))] +fn main() -> Result<(), Box> { + panic!("Enable feature 'use-rustls'"); +} diff --git a/rumqttc/examples/tls2.rs b/rumqttc/examples/tls2.rs index c6df58e85..496a806a0 100644 --- a/rumqttc/examples/tls2.rs +++ b/rumqttc/examples/tls2.rs @@ -1,10 +1,11 @@ //! Example of how to configure rumqttd to connect to a server using TLS and authentication. - -use rumqttc::{self, AsyncClient, Key, MqttOptions, TlsConfiguration, Transport}; use std::error::Error; +#[cfg(feature = "use-rustls")] #[tokio::main] async fn main() -> Result<(), Box> { + use rumqttc::{self, AsyncClient, Key, MqttOptions, TlsConfiguration, Transport}; + pretty_env_logger::init(); color_backtrace::install(); @@ -43,3 +44,8 @@ async fn main() -> Result<(), Box> { Ok(()) } + +#[cfg(not(feature = "use-rustls"))] +fn main() -> Result<(), Box> { + panic!("Enable feature 'use-rustls'"); +} diff --git a/rumqttc/src/eventloop.rs b/rumqttc/src/eventloop.rs index 1b1a79f8a..711c167cb 100644 --- a/rumqttc/src/eventloop.rs +++ b/rumqttc/src/eventloop.rs @@ -7,7 +7,9 @@ use crate::mqttbytes; use crate::mqttbytes::v4::*; use async_channel::{bounded, Receiver, Sender}; #[cfg(feature = "websocket")] -use async_tungstenite::tokio::{connect_async, connect_async_with_tls_connector}; +use async_tungstenite::tokio::connect_async; +#[cfg(all(feature = "use-rustls", feature = "websocket"))] +use async_tungstenite::tokio::connect_async_with_tls_connector; use tokio::net::TcpStream; #[cfg(unix)] use tokio::net::UnixStream; @@ -302,7 +304,7 @@ async fn network_connect(options: &MqttOptions) -> Result { let request = http::Request::builder() .method(http::Method::GET) diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index 91ae8ae96..328ce11c1 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -205,8 +205,8 @@ pub enum Transport { #[cfg(feature = "websocket")] #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] Ws, - #[cfg(feature = "websocket")] - #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + #[cfg(all(feature = "use-rustls", feature = "websocket"))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "use-rustls", feature = "websocket"))))] Wss(TlsConfiguration), } @@ -256,8 +256,8 @@ impl Transport { } /// Use secure websockets with tls as transport - #[cfg(feature = "websocket")] - #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + #[cfg(all(feature = "use-rustls", feature = "websocket"))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "use-rustls", feature = "websocket"))))] pub fn wss( ca: Vec, client_auth: Option<(Vec, Key)>, @@ -272,8 +272,8 @@ impl Transport { Self::wss_with_config(config) } - #[cfg(feature = "websocket")] - #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + #[cfg(all(feature = "use-rustls", feature = "websocket"))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "use-rustls", feature = "websocket"))))] pub fn wss_with_config(tls_config: TlsConfiguration) -> Self { Self::Wss(tls_config) } @@ -724,7 +724,7 @@ mod test { } #[test] - #[cfg(feature = "websocket")] + #[cfg(all(feature = "use-rustls", feature = "websocket"))] fn no_scheme() { let mut _mqtt_opts = MqttOptions::new("client_a", "a3f8czas.iot.eu-west-1.amazonaws.com/mqtt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=MyCreds%2F20201001%2Feu-west-1%2Fiotdevicegateway%2Faws4_request&X-Amz-Date=20201001T130812Z&X-Amz-Expires=7200&X-Amz-Signature=9ae09b49896f44270f2707551581953e6cac71a4ccf34c7c3415555be751b2d1&X-Amz-SignedHeaders=host", 443); From 279805a33798cac7fa6e397998371bde5511ec62 Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Mon, 7 Mar 2022 12:27:16 +0530 Subject: [PATCH 17/38] cmacos test Signed-off-by: Abhik Jain --- rumqttc/tests/reliability.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rumqttc/tests/reliability.rs b/rumqttc/tests/reliability.rs index 3a129b6b8..db41af29d 100644 --- a/rumqttc/tests/reliability.rs +++ b/rumqttc/tests/reliability.rs @@ -360,7 +360,7 @@ async fn packet_id_collisions_are_detected_and_flow_control_is_applied() { if ack == 1 { let elapsed = start.elapsed().as_millis() as i64; let deviation_millis: i64 = (5000 - elapsed).abs(); - assert!(deviation_millis < 10); + assert!(deviation_millis < 100); break; } } From 2e012b5cf203ea8eef05dae0d539ab4857351169 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Wed, 9 Mar 2022 18:46:21 +0530 Subject: [PATCH 18/38] Make rustls_pemfile optional & rm webpki --- Cargo.lock | 1 - rumqttc/Cargo.toml | 5 ++--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bdefe09ff..78b32f695 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1685,7 +1685,6 @@ dependencies = [ "tokio", "tokio-rustls", "url", - "webpki", "ws_stream_tungstenite", ] diff --git a/rumqttc/Cargo.toml b/rumqttc/Cargo.toml index af57dd083..e9951d144 100644 --- a/rumqttc/Cargo.toml +++ b/rumqttc/Cargo.toml @@ -16,14 +16,13 @@ rustdoc-args = ["--cfg", "docsrs"] [features] default = ["use-rustls"] websocket = ["async-tungstenite", "ws_stream_tungstenite"] -use-rustls = ["tokio-rustls"] +use-rustls = ["tokio-rustls", "rustls-pemfile"] [dependencies] tokio = { version = "1.0", features = ["rt", "macros", "io-util", "net", "time"] } bytes = "1.0" -webpki = "0.22.0" tokio-rustls = { version = "0.23.2", optional = true } -rustls-pemfile = "0.3.0" +rustls-pemfile = { version = "0.3.0", optional = true } async-tungstenite = { version = "0.16.1", default-features = false, features = ["tokio-rustls-native-certs"], optional = true } ws_stream_tungstenite = { version = "0.7.0", default-features = false, features = ["tokio_io"], optional = true } pollster = "0.2" From 9eaa6470a9ef1259945432a0cf5283b7a7e107e2 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Wed, 16 Mar 2022 17:28:03 +0530 Subject: [PATCH 19/38] Fix use of "wont fix" label by bot --- .github/stale.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/stale.yml b/.github/stale.yml index 37158a003..beed37682 100644 --- a/.github/stale.yml +++ b/.github/stale.yml @@ -5,6 +5,8 @@ exemptLabels: - "help wanted" - "bug" +staleLabel: "stale" + markComment: > This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you From 76916b08520f6eced5b0039a5e467cea8f215296 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Wed, 16 Mar 2022 17:36:44 +0530 Subject: [PATCH 20/38] Don't close PR/issue or comment --- .github/stale.yml | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/.github/stale.yml b/.github/stale.yml index beed37682..09f7b78a9 100644 --- a/.github/stale.yml +++ b/.github/stale.yml @@ -1,13 +1,8 @@ daysUntilStale: 20 - +staleLabel: "stale" +daysUntilClose: false +markComment: false exemptLabels: - "in-pipeline" - "help wanted" - "bug" - -staleLabel: "stale" - -markComment: > - This issue has been automatically marked as stale because it has not had - recent activity. It will be closed if no further activity occurs. Thank you - for your contributions. From a149b77f373f2a94425e05b8c80e40c2a878611a Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Fri, 18 Mar 2022 00:54:09 +0530 Subject: [PATCH 21/38] rumqttc: v5: separate pub/sub client option Signed-off-by: Abhik Jain --- rumqttc/src/v5/client/asyncclient.rs | 274 +++++++++++++++++++++++++++ rumqttc/src/v5/client/mod.rs | 102 ++++++++++ rumqttc/src/v5/client/publisher.rs | 98 ++++++++++ rumqttc/src/v5/client/subscriber.rs | 121 ++++++++++++ rumqttc/src/v5/client/syncclient.rs | 135 +++++++++++++ rumqttc/src/v5/eventloop.rs | 8 +- rumqttc/src/v5/mod.rs | 10 + rumqttc/src/v5/state.rs | 21 +- 8 files changed, 761 insertions(+), 8 deletions(-) create mode 100644 rumqttc/src/v5/client/asyncclient.rs create mode 100644 rumqttc/src/v5/client/mod.rs create mode 100644 rumqttc/src/v5/client/publisher.rs create mode 100644 rumqttc/src/v5/client/subscriber.rs create mode 100644 rumqttc/src/v5/client/syncclient.rs diff --git a/rumqttc/src/v5/client/asyncclient.rs b/rumqttc/src/v5/client/asyncclient.rs new file mode 100644 index 000000000..a00b295ef --- /dev/null +++ b/rumqttc/src/v5/client/asyncclient.rs @@ -0,0 +1,274 @@ +use std::{ + collections::VecDeque, + sync::{Arc, Mutex}, +}; + +use bytes::Bytes; +use flume::{SendError, Sender, TrySendError}; + +use crate::v5::{ + client::{get_ack_req, Publisher, Subscriber}, + packet::{Publish, Subscribe, SubscribeFilter, Unsubscribe}, + ClientError, EventLoop, MqttOptions, QoS, Request, +}; + +/// `AsyncClient` to communicate with MQTT `Eventloop` +/// This is cloneable and can be used to asynchronously Publish, Subscribe. +#[derive(Clone, Debug)] +pub struct AsyncClient { + request_buf: Arc>>, + sub_events_buf: Arc>>, + sub_events_buf_cache: VecDeque, + request_buf_capacity: usize, + request_tx: Sender<()>, + pub(crate) cancel_tx: Sender<()>, +} + +impl AsyncClient { + /// Create a new `AsyncClient` + pub fn new(options: MqttOptions, cap: usize) -> (AsyncClient, EventLoop) { + let mut eventloop = EventLoop::new(options, cap); + let request_buf = eventloop.request_buf().clone(); + let sub_events_buf = eventloop.sub_events_buf().clone(); + let sub_events_buf_cache = VecDeque::with_capacity(cap); + let request_tx = eventloop.handle(); + let cancel_tx = eventloop.cancel_handle(); + + let client = AsyncClient { + request_buf, + request_buf_capacity: cap, + sub_events_buf, + sub_events_buf_cache, + request_tx, + cancel_tx, + }; + + (client, eventloop) + } + + /// Create a new `AsyncClient` from a pair of async channel `Sender`s. This is mostly useful for + /// creating a test instance. + pub fn from_senders( + request_buf: Arc>>, + sub_events_buf: Arc>>, + request_tx: Sender<()>, + cancel_tx: Sender<()>, + cap: usize, + ) -> AsyncClient { + AsyncClient { + request_buf, + request_buf_capacity: cap, + sub_events_buf, + sub_events_buf_cache: VecDeque::with_capacity(cap), + request_tx, + cancel_tx, + } + } + + /// Sends a MQTT Publish to the eventloop + pub async fn publish( + &self, + topic: S, + qos: QoS, + retain: bool, + payload: V, + ) -> Result + where + S: Into, + V: Into>, + { + let mut publish = Publish::new(topic, qos, payload); + publish.retain = retain; + let pkid = publish.pkid; + self.send_async_and_notify(Request::Publish(publish)) + .await?; + Ok(pkid) + } + + /// Sends a MQTT Publish to the eventloop + pub fn try_publish( + &self, + topic: S, + qos: QoS, + retain: bool, + payload: V, + ) -> Result + where + S: Into, + V: Into>, + { + let mut publish = Publish::new(topic, qos, payload); + publish.retain = retain; + let pkid = publish.pkid; + self.try_send_and_notify(Request::Publish(publish))?; + Ok(pkid) + } + + /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. + pub async fn ack(&self, publish: &Publish) -> Result<(), ClientError> { + if let Some(ack) = get_ack_req(publish.qos, publish.pkid) { + self.send_async_and_notify(ack).await?; + } + Ok(()) + } + + /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. + pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { + if let Some(ack) = get_ack_req(publish.qos, publish.pkid) { + self.try_send_and_notify(ack)?; + } + Ok(()) + } + + /// Sends a MQTT Publish to the eventloop + pub async fn publish_bytes( + &self, + topic: S, + qos: QoS, + retain: bool, + payload: Bytes, + ) -> Result<(), ClientError> + where + S: Into, + { + let mut publish = Publish::from_bytes(topic, qos, payload); + publish.retain = retain; + self.send_async_and_notify(Request::Publish(publish)).await + } + + /// Sends a MQTT Subscribe to the eventloop + pub async fn subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { + let subscribe = Subscribe::new(topic.into(), qos); + self.send_async_and_notify(Request::Subscribe(subscribe)) + .await + } + + /// Sends a MQTT Subscribe to the eventloop + pub fn try_subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { + let subscribe = Subscribe::new(topic.into(), qos); + self.try_send_and_notify(Request::Subscribe(subscribe)) + } + + /// Sends a MQTT Subscribe for multiple topics to the eventloop + pub async fn subscribe_many(&self, topics: T) -> Result<(), ClientError> + where + T: IntoIterator, + { + let subscribe = Subscribe::new_many(topics); + self.send_async_and_notify(Request::Subscribe(subscribe)) + .await + } + + /// Sends a MQTT Subscribe for multiple topics to the eventloop + pub fn try_subscribe_many(&self, topics: T) -> Result<(), ClientError> + where + T: IntoIterator, + { + let subscribe = Subscribe::new_many(topics); + self.try_send_and_notify(Request::Subscribe(subscribe)) + } + + /// Sends a MQTT Unsubscribe to the eventloop + pub async fn unsubscribe>(&self, topic: S) -> Result<(), ClientError> { + let unsubscribe = Unsubscribe::new(topic.into()); + self.send_async_and_notify(Request::Unsubscribe(unsubscribe)) + .await + } + + /// Sends a MQTT Unsubscribe to the eventloop + pub fn try_unsubscribe>(&self, topic: S) -> Result<(), ClientError> { + let unsubscribe = Unsubscribe::new(topic.into()); + self.try_send_and_notify(Request::Unsubscribe(unsubscribe)) + } + + /// Sends a MQTT disconnect to the eventloop + pub async fn disconnect(&self) -> Result<(), ClientError> { + self.send_async_and_notify(Request::Disconnect).await + } + + /// Sends a MQTT disconnect to the eventloop + pub fn try_disconnect(&self) -> Result<(), ClientError> { + self.try_send_and_notify(Request::Disconnect) + } + + /// Stops the eventloop right away + pub async fn cancel(&self) -> Result<(), ClientError> { + self.cancel_tx + .send_async(()) + .await + .map_err(ClientError::Cancel) + } + + async fn send_async_and_notify(&self, request: Request) -> Result<(), ClientError> { + { + let mut request_buf = self.request_buf.lock().unwrap(); + if request_buf.len() == self.request_buf_capacity { + return Err(ClientError::RequestsFull); + } + request_buf.push_back(request); + } + if let Err(SendError(_)) = self.request_tx.send_async(()).await { + return Err(ClientError::EventloopClosed); + }; + Ok(()) + } + + pub(crate) fn send_and_notify(&self, request: Request) -> Result<(), ClientError> { + let mut request_buf = self.request_buf.lock().unwrap(); + if request_buf.len() == self.request_buf_capacity { + return Err(ClientError::RequestsFull); + } + request_buf.push_back(request); + if let Err(SendError(_)) = self.request_tx.send(()) { + return Err(ClientError::EventloopClosed); + }; + Ok(()) + } + + fn try_send_and_notify(&self, request: Request) -> Result<(), ClientError> { + let mut request_buf = self.request_buf.lock().unwrap(); + if request_buf.len() == self.request_buf_capacity { + return Err(ClientError::RequestsFull); + } + request_buf.push_back(request); + if let Err(TrySendError::Disconnected(_)) = self.request_tx.try_send(()) { + return Err(ClientError::EventloopClosed); + } + Ok(()) + } + + pub fn next_publish(&mut self) -> Option { + if let Some(publish) = self.sub_events_buf_cache.pop_front() { + return Some(publish); + } + + std::mem::swap( + &mut self.sub_events_buf_cache, + &mut *self.sub_events_buf.lock().unwrap(), + ); + self.sub_events_buf_cache.pop_front() + } + + pub async fn split( + self, + publish_topic: impl Into, + publish_qos: QoS, + ) -> Result<(Publisher, Subscriber), ClientError> { + let publisher = Publisher { + request_buf: self.request_buf.clone(), + request_buf_capacity: self.request_buf_capacity, + request_tx: self.request_tx.clone(), + cancel_tx: self.cancel_tx.clone(), + publish_topic: publish_topic.into(), + publish_qos, + }; + let subscriber = Subscriber { + request_buf: self.request_buf, + sub_events_buf: self.sub_events_buf, + sub_events_buf_cache: self.sub_events_buf_cache, + request_buf_capacity: self.request_buf_capacity, + request_tx: self.request_tx, + }; + Ok((publisher, subscriber)) + } +} diff --git a/rumqttc/src/v5/client/mod.rs b/rumqttc/src/v5/client/mod.rs new file mode 100644 index 000000000..3156126df --- /dev/null +++ b/rumqttc/src/v5/client/mod.rs @@ -0,0 +1,102 @@ +//! This module offers a high level synchronous and asynchronous abstraction to +//! async eventloop. +use crate::v5::{packet::*, ConnectionError, Event, EventLoop, Request}; + +use flume::SendError; +use std::mem; +use tokio::runtime::{self, Runtime}; + +mod asyncclient; +pub use asyncclient::AsyncClient; +mod publisher; +pub use publisher::Publisher; +mod subscriber; +pub use subscriber::Subscriber; +mod syncclient; +pub use syncclient::Client; + +/// Client Error +#[derive(Debug, thiserror::Error)] +pub enum ClientError { + #[error("Failed to send cancel request to eventloop")] + Cancel(SendError<()>), + #[error("Failed to send mqtt request to eventloop, the evenloop has been closed")] + EventloopClosed, + #[error("Failed to send mqtt request to evenloop, to requests buffer is full right now")] + RequestsFull, + #[error("Serialization error")] + Mqtt5(Error), +} + +fn get_ack_req(qos: QoS, pkid: u16) -> Option { + let ack = match qos { + QoS::AtMostOnce => return None, + QoS::AtLeastOnce => Request::PubAck(PubAck::new(pkid)), + QoS::ExactlyOnce => Request::PubRec(PubRec::new(pkid)), + }; + Some(ack) +} + + +/// MQTT connection. Maintains all the necessary state +pub struct Connection { + pub eventloop: EventLoop, + runtime: Option, +} + +impl Connection { + fn new(eventloop: EventLoop, runtime: Runtime) -> Connection { + Connection { + eventloop, + runtime: Some(runtime), + } + } + + /// Returns an iterator over this connection. Iterating over this is all that's + /// necessary to make connection progress and maintain a robust connection. + /// Just continuing to loop will reconnect + /// **NOTE** Don't block this while iterating + #[must_use = "Connection should be iterated over a loop to make progress"] + pub fn iter(&mut self) -> Iter { + let runtime = self.runtime.take().unwrap(); + Iter { + connection: self, + runtime, + } + } +} + +/// Iterator which polls the eventloop for connection progress +pub struct Iter<'a> { + connection: &'a mut Connection, + runtime: runtime::Runtime, +} + +impl<'a> Iterator for Iter<'a> { + type Item = Result; + + fn next(&mut self) -> Option { + let f = self.connection.eventloop.poll(); + match self.runtime.block_on(f) { + Ok(v) => Some(Ok(v)), + // closing of request channel should stop the iterator + Err(ConnectionError::RequestsDone) => { + trace!("Done with requests"); + None + } + Err(ConnectionError::Cancel) => { + trace!("Cancellation request received"); + None + } + Err(e) => Some(Err(e)), + } + } +} + +impl<'a> Drop for Iter<'a> { + fn drop(&mut self) { + // TODO: Don't create new runtime in drop + let runtime = runtime::Builder::new_current_thread().build().unwrap(); + self.connection.runtime = Some(mem::replace(&mut self.runtime, runtime)); + } +} diff --git a/rumqttc/src/v5/client/publisher.rs b/rumqttc/src/v5/client/publisher.rs new file mode 100644 index 000000000..ae47769f6 --- /dev/null +++ b/rumqttc/src/v5/client/publisher.rs @@ -0,0 +1,98 @@ +use std::{ + collections::VecDeque, + sync::{Arc, Mutex}, +}; + +use bytes::Bytes; +use flume::{SendError, Sender, TrySendError}; + +use crate::v5::{packet::Publish, ClientError, QoS, Request}; + +pub struct Publisher { + pub(crate) request_buf: Arc>>, + pub(crate) request_buf_capacity: usize, + pub(crate) request_tx: Sender<()>, + pub(crate) cancel_tx: Sender<()>, + pub(crate) publish_topic: String, + pub(crate) publish_qos: QoS, +} + +impl Publisher { + /// Sends a MQTT Publish to the eventloop + pub async fn publish( + &self, + retain: bool, + payload: impl Into>, + ) -> Result { + let mut publish = Publish::new(&self.publish_topic, self.publish_qos, payload); + publish.retain = retain; + let pkid = publish.pkid; + self.send_async_and_notify(Request::Publish(publish)) + .await?; + Ok(pkid) + } + + /// Sends a MQTT Publish to the eventloop + pub fn try_publish( + &self, + retain: bool, + payload: impl Into>, + ) -> Result { + let mut publish = Publish::new(&self.publish_topic, self.publish_qos, payload); + publish.retain = retain; + let pkid = publish.pkid; + self.try_send_and_notify(Request::Publish(publish))?; + Ok(pkid) + } + + /// Sends a MQTT Publish to the eventloop + pub async fn publish_bytes(&self, retain: bool, payload: Bytes) -> Result<(), ClientError> { + let mut publish = Publish::from_bytes(&self.publish_topic, self.publish_qos, payload); + publish.retain = retain; + self.send_async_and_notify(Request::Publish(publish)).await + } + + async fn send_async_and_notify(&self, request: Request) -> Result<(), ClientError> { + { + let mut request_buf = self.request_buf.lock().unwrap(); + if request_buf.len() == self.request_buf_capacity { + return Err(ClientError::RequestsFull); + } + request_buf.push_back(request); + } + if let Err(SendError(_)) = self.request_tx.send_async(()).await { + return Err(ClientError::EventloopClosed); + }; + Ok(()) + } + + /// Sends a MQTT disconnect to the eventloop + pub async fn disconnect(&self) -> Result<(), ClientError> { + self.send_async_and_notify(Request::Disconnect).await + } + + /// Sends a MQTT disconnect to the eventloop + pub fn try_disconnect(&self) -> Result<(), ClientError> { + self.try_send_and_notify(Request::Disconnect) + } + + /// Stops the eventloop right away + pub async fn cancel(&self) -> Result<(), ClientError> { + self.cancel_tx + .send_async(()) + .await + .map_err(ClientError::Cancel) + } + + fn try_send_and_notify(&self, request: Request) -> Result<(), ClientError> { + let mut request_buf = self.request_buf.lock().unwrap(); + if request_buf.len() == self.request_buf_capacity { + return Err(ClientError::RequestsFull); + } + request_buf.push_back(request); + if let Err(TrySendError::Disconnected(_)) = self.request_tx.try_send(()) { + return Err(ClientError::EventloopClosed); + } + Ok(()) + } +} diff --git a/rumqttc/src/v5/client/subscriber.rs b/rumqttc/src/v5/client/subscriber.rs new file mode 100644 index 000000000..6a7889f18 --- /dev/null +++ b/rumqttc/src/v5/client/subscriber.rs @@ -0,0 +1,121 @@ +use std::{ + collections::VecDeque, + sync::{Arc, Mutex}, +}; + +use flume::{SendError, Sender, TrySendError}; + +use crate::v5::{ + client::get_ack_req, ClientError, Publish, QoS, Request, Subscribe, SubscribeFilter, + Unsubscribe, +}; + +#[derive(Debug, Clone)] +pub struct Subscriber { + pub(crate) request_buf: Arc>>, + pub(crate) sub_events_buf: Arc>>, + pub(crate) sub_events_buf_cache: VecDeque, + pub(crate) request_buf_capacity: usize, + pub(crate) request_tx: Sender<()>, +} + +impl Subscriber { + pub fn next_publish(&mut self) -> Option { + if let Some(publish) = self.sub_events_buf_cache.pop_front() { + return Some(publish); + } + + std::mem::swap( + &mut self.sub_events_buf_cache, + &mut *self.sub_events_buf.lock().unwrap(), + ); + self.sub_events_buf_cache.pop_front() + } + + /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. + pub async fn ack(&self, qos: QoS, pkid: u16) -> Result<(), ClientError> { + if let Some(ack) = get_ack_req(qos, pkid) { + self.send_async_and_notify(ack).await?; + } + Ok(()) + } + + /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. + pub fn try_ack(&self, qos: QoS, pkid: u16) -> Result<(), ClientError> { + if let Some(ack) = get_ack_req(qos, pkid) { + self.try_send_and_notify(ack)?; + } + Ok(()) + } + + /// Sends a MQTT Subscribe to the eventloop + pub async fn subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { + let subscribe = Subscribe::new(topic.into(), qos); + self.send_async_and_notify(Request::Subscribe(subscribe)) + .await + } + + /// Sends a MQTT Subscribe to the eventloop + pub fn try_subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { + let subscribe = Subscribe::new(topic.into(), qos); + self.try_send_and_notify(Request::Subscribe(subscribe)) + } + + /// Sends a MQTT Subscribe for multiple topics to the eventloop + pub async fn subscribe_many(&self, topics: T) -> Result<(), ClientError> + where + T: IntoIterator, + { + let subscribe = Subscribe::new_many(topics); + self.send_async_and_notify(Request::Subscribe(subscribe)) + .await + } + + /// Sends a MQTT Subscribe for multiple topics to the eventloop + pub fn try_subscribe_many(&self, topics: T) -> Result<(), ClientError> + where + T: IntoIterator, + { + let subscribe = Subscribe::new_many(topics); + self.try_send_and_notify(Request::Subscribe(subscribe)) + } + + /// Sends a MQTT Unsubscribe to the eventloop + pub async fn unsubscribe>(&self, topic: S) -> Result<(), ClientError> { + let unsubscribe = Unsubscribe::new(topic.into()); + self.send_async_and_notify(Request::Unsubscribe(unsubscribe)) + .await + } + + /// Sends a MQTT Unsubscribe to the eventloop + pub fn try_unsubscribe>(&self, topic: S) -> Result<(), ClientError> { + let unsubscribe = Unsubscribe::new(topic.into()); + self.try_send_and_notify(Request::Unsubscribe(unsubscribe)) + } + + async fn send_async_and_notify(&self, request: Request) -> Result<(), ClientError> { + { + let mut request_buf = self.request_buf.lock().unwrap(); + if request_buf.len() == self.request_buf_capacity { + return Err(ClientError::RequestsFull); + } + request_buf.push_back(request); + } + if let Err(SendError(_)) = self.request_tx.send_async(()).await { + return Err(ClientError::EventloopClosed); + }; + Ok(()) + } + + fn try_send_and_notify(&self, request: Request) -> Result<(), ClientError> { + let mut request_buf = self.request_buf.lock().unwrap(); + if request_buf.len() == self.request_buf_capacity { + return Err(ClientError::RequestsFull); + } + request_buf.push_back(request); + if let Err(TrySendError::Disconnected(_)) = self.request_tx.try_send(()) { + return Err(ClientError::EventloopClosed); + } + Ok(()) + } +} diff --git a/rumqttc/src/v5/client/syncclient.rs b/rumqttc/src/v5/client/syncclient.rs new file mode 100644 index 000000000..bed105012 --- /dev/null +++ b/rumqttc/src/v5/client/syncclient.rs @@ -0,0 +1,135 @@ +use tokio::runtime; + +use crate::v5::{ + client::get_ack_req, + packet::{Publish, Subscribe, SubscribeFilter, Unsubscribe}, + AsyncClient, ClientError, Connection, MqttOptions, QoS, Request, +}; + +/// `Client` to communicate with MQTT eventloop `Connection`. +/// +/// Client is cloneable and can be used to synchronously Publish, Subscribe. +/// Asynchronous channel handle can also be extracted if necessary +#[derive(Clone)] +pub struct Client { + client: AsyncClient, +} + +impl Client { + /// Create a new `Client` + pub fn new(options: MqttOptions, cap: usize) -> (Client, Connection) { + let (client, eventloop) = AsyncClient::new(options, cap); + let client = Client { client }; + let runtime = runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + let connection = Connection::new(eventloop, runtime); + (client, connection) + } + + /// Sends a MQTT Publish to the eventloop + pub fn publish( + &mut self, + topic: S, + qos: QoS, + retain: bool, + payload: V, + ) -> Result + where + S: Into, + V: Into>, + { + let mut publish = Publish::new(topic, qos, payload); + publish.retain = retain; + let pkid = publish.pkid; + self.client.send_and_notify(Request::Publish(publish))?; + Ok(pkid) + } + + pub fn try_publish( + &mut self, + topic: S, + qos: QoS, + retain: bool, + payload: V, + ) -> Result + where + S: Into, + V: Into>, + { + self.client.try_publish(topic, qos, retain, payload) + } + + /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. + pub fn ack(&self, publish: &Publish) -> Result<(), ClientError> { + if let Some(ack) = get_ack_req(publish.qos, publish.pkid) { + self.client.send_and_notify(ack)?; + } + Ok(()) + } + + /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. + pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { + self.client.try_ack(publish) + } + + /// Sends a MQTT Subscribe to the eventloop + pub fn subscribe>(&mut self, topic: S, qos: QoS) -> Result<(), ClientError> { + let subscribe = Subscribe::new(topic.into(), qos); + self.client.send_and_notify(Request::Subscribe(subscribe)) + } + + /// Sends a MQTT Subscribe to the eventloop + pub fn try_subscribe>( + &mut self, + topic: S, + qos: QoS, + ) -> Result<(), ClientError> { + self.client.try_subscribe(topic, qos) + } + + /// Sends a MQTT Subscribe for multiple topics to the eventloop + pub fn subscribe_many(&mut self, topics: T) -> Result<(), ClientError> + where + T: IntoIterator, + { + let subscribe = Subscribe::new_many(topics); + self.client.send_and_notify(Request::Subscribe(subscribe)) + } + + pub fn try_subscribe_many(&mut self, topics: T) -> Result<(), ClientError> + where + T: IntoIterator, + { + self.client.try_subscribe_many(topics) + } + + /// Sends a MQTT Unsubscribe to the eventloop + pub fn unsubscribe>(&mut self, topic: S) -> Result<(), ClientError> { + let unsubscribe = Unsubscribe::new(topic.into()); + self.client + .send_and_notify(Request::Unsubscribe(unsubscribe)) + } + + /// Sends a MQTT Unsubscribe to the eventloop + pub fn try_unsubscribe>(&mut self, topic: S) -> Result<(), ClientError> { + self.client.try_unsubscribe(topic) + } + + /// Sends a MQTT disconnect to the eventloop + pub fn disconnect(&mut self) -> Result<(), ClientError> { + self.client.send_and_notify(Request::Disconnect) + } + + /// Sends a MQTT disconnect to the eventloop + pub fn try_disconnect(&mut self) -> Result<(), ClientError> { + self.client.try_disconnect() + } + + /// Stops the eventloop right away + pub fn cancel(&mut self) -> Result<(), ClientError> { + self.client.cancel_tx.send(()).map_err(ClientError::Cancel) + } +} diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index 659a713e7..7c1843100 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -93,7 +93,7 @@ impl EventLoop { EventLoop { options, - state: MqttState::new(max_inflight, manual_acks), + state: MqttState::new(max_inflight, manual_acks, cap), request_buf, request_buf_cache: VecDeque::with_capacity(cap), requests_tx, @@ -111,10 +111,14 @@ impl EventLoop { self.requests_tx.clone() } - pub fn buf(&self) -> &Arc>> { + pub fn request_buf(&self) -> &Arc>> { &self.request_buf } + pub fn sub_events_buf(&self) -> &Arc>> { + &self.state.sub_events_buf + } + /// Handle for cancelling the eventloop. /// /// Can be useful in cases when connection should be halted immediately diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index 83e21ef01..d8d2db4c1 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -233,6 +233,7 @@ pub struct MqttOptions { /// If set to `true` MQTT acknowledgements are not sent automatically. /// Every incoming publish packet must be manually acknowledged with `client.ack(...)` method. manual_acks: bool, + override_slow_subs: bool, } impl MqttOptions { @@ -260,6 +261,7 @@ impl MqttOptions { last_will: None, conn_timeout: 5, manual_acks: false, + override_slow_subs: false, } } @@ -408,6 +410,14 @@ impl MqttOptions { pub fn manual_acks(&self) -> bool { self.manual_acks } + + pub fn set_override_slow_subs(&mut self, val: bool) { + self.override_slow_subs = val; + } + + pub fn override_slow_subs(&self) -> bool { + self.override_slow_subs + } } #[cfg(feature = "url")] diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 2bc2b470a..0cec123b3 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -2,7 +2,11 @@ use super::{packet::*, Event, Incoming, Outgoing, Request}; use bytes::BytesMut; use std::collections::VecDeque; -use std::{io, mem, time::Instant}; +use std::{ + io, mem, + sync::{Arc, Mutex}, + time::Instant, +}; /// Errors during state handling #[derive(Debug, thiserror::Error)] @@ -75,13 +79,14 @@ pub struct MqttState { pub write: BytesMut, /// Indicates if acknowledgements should be send immediately pub manual_acks: bool, + pub(crate) sub_events_buf: Arc>>, } impl MqttState { /// Creates new mqtt state. Same state should be used during a /// connection for persistent sessions while new state should /// instantiated for clean sessions - pub fn new(max_inflight: u16, manual_acks: bool) -> Self { + pub fn new(max_inflight: u16, manual_acks: bool, cap: usize) -> Self { MqttState { await_pingresp: false, collision_ping_count: 0, @@ -99,6 +104,7 @@ impl MqttState { events: VecDeque::with_capacity(100), write: BytesMut::with_capacity(10 * 1024), manual_acks, + sub_events_buf: Arc::new(Mutex::new(VecDeque::with_capacity(cap))), } } @@ -195,13 +201,12 @@ impl MqttState { let qos = publish.qos; match qos { - QoS::AtMostOnce => Ok(()), + QoS::AtMostOnce => {}, QoS::AtLeastOnce => { if !self.manual_acks { let puback = PubAck::new(publish.pkid); self.outgoing_puback(puback)? } - Ok(()) } QoS::ExactlyOnce => { let pkid = publish.pkid; @@ -210,9 +215,13 @@ impl MqttState { let pubrec = PubRec::new(pkid); self.outgoing_pubrec(pubrec)?; } - Ok(()) } } + + // TODO: maybe limit the capacity of `self.sub_events_buf` + self.sub_events_buf.lock().unwrap().push_back(publish.clone()); + + Ok(()) } fn handle_incoming_puback(&mut self, puback: &PubAck) -> Result<(), StateError> { @@ -507,7 +516,7 @@ mod test { } fn build_mqttstate() -> MqttState { - MqttState::new(100, false) + MqttState::new(100, false, 100) } #[test] From 1b3cdffd44cad95e606a9eb627886cf12f380d33 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Sat, 19 Mar 2022 09:37:30 +0530 Subject: [PATCH 22/38] Rename tls Error --- rumqttc/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index 328ce11c1..9118fc68e 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -119,7 +119,7 @@ pub use mqttbytes::v4::*; pub use mqttbytes::*; pub use state::{MqttState, StateError}; #[cfg(feature = "use-rustls")] -pub use tls::Error; +pub use tls::Error as TlsError; #[cfg(feature = "use-rustls")] pub use tokio_rustls::rustls::ClientConfig; From 57441f92adb8d06bac872243ecd0a2dc2dfdc1e9 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Sat, 19 Mar 2022 09:44:49 +0530 Subject: [PATCH 23/38] Restructure parts of error handling --- rumqttc/src/eventloop.rs | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/rumqttc/src/eventloop.rs b/rumqttc/src/eventloop.rs index 711c167cb..2d2e9d0b5 100644 --- a/rumqttc/src/eventloop.rs +++ b/rumqttc/src/eventloop.rs @@ -35,10 +35,13 @@ pub enum ConnectionError { #[error("Packet parsing error: {0}")] Mqtt4Bytes(mqttbytes::Error), #[cfg(feature = "use-rustls")] - #[error("Network: {0}")] - Network(#[from] tls::Error), + #[error("Tls Error: {0}")] + Tls(#[from] tls::Error), #[error("I/O: {0}")] Io(#[from] io::Error), + #[cfg(feature = "websocket")] + #[error("Websocket Connect: {0}")] + WsConnect(#[from] http::Error), #[error("Stream done")] StreamDone, #[error("Requests done")] @@ -295,8 +298,7 @@ async fn network_connect(options: &MqttOptions) -> Result Result Date: Tue, 22 Mar 2022 19:23:25 +0530 Subject: [PATCH 24/38] rumqttc: v5: fix pkid Signed-off-by: Abhik Jain --- rumqttc/src/v5/client/asyncclient.rs | 50 +++++++++++-- rumqttc/src/v5/client/publisher.rs | 37 ++++++++-- rumqttc/src/v5/state.rs | 102 ++++++++++++++++++--------- 3 files changed, 145 insertions(+), 44 deletions(-) diff --git a/rumqttc/src/v5/client/asyncclient.rs b/rumqttc/src/v5/client/asyncclient.rs index a00b295ef..d811cb2fc 100644 --- a/rumqttc/src/v5/client/asyncclient.rs +++ b/rumqttc/src/v5/client/asyncclient.rs @@ -1,6 +1,9 @@ use std::{ collections::VecDeque, - sync::{Arc, Mutex}, + sync::{ + atomic::{AtomicU16, Ordering}, + Arc, Mutex, + }, }; use bytes::Bytes; @@ -18,6 +21,8 @@ use crate::v5::{ pub struct AsyncClient { request_buf: Arc>>, sub_events_buf: Arc>>, + pkid_counter: Arc, + max_inflight: u16, sub_events_buf_cache: VecDeque, request_buf_capacity: usize, request_tx: Sender<()>, @@ -33,11 +38,15 @@ impl AsyncClient { let sub_events_buf_cache = VecDeque::with_capacity(cap); let request_tx = eventloop.handle(); let cancel_tx = eventloop.cancel_handle(); + let max_inflight = eventloop.state.max_inflight; + let pkid_counter = eventloop.state.pkid_counter().clone(); let client = AsyncClient { request_buf, request_buf_capacity: cap, sub_events_buf, + pkid_counter, + max_inflight, sub_events_buf_cache, request_tx, cancel_tx, @@ -51,6 +60,8 @@ impl AsyncClient { pub fn from_senders( request_buf: Arc>>, sub_events_buf: Arc>>, + pkid_counter: Arc, + max_inflight: u16, request_tx: Sender<()>, cancel_tx: Sender<()>, cap: usize, @@ -58,6 +69,8 @@ impl AsyncClient { AsyncClient { request_buf, request_buf_capacity: cap, + pkid_counter, + max_inflight, sub_events_buf, sub_events_buf_cache: VecDeque::with_capacity(cap), request_tx, @@ -79,7 +92,8 @@ impl AsyncClient { { let mut publish = Publish::new(topic, qos, payload); publish.retain = retain; - let pkid = publish.pkid; + let pkid = self.increment_pkid(); + publish.pkid = pkid; self.send_async_and_notify(Request::Publish(publish)) .await?; Ok(pkid) @@ -99,7 +113,8 @@ impl AsyncClient { { let mut publish = Publish::new(topic, qos, payload); publish.retain = retain; - let pkid = publish.pkid; + let pkid = self.increment_pkid(); + publish.pkid = pkid; self.try_send_and_notify(Request::Publish(publish))?; Ok(pkid) } @@ -127,13 +142,16 @@ impl AsyncClient { qos: QoS, retain: bool, payload: Bytes, - ) -> Result<(), ClientError> + ) -> Result where S: Into, { let mut publish = Publish::from_bytes(topic, qos, payload); publish.retain = retain; - self.send_async_and_notify(Request::Publish(publish)).await + let pkid = self.increment_pkid(); + publish.pkid = pkid; + self.send_async_and_notify(Request::Publish(publish)).await?; + Ok(pkid) } /// Sends a MQTT Subscribe to the eventloop @@ -257,6 +275,8 @@ impl AsyncClient { let publisher = Publisher { request_buf: self.request_buf.clone(), request_buf_capacity: self.request_buf_capacity, + pkid_counter: self.pkid_counter, + max_inflight: self.max_inflight, request_tx: self.request_tx.clone(), cancel_tx: self.cancel_tx.clone(), publish_topic: publish_topic.into(), @@ -271,4 +291,24 @@ impl AsyncClient { }; Ok((publisher, subscriber)) } + + fn increment_pkid(&self) -> u16 { + let mut cur_pkid = self.pkid_counter.load(Ordering::SeqCst); + loop { + let new_pkid = if cur_pkid > self.max_inflight { + 1 + } else { + cur_pkid + 1 + }; + match self.pkid_counter.compare_exchange( + cur_pkid, + new_pkid, + Ordering::SeqCst, + Ordering::Relaxed, + ) { + Ok(_prev_pkid) => break new_pkid, + Err(actual_pkid) => cur_pkid = actual_pkid, + } + } + } } diff --git a/rumqttc/src/v5/client/publisher.rs b/rumqttc/src/v5/client/publisher.rs index ae47769f6..e4de4af8b 100644 --- a/rumqttc/src/v5/client/publisher.rs +++ b/rumqttc/src/v5/client/publisher.rs @@ -1,6 +1,6 @@ use std::{ collections::VecDeque, - sync::{Arc, Mutex}, + sync::{Arc, atomic::{AtomicU16, Ordering}, Mutex}, }; use bytes::Bytes; @@ -11,6 +11,8 @@ use crate::v5::{packet::Publish, ClientError, QoS, Request}; pub struct Publisher { pub(crate) request_buf: Arc>>, pub(crate) request_buf_capacity: usize, + pub(crate) pkid_counter: Arc, + pub(crate) max_inflight: u16, pub(crate) request_tx: Sender<()>, pub(crate) cancel_tx: Sender<()>, pub(crate) publish_topic: String, @@ -26,7 +28,8 @@ impl Publisher { ) -> Result { let mut publish = Publish::new(&self.publish_topic, self.publish_qos, payload); publish.retain = retain; - let pkid = publish.pkid; + let pkid = self.increment_pkid(); + publish.pkid = pkid; self.send_async_and_notify(Request::Publish(publish)) .await?; Ok(pkid) @@ -40,16 +43,20 @@ impl Publisher { ) -> Result { let mut publish = Publish::new(&self.publish_topic, self.publish_qos, payload); publish.retain = retain; - let pkid = publish.pkid; + let pkid = self.increment_pkid(); + publish.pkid = pkid; self.try_send_and_notify(Request::Publish(publish))?; Ok(pkid) } /// Sends a MQTT Publish to the eventloop - pub async fn publish_bytes(&self, retain: bool, payload: Bytes) -> Result<(), ClientError> { + pub async fn publish_bytes(&self, retain: bool, payload: Bytes) -> Result { let mut publish = Publish::from_bytes(&self.publish_topic, self.publish_qos, payload); + let pkid = self.increment_pkid(); + publish.pkid = pkid; publish.retain = retain; - self.send_async_and_notify(Request::Publish(publish)).await + self.send_async_and_notify(Request::Publish(publish)).await?; + Ok(pkid) } async fn send_async_and_notify(&self, request: Request) -> Result<(), ClientError> { @@ -95,4 +102,24 @@ impl Publisher { } Ok(()) } + + fn increment_pkid(&self) -> u16 { + let mut cur_pkid = self.pkid_counter.load(Ordering::SeqCst); + loop { + let new_pkid = if cur_pkid > self.max_inflight { + 1 + } else { + cur_pkid + 1 + }; + match self.pkid_counter.compare_exchange( + cur_pkid, + new_pkid, + Ordering::SeqCst, + Ordering::Relaxed, + ) { + Ok(_prev_pkid) => break new_pkid, + Err(actual_pkid) => cur_pkid = actual_pkid, + } + } + } } diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 0cec123b3..c990f92d0 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -2,6 +2,7 @@ use super::{packet::*, Event, Incoming, Outgoing, Request}; use bytes::BytesMut; use std::collections::VecDeque; +use std::sync::atomic::{AtomicU16, Ordering}; use std::{ io, mem, sync::{Arc, Mutex}, @@ -59,8 +60,6 @@ pub struct MqttState { last_incoming: Instant, /// Last outgoing packet time last_outgoing: Instant, - /// Packet id of the last outgoing packet - pub(crate) last_pkid: u16, /// Number of outgoing inflight publishes pub(crate) inflight: u16, /// Maximum number of allowed inflight @@ -80,6 +79,7 @@ pub struct MqttState { /// Indicates if acknowledgements should be send immediately pub manual_acks: bool, pub(crate) sub_events_buf: Arc>>, + pkid_counter: Arc, } impl MqttState { @@ -92,7 +92,6 @@ impl MqttState { collision_ping_count: 0, last_incoming: Instant::now(), last_outgoing: Instant::now(), - last_pkid: 0, inflight: 0, max_inflight, // index 0 is wasted as 0 is not a valid packet id @@ -105,6 +104,38 @@ impl MqttState { write: BytesMut::with_capacity(10 * 1024), manual_acks, sub_events_buf: Arc::new(Mutex::new(VecDeque::with_capacity(cap))), + pkid_counter: Arc::new(AtomicU16::new(0)), + } + } + + #[inline] + pub(crate) fn pkid_counter(&self) -> &Arc { + &self.pkid_counter + } + + #[cfg(test)] + #[inline] + fn cur_pkid(&self) -> u16 { + self.pkid_counter.load(Ordering::SeqCst) + } + + pub(crate) fn increment_pkid(&self) -> u16 { + let mut cur_pkid = self.pkid_counter.load(Ordering::SeqCst); + loop { + let new_pkid = if cur_pkid > self.max_inflight { + 1 + } else { + cur_pkid + 1 + }; + match self.pkid_counter.compare_exchange( + cur_pkid, + new_pkid, + Ordering::SeqCst, + Ordering::Relaxed, + ) { + Ok(_prev_pkid) => break new_pkid, + Err(actual_pkid) => cur_pkid = actual_pkid, + } } } @@ -201,7 +232,7 @@ impl MqttState { let qos = publish.qos; match qos { - QoS::AtMostOnce => {}, + QoS::AtMostOnce => {} QoS::AtLeastOnce => { if !self.manual_acks { let puback = PubAck::new(publish.pkid); @@ -219,7 +250,10 @@ impl MqttState { } // TODO: maybe limit the capacity of `self.sub_events_buf` - self.sub_events_buf.lock().unwrap().push_back(publish.clone()); + self.sub_events_buf + .lock() + .unwrap() + .push_back(publish.clone()); Ok(()) } @@ -312,7 +346,7 @@ impl MqttState { fn outgoing_publish(&mut self, mut publish: Publish) -> Result<(), StateError> { if publish.qos != QoS::AtMostOnce { if publish.pkid == 0 { - publish.pkid = self.next_pkid(); + publish.pkid = self.increment_pkid(); } let pkid = publish.pkid; @@ -409,7 +443,7 @@ impl MqttState { } fn outgoing_subscribe(&mut self, mut subscription: Subscribe) -> Result<(), StateError> { - let pkid = self.next_pkid(); + let pkid = self.increment_pkid(); subscription.pkid = pkid; debug!( @@ -424,7 +458,7 @@ impl MqttState { } fn outgoing_unsubscribe(&mut self, mut unsub: Unsubscribe) -> Result<(), StateError> { - let pkid = self.next_pkid(); + let pkid = self.increment_pkid(); unsub.pkid = pkid; debug!( @@ -461,7 +495,7 @@ impl MqttState { let pubrel = match pubrel.pkid { // consider PacketIdentifier(0) as uninitialized packets 0 => { - pubrel.pkid = self.next_pkid(); + pubrel.pkid = self.increment_pkid(); pubrel } _ => pubrel, @@ -471,24 +505,24 @@ impl MqttState { Ok(pubrel) } - /// http://stackoverflow.com/questions/11115364/mqtt-messageid-practical-implementation - /// Packet ids are incremented till maximum set inflight messages and reset to 1 after that. - /// - fn next_pkid(&mut self) -> u16 { - let next_pkid = self.last_pkid + 1; - - // When next packet id is at the edge of inflight queue, - // set await flag. This instructs eventloop to stop - // processing requests until all the inflight publishes - // are acked - if next_pkid == self.max_inflight { - self.last_pkid = 0; - return next_pkid; - } - - self.last_pkid = next_pkid; - next_pkid - } + ///// http://stackoverflow.com/questions/11115364/mqtt-messageid-practical-implementation + ///// Packet ids are incremented till maximum set inflight messages and reset to 1 after that. + ///// + //fn next_pkid(&mut self) -> u16 { + // let next_pkid = self.last_pkid + 1; + + // // When next packet id is at the edge of inflight queue, + // // set await flag. This instructs eventloop to stop + // // processing requests until all the inflight publishes + // // are acked + // if next_pkid == self.max_inflight { + // self.last_pkid = 0; + // return next_pkid; + // } + + // self.last_pkid = next_pkid; + // next_pkid + //} } #[cfg(test)] @@ -521,10 +555,10 @@ mod test { #[test] fn next_pkid_increments_as_expected() { - let mut mqtt = build_mqttstate(); + let mqtt = build_mqttstate(); for i in 1..=100 { - let pkid = mqtt.next_pkid(); + let pkid = mqtt.increment_pkid(); // loops between 0-99. % 100 == 0 implies border let expected = i % 100; @@ -545,7 +579,7 @@ mod test { // QoS 0 publish shouldn't be saved in queue mqtt.outgoing_publish(publish).unwrap(); - assert_eq!(mqtt.last_pkid, 0); + assert_eq!(mqtt.cur_pkid(), 0); assert_eq!(mqtt.inflight, 0); // QoS1 Publish @@ -553,12 +587,12 @@ mod test { // Packet id should be set and publish should be saved in queue mqtt.outgoing_publish(publish.clone()).unwrap(); - assert_eq!(mqtt.last_pkid, 1); + assert_eq!(mqtt.cur_pkid(), 1); assert_eq!(mqtt.inflight, 1); // Packet id should be incremented and publish should be saved in queue mqtt.outgoing_publish(publish).unwrap(); - assert_eq!(mqtt.last_pkid, 2); + assert_eq!(mqtt.cur_pkid(), 2); assert_eq!(mqtt.inflight, 2); // QoS1 Publish @@ -566,12 +600,12 @@ mod test { // Packet id should be set and publish should be saved in queue mqtt.outgoing_publish(publish.clone()).unwrap(); - assert_eq!(mqtt.last_pkid, 3); + assert_eq!(mqtt.cur_pkid(), 3); assert_eq!(mqtt.inflight, 3); // Packet id should be incremented and publish should be saved in queue mqtt.outgoing_publish(publish).unwrap(); - assert_eq!(mqtt.last_pkid, 4); + assert_eq!(mqtt.cur_pkid(), 4); assert_eq!(mqtt.inflight, 4); } From b516430361e946cfb19fd6bb55d422560b30729e Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Thu, 24 Mar 2022 19:03:01 +0530 Subject: [PATCH 25/38] rumqttc: v5: remove cancel channel Signed-off-by: Abhik Jain --- rumqttc/src/v5/client/asyncclient.rs | 16 +------------ rumqttc/src/v5/client/publisher.rs | 9 ------- rumqttc/src/v5/client/syncclient.rs | 5 ---- rumqttc/src/v5/eventloop.rs | 35 +--------------------------- 4 files changed, 2 insertions(+), 63 deletions(-) diff --git a/rumqttc/src/v5/client/asyncclient.rs b/rumqttc/src/v5/client/asyncclient.rs index d811cb2fc..478b189fd 100644 --- a/rumqttc/src/v5/client/asyncclient.rs +++ b/rumqttc/src/v5/client/asyncclient.rs @@ -26,18 +26,16 @@ pub struct AsyncClient { sub_events_buf_cache: VecDeque, request_buf_capacity: usize, request_tx: Sender<()>, - pub(crate) cancel_tx: Sender<()>, } impl AsyncClient { /// Create a new `AsyncClient` pub fn new(options: MqttOptions, cap: usize) -> (AsyncClient, EventLoop) { - let mut eventloop = EventLoop::new(options, cap); + let eventloop = EventLoop::new(options, cap); let request_buf = eventloop.request_buf().clone(); let sub_events_buf = eventloop.sub_events_buf().clone(); let sub_events_buf_cache = VecDeque::with_capacity(cap); let request_tx = eventloop.handle(); - let cancel_tx = eventloop.cancel_handle(); let max_inflight = eventloop.state.max_inflight; let pkid_counter = eventloop.state.pkid_counter().clone(); @@ -49,7 +47,6 @@ impl AsyncClient { max_inflight, sub_events_buf_cache, request_tx, - cancel_tx, }; (client, eventloop) @@ -63,7 +60,6 @@ impl AsyncClient { pkid_counter: Arc, max_inflight: u16, request_tx: Sender<()>, - cancel_tx: Sender<()>, cap: usize, ) -> AsyncClient { AsyncClient { @@ -74,7 +70,6 @@ impl AsyncClient { sub_events_buf, sub_events_buf_cache: VecDeque::with_capacity(cap), request_tx, - cancel_tx, } } @@ -209,14 +204,6 @@ impl AsyncClient { self.try_send_and_notify(Request::Disconnect) } - /// Stops the eventloop right away - pub async fn cancel(&self) -> Result<(), ClientError> { - self.cancel_tx - .send_async(()) - .await - .map_err(ClientError::Cancel) - } - async fn send_async_and_notify(&self, request: Request) -> Result<(), ClientError> { { let mut request_buf = self.request_buf.lock().unwrap(); @@ -278,7 +265,6 @@ impl AsyncClient { pkid_counter: self.pkid_counter, max_inflight: self.max_inflight, request_tx: self.request_tx.clone(), - cancel_tx: self.cancel_tx.clone(), publish_topic: publish_topic.into(), publish_qos, }; diff --git a/rumqttc/src/v5/client/publisher.rs b/rumqttc/src/v5/client/publisher.rs index e4de4af8b..8cf924254 100644 --- a/rumqttc/src/v5/client/publisher.rs +++ b/rumqttc/src/v5/client/publisher.rs @@ -14,7 +14,6 @@ pub struct Publisher { pub(crate) pkid_counter: Arc, pub(crate) max_inflight: u16, pub(crate) request_tx: Sender<()>, - pub(crate) cancel_tx: Sender<()>, pub(crate) publish_topic: String, pub(crate) publish_qos: QoS, } @@ -83,14 +82,6 @@ impl Publisher { self.try_send_and_notify(Request::Disconnect) } - /// Stops the eventloop right away - pub async fn cancel(&self) -> Result<(), ClientError> { - self.cancel_tx - .send_async(()) - .await - .map_err(ClientError::Cancel) - } - fn try_send_and_notify(&self, request: Request) -> Result<(), ClientError> { let mut request_buf = self.request_buf.lock().unwrap(); if request_buf.len() == self.request_buf_capacity { diff --git a/rumqttc/src/v5/client/syncclient.rs b/rumqttc/src/v5/client/syncclient.rs index bed105012..749b06627 100644 --- a/rumqttc/src/v5/client/syncclient.rs +++ b/rumqttc/src/v5/client/syncclient.rs @@ -127,9 +127,4 @@ impl Client { pub fn try_disconnect(&mut self) -> Result<(), ClientError> { self.client.try_disconnect() } - - /// Stops the eventloop right away - pub fn cancel(&mut self) -> Result<(), ClientError> { - self.client.cancel_tx.send(()).map_err(ClientError::Cancel) - } } diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index 7c1843100..0c03fe3bf 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -64,10 +64,6 @@ pub struct EventLoop { pub(crate) network: Option, /// Keep alive time pub(crate) keepalive_timeout: Option>>, - /// Handle to read cancellation requests - pub(crate) cancel_rx: Receiver<()>, - /// Handle to send cancellation requests (and drops) - pub(crate) cancel_tx: Sender<()>, } /// Events which can be yielded by the event loop @@ -83,7 +79,6 @@ impl EventLoop { /// When connection encounters critical errors (like auth failure), user has a choice to /// access and update `options`, `state` and `requests`. pub fn new(options: MqttOptions, cap: usize) -> EventLoop { - let (cancel_tx, cancel_rx) = bounded(5); let (requests_tx, requests_rx) = bounded(1); let request_buf = Arc::new(Mutex::new(VecDeque::with_capacity(cap))); let pending = Vec::new(); @@ -101,8 +96,6 @@ impl EventLoop { pending, network: None, keepalive_timeout: None, - cancel_rx, - cancel_tx, } } @@ -119,14 +112,6 @@ impl EventLoop { &self.state.sub_events_buf } - /// Handle for cancelling the eventloop. - /// - /// Can be useful in cases when connection should be halted immediately - /// between half-open connection detections or (re)connection timeouts - pub(crate) fn cancel_handle(&mut self) -> Sender<()> { - self.cancel_tx.clone() - } - fn clean(&mut self) { self.network = None; self.keepalive_timeout = None; @@ -140,7 +125,7 @@ impl EventLoop { /// **NOTE** Don't block this while iterating pub async fn poll(&mut self) -> Result { if self.network.is_none() { - let (network, connack) = connect_or_cancel(&self.options, &self.cancel_rx).await?; + let (network, connack) = connect(&self.options).await?; self.network = Some(network); if self.keepalive_timeout.is_none() { @@ -243,29 +228,11 @@ impl EventLoop { network.flush(&mut self.state.write).await?; return Ok(self.state.events.pop_front().unwrap()) } - // cancellation requests to stop the polling - _ = self.cancel_rx.recv_async() => { - return Err(ConnectionError::Cancel) - } } } } } -async fn connect_or_cancel( - options: &MqttOptions, - cancel_rx: &Receiver<()>, -) -> Result<(Network, Incoming), ConnectionError> { - // select here prevents cancel request from being blocked until connection request is - // resolved. Returns with an error if connections fail continuously - select! { - o = connect(options) => o, - _ = cancel_rx.recv_async() => { - Err(ConnectionError::Cancel) - } - } -} - /// This stream internally processes requests from the request stream provided to the eventloop /// while also consuming byte stream from the network and yielding mqtt packets as the output of /// the stream. From d2de93ea3994917d26e9217de72866b7c8aa81ce Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Thu, 24 Mar 2022 19:13:08 +0530 Subject: [PATCH 26/38] rumqttc: v5: renaming buffers Signed-off-by: Abhik Jain --- rumqttc/src/v5/client/asyncclient.rs | 62 ++++++++++++++-------------- rumqttc/src/v5/client/publisher.rs | 12 +++--- rumqttc/src/v5/client/subscriber.rs | 18 ++++---- rumqttc/src/v5/eventloop.rs | 32 +++++++------- rumqttc/src/v5/state.rs | 8 ++-- 5 files changed, 66 insertions(+), 66 deletions(-) diff --git a/rumqttc/src/v5/client/asyncclient.rs b/rumqttc/src/v5/client/asyncclient.rs index 478b189fd..84a75b8f0 100644 --- a/rumqttc/src/v5/client/asyncclient.rs +++ b/rumqttc/src/v5/client/asyncclient.rs @@ -19,12 +19,12 @@ use crate::v5::{ /// This is cloneable and can be used to asynchronously Publish, Subscribe. #[derive(Clone, Debug)] pub struct AsyncClient { - request_buf: Arc>>, - sub_events_buf: Arc>>, + incoming_buf: Arc>>, + outgoing_buf: Arc>>, pkid_counter: Arc, max_inflight: u16, - sub_events_buf_cache: VecDeque, - request_buf_capacity: usize, + incoming_buf_cache: VecDeque, + incoming_buf_capacity: usize, request_tx: Sender<()>, } @@ -33,19 +33,19 @@ impl AsyncClient { pub fn new(options: MqttOptions, cap: usize) -> (AsyncClient, EventLoop) { let eventloop = EventLoop::new(options, cap); let request_buf = eventloop.request_buf().clone(); - let sub_events_buf = eventloop.sub_events_buf().clone(); - let sub_events_buf_cache = VecDeque::with_capacity(cap); + let incoming_buf = eventloop.state.incoming_buf.clone(); + let incoming_buf_cache = VecDeque::with_capacity(cap); let request_tx = eventloop.handle(); let max_inflight = eventloop.state.max_inflight; let pkid_counter = eventloop.state.pkid_counter().clone(); let client = AsyncClient { - request_buf, - request_buf_capacity: cap, - sub_events_buf, + incoming_buf: request_buf, + incoming_buf_capacity: cap, + outgoing_buf: incoming_buf, pkid_counter, max_inflight, - sub_events_buf_cache, + incoming_buf_cache, request_tx, }; @@ -56,19 +56,19 @@ impl AsyncClient { /// creating a test instance. pub fn from_senders( request_buf: Arc>>, - sub_events_buf: Arc>>, + incoming_buf: Arc>>, pkid_counter: Arc, max_inflight: u16, request_tx: Sender<()>, cap: usize, ) -> AsyncClient { AsyncClient { - request_buf, - request_buf_capacity: cap, + incoming_buf: request_buf, + incoming_buf_capacity: cap, pkid_counter, max_inflight, - sub_events_buf, - sub_events_buf_cache: VecDeque::with_capacity(cap), + outgoing_buf: incoming_buf, + incoming_buf_cache: VecDeque::with_capacity(cap), request_tx, } } @@ -206,8 +206,8 @@ impl AsyncClient { async fn send_async_and_notify(&self, request: Request) -> Result<(), ClientError> { { - let mut request_buf = self.request_buf.lock().unwrap(); - if request_buf.len() == self.request_buf_capacity { + let mut request_buf = self.incoming_buf.lock().unwrap(); + if request_buf.len() == self.incoming_buf_capacity { return Err(ClientError::RequestsFull); } request_buf.push_back(request); @@ -219,8 +219,8 @@ impl AsyncClient { } pub(crate) fn send_and_notify(&self, request: Request) -> Result<(), ClientError> { - let mut request_buf = self.request_buf.lock().unwrap(); - if request_buf.len() == self.request_buf_capacity { + let mut request_buf = self.incoming_buf.lock().unwrap(); + if request_buf.len() == self.incoming_buf_capacity { return Err(ClientError::RequestsFull); } request_buf.push_back(request); @@ -231,8 +231,8 @@ impl AsyncClient { } fn try_send_and_notify(&self, request: Request) -> Result<(), ClientError> { - let mut request_buf = self.request_buf.lock().unwrap(); - if request_buf.len() == self.request_buf_capacity { + let mut request_buf = self.incoming_buf.lock().unwrap(); + if request_buf.len() == self.incoming_buf_capacity { return Err(ClientError::RequestsFull); } request_buf.push_back(request); @@ -243,15 +243,15 @@ impl AsyncClient { } pub fn next_publish(&mut self) -> Option { - if let Some(publish) = self.sub_events_buf_cache.pop_front() { + if let Some(publish) = self.incoming_buf_cache.pop_front() { return Some(publish); } std::mem::swap( - &mut self.sub_events_buf_cache, - &mut *self.sub_events_buf.lock().unwrap(), + &mut self.incoming_buf_cache, + &mut *self.outgoing_buf.lock().unwrap(), ); - self.sub_events_buf_cache.pop_front() + self.incoming_buf_cache.pop_front() } pub async fn split( @@ -260,8 +260,8 @@ impl AsyncClient { publish_qos: QoS, ) -> Result<(Publisher, Subscriber), ClientError> { let publisher = Publisher { - request_buf: self.request_buf.clone(), - request_buf_capacity: self.request_buf_capacity, + incoming_buf: self.incoming_buf.clone(), + incoming_buf_capacity: self.incoming_buf_capacity, pkid_counter: self.pkid_counter, max_inflight: self.max_inflight, request_tx: self.request_tx.clone(), @@ -269,10 +269,10 @@ impl AsyncClient { publish_qos, }; let subscriber = Subscriber { - request_buf: self.request_buf, - sub_events_buf: self.sub_events_buf, - sub_events_buf_cache: self.sub_events_buf_cache, - request_buf_capacity: self.request_buf_capacity, + outgoing_buf: self.incoming_buf, + incoming_buf: self.outgoing_buf, + incoming_buf_cache: self.incoming_buf_cache, + request_buf_capacity: self.incoming_buf_capacity, request_tx: self.request_tx, }; Ok((publisher, subscriber)) diff --git a/rumqttc/src/v5/client/publisher.rs b/rumqttc/src/v5/client/publisher.rs index 8cf924254..3b5c07f2d 100644 --- a/rumqttc/src/v5/client/publisher.rs +++ b/rumqttc/src/v5/client/publisher.rs @@ -9,8 +9,8 @@ use flume::{SendError, Sender, TrySendError}; use crate::v5::{packet::Publish, ClientError, QoS, Request}; pub struct Publisher { - pub(crate) request_buf: Arc>>, - pub(crate) request_buf_capacity: usize, + pub(crate) incoming_buf: Arc>>, + pub(crate) incoming_buf_capacity: usize, pub(crate) pkid_counter: Arc, pub(crate) max_inflight: u16, pub(crate) request_tx: Sender<()>, @@ -60,8 +60,8 @@ impl Publisher { async fn send_async_and_notify(&self, request: Request) -> Result<(), ClientError> { { - let mut request_buf = self.request_buf.lock().unwrap(); - if request_buf.len() == self.request_buf_capacity { + let mut request_buf = self.incoming_buf.lock().unwrap(); + if request_buf.len() == self.incoming_buf_capacity { return Err(ClientError::RequestsFull); } request_buf.push_back(request); @@ -83,8 +83,8 @@ impl Publisher { } fn try_send_and_notify(&self, request: Request) -> Result<(), ClientError> { - let mut request_buf = self.request_buf.lock().unwrap(); - if request_buf.len() == self.request_buf_capacity { + let mut request_buf = self.incoming_buf.lock().unwrap(); + if request_buf.len() == self.incoming_buf_capacity { return Err(ClientError::RequestsFull); } request_buf.push_back(request); diff --git a/rumqttc/src/v5/client/subscriber.rs b/rumqttc/src/v5/client/subscriber.rs index 6a7889f18..d1982bddd 100644 --- a/rumqttc/src/v5/client/subscriber.rs +++ b/rumqttc/src/v5/client/subscriber.rs @@ -12,24 +12,24 @@ use crate::v5::{ #[derive(Debug, Clone)] pub struct Subscriber { - pub(crate) request_buf: Arc>>, - pub(crate) sub_events_buf: Arc>>, - pub(crate) sub_events_buf_cache: VecDeque, + pub(crate) outgoing_buf: Arc>>, + pub(crate) incoming_buf: Arc>>, + pub(crate) incoming_buf_cache: VecDeque, pub(crate) request_buf_capacity: usize, pub(crate) request_tx: Sender<()>, } impl Subscriber { pub fn next_publish(&mut self) -> Option { - if let Some(publish) = self.sub_events_buf_cache.pop_front() { + if let Some(publish) = self.incoming_buf_cache.pop_front() { return Some(publish); } std::mem::swap( - &mut self.sub_events_buf_cache, - &mut *self.sub_events_buf.lock().unwrap(), + &mut self.incoming_buf_cache, + &mut *self.incoming_buf.lock().unwrap(), ); - self.sub_events_buf_cache.pop_front() + self.incoming_buf_cache.pop_front() } /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. @@ -95,7 +95,7 @@ impl Subscriber { async fn send_async_and_notify(&self, request: Request) -> Result<(), ClientError> { { - let mut request_buf = self.request_buf.lock().unwrap(); + let mut request_buf = self.outgoing_buf.lock().unwrap(); if request_buf.len() == self.request_buf_capacity { return Err(ClientError::RequestsFull); } @@ -108,7 +108,7 @@ impl Subscriber { } fn try_send_and_notify(&self, request: Request) -> Result<(), ClientError> { - let mut request_buf = self.request_buf.lock().unwrap(); + let mut request_buf = self.outgoing_buf.lock().unwrap(); if request_buf.len() == self.request_buf_capacity { return Err(ClientError::RequestsFull); } diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index 0c03fe3bf..d8b8ad4cf 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -52,12 +52,12 @@ pub struct EventLoop { pub options: MqttOptions, /// Current state of the connection pub state: MqttState, - request_buf: Arc>>, - request_buf_cache: VecDeque, + incoming_buf: Arc>>, + incoming_buf_cache: VecDeque, /// Request stream - pub requests_rx: Receiver<()>, + pub incoming_rx: Receiver<()>, /// Requests handle to send requests - pub requests_tx: Sender<()>, + pub incoming_tx: Sender<()>, /// Pending packets from last session pub pending: IntoIter, /// Network connection to the broker @@ -79,7 +79,7 @@ impl EventLoop { /// When connection encounters critical errors (like auth failure), user has a choice to /// access and update `options`, `state` and `requests`. pub fn new(options: MqttOptions, cap: usize) -> EventLoop { - let (requests_tx, requests_rx) = bounded(1); + let (incoming_tx, incoming_rx) = bounded(1); let request_buf = Arc::new(Mutex::new(VecDeque::with_capacity(cap))); let pending = Vec::new(); let pending = pending.into_iter(); @@ -89,10 +89,10 @@ impl EventLoop { EventLoop { options, state: MqttState::new(max_inflight, manual_acks, cap), - request_buf, - request_buf_cache: VecDeque::with_capacity(cap), - requests_tx, - requests_rx, + incoming_buf: request_buf, + incoming_buf_cache: VecDeque::with_capacity(cap), + incoming_tx, + incoming_rx, pending, network: None, keepalive_timeout: None, @@ -101,15 +101,15 @@ impl EventLoop { /// Returns a handle to communicate with this eventloop pub fn handle(&self) -> Sender<()> { - self.requests_tx.clone() + self.incoming_tx.clone() } pub fn request_buf(&self) -> &Arc>> { - &self.request_buf + &self.incoming_buf } pub fn sub_events_buf(&self) -> &Arc>> { - &self.state.sub_events_buf + &self.state.incoming_buf } fn clean(&mut self) { @@ -194,14 +194,14 @@ impl EventLoop { // After collision with pkid 1 -> [1b ,2, x, 4, 5]. // 1a is saved to state and event loop is set to collision mode stopping new // outgoing requests (along with 1b). - o = self.requests_rx.recv_async(), if !inflight_full && !pending && !collision => match o { + o = self.incoming_rx.recv_async(), if !inflight_full && !pending && !collision => match o { Ok(_request_notif) => { // swapping to avoid blocking the mutex - std::mem::swap(&mut self.request_buf_cache,&mut *self.request_buf.lock().unwrap()); - if self.request_buf_cache.is_empty() { + std::mem::swap(&mut self.incoming_buf_cache,&mut *self.incoming_buf.lock().unwrap()); + if self.incoming_buf_cache.is_empty() { continue; } - for request in self.request_buf_cache.drain(..) { + for request in self.incoming_buf_cache.drain(..) { self.state.handle_outgoing_packet(request)?; } network.flush(&mut self.state.write).await?; diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index c990f92d0..8ad0e5375 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -78,7 +78,7 @@ pub struct MqttState { pub write: BytesMut, /// Indicates if acknowledgements should be send immediately pub manual_acks: bool, - pub(crate) sub_events_buf: Arc>>, + pub(crate) incoming_buf: Arc>>, pkid_counter: Arc, } @@ -103,7 +103,7 @@ impl MqttState { events: VecDeque::with_capacity(100), write: BytesMut::with_capacity(10 * 1024), manual_acks, - sub_events_buf: Arc::new(Mutex::new(VecDeque::with_capacity(cap))), + incoming_buf: Arc::new(Mutex::new(VecDeque::with_capacity(cap))), pkid_counter: Arc::new(AtomicU16::new(0)), } } @@ -249,8 +249,8 @@ impl MqttState { } } - // TODO: maybe limit the capacity of `self.sub_events_buf` - self.sub_events_buf + // TODO: maybe limit the capacity of `self.incoming_buf` + self.incoming_buf .lock() .unwrap() .push_back(publish.clone()); From d7a897ad1541240b8c10a3206317d3c8bbf70ff3 Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Thu, 24 Mar 2022 19:13:49 +0530 Subject: [PATCH 27/38] rumqttc: v5: asyncclient: remove `AsyncClient::from_senders(..)` Signed-off-by: Abhik Jain --- rumqttc/src/v5/client/asyncclient.rs | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/rumqttc/src/v5/client/asyncclient.rs b/rumqttc/src/v5/client/asyncclient.rs index 84a75b8f0..9e72ab45c 100644 --- a/rumqttc/src/v5/client/asyncclient.rs +++ b/rumqttc/src/v5/client/asyncclient.rs @@ -52,27 +52,6 @@ impl AsyncClient { (client, eventloop) } - /// Create a new `AsyncClient` from a pair of async channel `Sender`s. This is mostly useful for - /// creating a test instance. - pub fn from_senders( - request_buf: Arc>>, - incoming_buf: Arc>>, - pkid_counter: Arc, - max_inflight: u16, - request_tx: Sender<()>, - cap: usize, - ) -> AsyncClient { - AsyncClient { - incoming_buf: request_buf, - incoming_buf_capacity: cap, - pkid_counter, - max_inflight, - outgoing_buf: incoming_buf, - incoming_buf_cache: VecDeque::with_capacity(cap), - request_tx, - } - } - /// Sends a MQTT Publish to the eventloop pub async fn publish( &self, From 923e2adfdb094d70194af43c81fc6b1dd93408a4 Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Thu, 24 Mar 2022 19:15:35 +0530 Subject: [PATCH 28/38] rumqttc: v5: asyncclient: internal functions rename Signed-off-by: Abhik Jain --- rumqttc/src/v5/client/asyncclient.rs | 32 ++++++++++++++-------------- rumqttc/src/v5/client/syncclient.rs | 12 +++++------ 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/rumqttc/src/v5/client/asyncclient.rs b/rumqttc/src/v5/client/asyncclient.rs index 9e72ab45c..f7e1f3783 100644 --- a/rumqttc/src/v5/client/asyncclient.rs +++ b/rumqttc/src/v5/client/asyncclient.rs @@ -68,7 +68,7 @@ impl AsyncClient { publish.retain = retain; let pkid = self.increment_pkid(); publish.pkid = pkid; - self.send_async_and_notify(Request::Publish(publish)) + self.push_and_async_notify(Request::Publish(publish)) .await?; Ok(pkid) } @@ -89,14 +89,14 @@ impl AsyncClient { publish.retain = retain; let pkid = self.increment_pkid(); publish.pkid = pkid; - self.try_send_and_notify(Request::Publish(publish))?; + self.push_and_try_notify(Request::Publish(publish))?; Ok(pkid) } /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. pub async fn ack(&self, publish: &Publish) -> Result<(), ClientError> { if let Some(ack) = get_ack_req(publish.qos, publish.pkid) { - self.send_async_and_notify(ack).await?; + self.push_and_async_notify(ack).await?; } Ok(()) } @@ -104,7 +104,7 @@ impl AsyncClient { /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { if let Some(ack) = get_ack_req(publish.qos, publish.pkid) { - self.try_send_and_notify(ack)?; + self.push_and_try_notify(ack)?; } Ok(()) } @@ -124,21 +124,21 @@ impl AsyncClient { publish.retain = retain; let pkid = self.increment_pkid(); publish.pkid = pkid; - self.send_async_and_notify(Request::Publish(publish)).await?; + self.push_and_async_notify(Request::Publish(publish)).await?; Ok(pkid) } /// Sends a MQTT Subscribe to the eventloop pub async fn subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { let subscribe = Subscribe::new(topic.into(), qos); - self.send_async_and_notify(Request::Subscribe(subscribe)) + self.push_and_async_notify(Request::Subscribe(subscribe)) .await } /// Sends a MQTT Subscribe to the eventloop pub fn try_subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { let subscribe = Subscribe::new(topic.into(), qos); - self.try_send_and_notify(Request::Subscribe(subscribe)) + self.push_and_try_notify(Request::Subscribe(subscribe)) } /// Sends a MQTT Subscribe for multiple topics to the eventloop @@ -147,7 +147,7 @@ impl AsyncClient { T: IntoIterator, { let subscribe = Subscribe::new_many(topics); - self.send_async_and_notify(Request::Subscribe(subscribe)) + self.push_and_async_notify(Request::Subscribe(subscribe)) .await } @@ -157,33 +157,33 @@ impl AsyncClient { T: IntoIterator, { let subscribe = Subscribe::new_many(topics); - self.try_send_and_notify(Request::Subscribe(subscribe)) + self.push_and_try_notify(Request::Subscribe(subscribe)) } /// Sends a MQTT Unsubscribe to the eventloop pub async fn unsubscribe>(&self, topic: S) -> Result<(), ClientError> { let unsubscribe = Unsubscribe::new(topic.into()); - self.send_async_and_notify(Request::Unsubscribe(unsubscribe)) + self.push_and_async_notify(Request::Unsubscribe(unsubscribe)) .await } /// Sends a MQTT Unsubscribe to the eventloop pub fn try_unsubscribe>(&self, topic: S) -> Result<(), ClientError> { let unsubscribe = Unsubscribe::new(topic.into()); - self.try_send_and_notify(Request::Unsubscribe(unsubscribe)) + self.push_and_try_notify(Request::Unsubscribe(unsubscribe)) } /// Sends a MQTT disconnect to the eventloop pub async fn disconnect(&self) -> Result<(), ClientError> { - self.send_async_and_notify(Request::Disconnect).await + self.push_and_async_notify(Request::Disconnect).await } /// Sends a MQTT disconnect to the eventloop pub fn try_disconnect(&self) -> Result<(), ClientError> { - self.try_send_and_notify(Request::Disconnect) + self.push_and_try_notify(Request::Disconnect) } - async fn send_async_and_notify(&self, request: Request) -> Result<(), ClientError> { + async fn push_and_async_notify(&self, request: Request) -> Result<(), ClientError> { { let mut request_buf = self.incoming_buf.lock().unwrap(); if request_buf.len() == self.incoming_buf_capacity { @@ -197,7 +197,7 @@ impl AsyncClient { Ok(()) } - pub(crate) fn send_and_notify(&self, request: Request) -> Result<(), ClientError> { + pub(crate) fn push_and_notify(&self, request: Request) -> Result<(), ClientError> { let mut request_buf = self.incoming_buf.lock().unwrap(); if request_buf.len() == self.incoming_buf_capacity { return Err(ClientError::RequestsFull); @@ -209,7 +209,7 @@ impl AsyncClient { Ok(()) } - fn try_send_and_notify(&self, request: Request) -> Result<(), ClientError> { + fn push_and_try_notify(&self, request: Request) -> Result<(), ClientError> { let mut request_buf = self.incoming_buf.lock().unwrap(); if request_buf.len() == self.incoming_buf_capacity { return Err(ClientError::RequestsFull); diff --git a/rumqttc/src/v5/client/syncclient.rs b/rumqttc/src/v5/client/syncclient.rs index 749b06627..e517db9ff 100644 --- a/rumqttc/src/v5/client/syncclient.rs +++ b/rumqttc/src/v5/client/syncclient.rs @@ -44,7 +44,7 @@ impl Client { let mut publish = Publish::new(topic, qos, payload); publish.retain = retain; let pkid = publish.pkid; - self.client.send_and_notify(Request::Publish(publish))?; + self.client.push_and_notify(Request::Publish(publish))?; Ok(pkid) } @@ -65,7 +65,7 @@ impl Client { /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. pub fn ack(&self, publish: &Publish) -> Result<(), ClientError> { if let Some(ack) = get_ack_req(publish.qos, publish.pkid) { - self.client.send_and_notify(ack)?; + self.client.push_and_notify(ack)?; } Ok(()) } @@ -78,7 +78,7 @@ impl Client { /// Sends a MQTT Subscribe to the eventloop pub fn subscribe>(&mut self, topic: S, qos: QoS) -> Result<(), ClientError> { let subscribe = Subscribe::new(topic.into(), qos); - self.client.send_and_notify(Request::Subscribe(subscribe)) + self.client.push_and_notify(Request::Subscribe(subscribe)) } /// Sends a MQTT Subscribe to the eventloop @@ -96,7 +96,7 @@ impl Client { T: IntoIterator, { let subscribe = Subscribe::new_many(topics); - self.client.send_and_notify(Request::Subscribe(subscribe)) + self.client.push_and_notify(Request::Subscribe(subscribe)) } pub fn try_subscribe_many(&mut self, topics: T) -> Result<(), ClientError> @@ -110,7 +110,7 @@ impl Client { pub fn unsubscribe>(&mut self, topic: S) -> Result<(), ClientError> { let unsubscribe = Unsubscribe::new(topic.into()); self.client - .send_and_notify(Request::Unsubscribe(unsubscribe)) + .push_and_notify(Request::Unsubscribe(unsubscribe)) } /// Sends a MQTT Unsubscribe to the eventloop @@ -120,7 +120,7 @@ impl Client { /// Sends a MQTT disconnect to the eventloop pub fn disconnect(&mut self) -> Result<(), ClientError> { - self.client.send_and_notify(Request::Disconnect) + self.client.push_and_notify(Request::Disconnect) } /// Sends a MQTT disconnect to the eventloop From acfda4cd27e9010af3659a959815cead1d42d9b7 Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Thu, 24 Mar 2022 21:14:06 +0530 Subject: [PATCH 29/38] rumqttc: merge master (fixed) Signed-off-by: Abhik Jain --- rumqttc/examples/websocket.rs | 2 +- rumqttc/src/v4/mod.rs | 127 +++++++++++++++++++++++++++++++--- rumqttc/src/v5/eventloop.rs | 8 ++- rumqttc/src/v5/mod.rs | 33 +++++---- 4 files changed, 140 insertions(+), 30 deletions(-) diff --git a/rumqttc/examples/websocket.rs b/rumqttc/examples/websocket.rs index 712427b3c..b5d1e39a4 100644 --- a/rumqttc/examples/websocket.rs +++ b/rumqttc/examples/websocket.rs @@ -1,5 +1,5 @@ #[cfg(feature = "websocket")] -use rumqttc::{self, AsyncClient, MqttOptions, QoS, Transport}; +use rumqttc::v4::{AsyncClient, MqttOptions, QoS, Transport}; #[cfg(feature = "websocket")] use std::{error::Error, time::Duration}; #[cfg(feature = "websocket")] diff --git a/rumqttc/src/v4/mod.rs b/rumqttc/src/v4/mod.rs index 709b327bf..0c39b6478 100644 --- a/rumqttc/src/v4/mod.rs +++ b/rumqttc/src/v4/mod.rs @@ -1,4 +1,103 @@ +//! A pure rust MQTT client which strives to be robust, efficient and easy to use. +//! This library is backed by an async (tokio) eventloop which handles all the +//! robustness and and efficiency parts of MQTT but naturally fits into both sync +//! and async worlds as we'll see +//! +//! Let's jump into examples right away +//! +//! A simple synchronous publish and subscribe +//! ---------------------------- +//! +//! ```no_run +//! use rumqttc::v4::{MqttOptions, Client, QoS}; +//! use std::time::Duration; +//! use std::thread; +//! +//! let mut mqttoptions = MqttOptions::new("rumqtt-sync", "test.mosquitto.org", 1883); +//! mqttoptions.set_keep_alive(Duration::from_secs(5)); +//! +//! let (mut client, mut connection) = Client::new(mqttoptions, 10); +//! client.subscribe("hello/rumqtt", QoS::AtMostOnce).unwrap(); +//! thread::spawn(move || for i in 0..10 { +//! client.publish("hello/rumqtt", QoS::AtLeastOnce, false, vec![i; i as usize]).unwrap(); +//! thread::sleep(Duration::from_millis(100)); +//! }); +//! +//! // Iterate to poll the eventloop for connection progress +//! for (i, notification) in connection.iter().enumerate() { +//! println!("Notification = {:?}", notification); +//! } +//! ``` +//! +//! A simple asynchronous publish and subscribe +//! ------------------------------ +//! +//! ```no_run +//! use rumqttc::v4::{MqttOptions, AsyncClient, QoS}; +//! use tokio::{task, time}; +//! use std::time::Duration; +//! use std::error::Error; +//! +//! # #[tokio::main(worker_threads = 1)] +//! # async fn main() { +//! let mut mqttoptions = MqttOptions::new("rumqtt-async", "test.mosquitto.org", 1883); +//! mqttoptions.set_keep_alive(Duration::from_secs(5)); +//! +//! let (mut client, mut eventloop) = AsyncClient::new(mqttoptions, 10); +//! client.subscribe("hello/rumqtt", QoS::AtMostOnce).await.unwrap(); +//! +//! task::spawn(async move { +//! for i in 0..10 { +//! client.publish("hello/rumqtt", QoS::AtLeastOnce, false, vec![i; i as usize]).await.unwrap(); +//! time::sleep(Duration::from_millis(100)).await; +//! } +//! }); +//! +//! loop { +//! let notification = eventloop.poll().await.unwrap(); +//! println!("Received = {:?}", notification); +//! } +//! # } +//! ``` +//! +//! Quick overview of features +//! - Eventloop orchestrates outgoing/incoming packets concurrently and hadles the state +//! - Pings the broker when necessary and detects client side half open connections as well +//! - Throttling of outgoing packets (todo) +//! - Queue size based flow control on outgoing packets +//! - Automatic reconnections by just continuing the `eventloop.poll()/connection.iter()` loop` +//! - Natural backpressure to client APIs during bad network +//! - Immediate cancellation with `client.cancel()` +//! +//! In short, everything necessary to maintain a robust connection +//! +//! Since the eventloop is externally polled (with `iter()/poll()` in a loop) +//! out side the library and `Eventloop` is accessible, users can +//! - Distribute incoming messages based on topics +//! - Stop it when required +//! - Access internal state for use cases like graceful shutdown or to modify options before reconnection +//! +//! ## Important notes +//! +//! - Looping on `connection.iter()`/`eventloop.poll()` is necessary to run the +//! event loop and make progress. It yields incoming and outgoing activity +//! notifications which allows customization as you see fit. +//! +//! - Blocking inside the `connection.iter()`/`eventloop.poll()` loop will block +//! connection progress. +//! +//! ## FAQ +//! **Connecting to a broker using raw ip doesn't work** +//! +//! You cannot create a TLS connection to a bare IP address with a self-signed +//! certificate. This is a [limitation of rustls](https://github.com/ctz/rustls/issues/184). +//! One workaround, which only works under *nix/BSD-like systems, is to add an +//! entry to wherever your DNS resolver looks (e.g. `/etc/hosts`) for the bare IP +//! address and use that name in your code. +#![cfg_attr(docsrs, feature(doc_cfg))] + use std::fmt::{self, Debug, Formatter}; +#[cfg(feature = "use-rustls")] use std::sync::Arc; use std::time::Duration; @@ -6,15 +105,18 @@ mod client; mod eventloop; mod framed; mod state; +#[cfg(feature = "use-rustls")] mod tls; -pub use crate::mqttbytes::v4::*; -pub use crate::mqttbytes::*; pub use async_channel::{SendError, Sender, TrySendError}; pub use client::{AsyncClient, Client, ClientError, Connection}; pub use eventloop::{ConnectionError, Event, EventLoop}; +pub use crate::mqttbytes::v4::*; +pub use crate::mqttbytes::*; pub use state::{MqttState, StateError}; -pub use tls::Error; +#[cfg(feature = "use-rustls")] +pub use tls::Error as TlsError; +#[cfg(feature = "use-rustls")] pub use tokio_rustls::rustls::ClientConfig; pub type Incoming = Packet; @@ -92,14 +194,15 @@ impl From for Request { #[derive(Clone)] pub enum Transport { Tcp, + #[cfg(feature = "use-rustls")] Tls(TlsConfiguration), #[cfg(unix)] Unix, #[cfg(feature = "websocket")] #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] Ws, - #[cfg(feature = "websocket")] - #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + #[cfg(all(feature = "use-rustls", feature = "websocket"))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "use-rustls", feature = "websocket"))))] Wss(TlsConfiguration), } @@ -116,6 +219,7 @@ impl Transport { } /// Use secure tcp with tls as transport + #[cfg(feature = "use-rustls")] pub fn tls( ca: Vec, client_auth: Option<(Vec, Key)>, @@ -130,6 +234,7 @@ impl Transport { Self::tls_with_config(config) } + #[cfg(feature = "use-rustls")] pub fn tls_with_config(tls_config: TlsConfiguration) -> Self { Self::Tls(tls_config) } @@ -147,8 +252,8 @@ impl Transport { } /// Use secure websockets with tls as transport - #[cfg(feature = "websocket")] - #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + #[cfg(all(feature = "use-rustls", feature = "websocket"))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "use-rustls", feature = "websocket"))))] pub fn wss( ca: Vec, client_auth: Option<(Vec, Key)>, @@ -163,14 +268,15 @@ impl Transport { Self::wss_with_config(config) } - #[cfg(feature = "websocket")] - #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + #[cfg(all(feature = "use-rustls", feature = "websocket"))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "use-rustls", feature = "websocket"))))] pub fn wss_with_config(tls_config: TlsConfiguration) -> Self { Self::Wss(tls_config) } } #[derive(Clone)] +#[cfg(feature = "use-rustls")] pub enum TlsConfiguration { Simple { /// connection method @@ -184,6 +290,7 @@ pub enum TlsConfiguration { Rustls(Arc), } +#[cfg(feature = "use-rustls")] impl From for TlsConfiguration { fn from(config: ClientConfig) -> Self { TlsConfiguration::Rustls(Arc::new(config)) @@ -613,7 +720,7 @@ mod test { } #[test] - #[cfg(feature = "websocket")] + #[cfg(all(feature = "use-rustls", feature = "websocket"))] fn no_scheme() { let mut _mqtt_opts = MqttOptions::new("client_a", "a3f8czas.iot.eu-west-1.amazonaws.com/mqtt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=MyCreds%2F20201001%2Feu-west-1%2Fiotdevicegateway%2Faws4_request&X-Amz-Date=20201001T130812Z&X-Amz-Expires=7200&X-Amz-Signature=9ae09b49896f44270f2707551581953e6cac71a4ccf34c7c3415555be751b2d1&X-Amz-SignedHeaders=host", 443); diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index d8b8ad4cf..a28fd1e23 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -1,7 +1,9 @@ use crate::v5::{ - framed::Network, packet::*, tls, Incoming, MqttOptions, MqttState, Outgoing, Packet, Request, + framed::Network, packet::*, Incoming, MqttOptions, MqttState, Outgoing, Packet, Request, StateError, Transport, }; +#[cfg(feature = "use-rustls")] +use crate::v5::tls; #[cfg(feature = "websocket")] use async_tungstenite::tokio::{connect_async, connect_async_with_tls_connector}; @@ -34,6 +36,7 @@ pub enum ConnectionError { Timeout(#[from] Elapsed), #[error("Packet parsing error: {0}")] Mqtt5Bytes(Error), + #[cfg(feature = "use-rustls")] #[error("Network: {0}")] Network(#[from] tls::Error), #[error("I/O: {0}")] @@ -269,8 +272,9 @@ async fn network_connect(options: &MqttOptions) -> Result { - let socket = tls::tls_connect(&options, &tls_config).await?; + let socket = tls::tls_connect(options, &tls_config).await?; Network::new(socket, options.max_incoming_packet_size) } #[cfg(unix)] diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index d8d2db4c1..943a14f9b 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -1,4 +1,5 @@ use std::fmt::{self, Debug, Formatter}; +#[cfg(feature = "use-rustls")] use std::sync::Arc; use std::time::Duration; @@ -7,6 +8,7 @@ mod eventloop; mod framed; mod packet; mod state; +#[cfg(feature = "use-rustls")] mod tls; pub use client::{AsyncClient, Client, ClientError, Connection}; @@ -14,7 +16,9 @@ pub use eventloop::{ConnectionError, Event, EventLoop}; pub use flume::{SendError, Sender, TrySendError}; pub use packet::*; pub use state::{MqttState, StateError}; +#[cfg(feature = "use-rustls")] pub use tls::Error; +#[cfg(feature = "use-rustls")] pub use tokio_rustls::rustls::ClientConfig; pub type Incoming = Packet; @@ -92,14 +96,15 @@ impl From for Request { #[derive(Clone)] pub enum Transport { Tcp, +#[cfg(feature = "use-rustls")] Tls(TlsConfiguration), #[cfg(unix)] Unix, #[cfg(feature = "websocket")] #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] Ws, - #[cfg(feature = "websocket")] - #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + #[cfg(all(feature = "use-rustls", feature = "websocket"))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "use-rustls", feature = "websocket"))))] Wss(TlsConfiguration), } @@ -116,6 +121,7 @@ impl Transport { } /// Use secure tcp with tls as transport + #[cfg(feature = "use-rustls")] pub fn tls( ca: Vec, client_auth: Option<(Vec, Key)>, @@ -130,6 +136,7 @@ impl Transport { Self::tls_with_config(config) } + #[cfg(feature = "use-rustls")] pub fn tls_with_config(tls_config: TlsConfiguration) -> Self { Self::Tls(tls_config) } @@ -147,8 +154,8 @@ impl Transport { } /// Use secure websockets with tls as transport - #[cfg(feature = "websocket")] - #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + #[cfg(all(feature = "use-rustls", feature = "websocket"))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "use-rustls", feature = "websocket"))))] pub fn wss( ca: Vec, client_auth: Option<(Vec, Key)>, @@ -163,14 +170,15 @@ impl Transport { Self::wss_with_config(config) } - #[cfg(feature = "websocket")] - #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + #[cfg(all(feature = "use-rustls", feature = "websocket"))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "use-rustls", feature = "websocket"))))] pub fn wss_with_config(tls_config: TlsConfiguration) -> Self { Self::Wss(tls_config) } } #[derive(Clone)] +#[cfg(feature = "use-rustls")] pub enum TlsConfiguration { Simple { /// connection method @@ -184,6 +192,7 @@ pub enum TlsConfiguration { Rustls(Arc), } +#[cfg(feature = "use-rustls")] impl From for TlsConfiguration { fn from(config: ClientConfig) -> Self { TlsConfiguration::Rustls(Arc::new(config)) @@ -233,7 +242,6 @@ pub struct MqttOptions { /// If set to `true` MQTT acknowledgements are not sent automatically. /// Every incoming publish packet must be manually acknowledged with `client.ack(...)` method. manual_acks: bool, - override_slow_subs: bool, } impl MqttOptions { @@ -261,7 +269,6 @@ impl MqttOptions { last_will: None, conn_timeout: 5, manual_acks: false, - override_slow_subs: false, } } @@ -410,14 +417,6 @@ impl MqttOptions { pub fn manual_acks(&self) -> bool { self.manual_acks } - - pub fn set_override_slow_subs(&mut self, val: bool) { - self.override_slow_subs = val; - } - - pub fn override_slow_subs(&self) -> bool { - self.override_slow_subs - } } #[cfg(feature = "url")] @@ -623,7 +622,7 @@ mod test { } #[test] - #[cfg(feature = "websocket")] + #[cfg(all(feature = "use-rustls", feature = "websocket"))] fn no_scheme() { let mut _mqtt_opts = MqttOptions::new("client_a", "a3f8czas.iot.eu-west-1.amazonaws.com/mqtt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=MyCreds%2F20201001%2Feu-west-1%2Fiotdevicegateway%2Faws4_request&X-Amz-Date=20201001T130812Z&X-Amz-Expires=7200&X-Amz-Signature=9ae09b49896f44270f2707551581953e6cac71a4ccf34c7c3415555be751b2d1&X-Amz-SignedHeaders=host", 443); From 2d56f3e912284853182948cf81d54a4ae3f9772e Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Thu, 24 Mar 2022 21:26:14 +0530 Subject: [PATCH 30/38] rumqttc: client: rollback pub-sub seperation Signed-off-by: Abhik Jain --- rumqttc/src/v5/client/asyncclient.rs | 26 +----- rumqttc/src/v5/client/mod.rs | 4 - rumqttc/src/v5/client/publisher.rs | 116 ------------------------- rumqttc/src/v5/client/subscriber.rs | 121 --------------------------- 4 files changed, 1 insertion(+), 266 deletions(-) delete mode 100644 rumqttc/src/v5/client/publisher.rs delete mode 100644 rumqttc/src/v5/client/subscriber.rs diff --git a/rumqttc/src/v5/client/asyncclient.rs b/rumqttc/src/v5/client/asyncclient.rs index f7e1f3783..5483da737 100644 --- a/rumqttc/src/v5/client/asyncclient.rs +++ b/rumqttc/src/v5/client/asyncclient.rs @@ -10,7 +10,7 @@ use bytes::Bytes; use flume::{SendError, Sender, TrySendError}; use crate::v5::{ - client::{get_ack_req, Publisher, Subscriber}, + client::get_ack_req, packet::{Publish, Subscribe, SubscribeFilter, Unsubscribe}, ClientError, EventLoop, MqttOptions, QoS, Request, }; @@ -233,30 +233,6 @@ impl AsyncClient { self.incoming_buf_cache.pop_front() } - pub async fn split( - self, - publish_topic: impl Into, - publish_qos: QoS, - ) -> Result<(Publisher, Subscriber), ClientError> { - let publisher = Publisher { - incoming_buf: self.incoming_buf.clone(), - incoming_buf_capacity: self.incoming_buf_capacity, - pkid_counter: self.pkid_counter, - max_inflight: self.max_inflight, - request_tx: self.request_tx.clone(), - publish_topic: publish_topic.into(), - publish_qos, - }; - let subscriber = Subscriber { - outgoing_buf: self.incoming_buf, - incoming_buf: self.outgoing_buf, - incoming_buf_cache: self.incoming_buf_cache, - request_buf_capacity: self.incoming_buf_capacity, - request_tx: self.request_tx, - }; - Ok((publisher, subscriber)) - } - fn increment_pkid(&self) -> u16 { let mut cur_pkid = self.pkid_counter.load(Ordering::SeqCst); loop { diff --git a/rumqttc/src/v5/client/mod.rs b/rumqttc/src/v5/client/mod.rs index 3156126df..cb47da36e 100644 --- a/rumqttc/src/v5/client/mod.rs +++ b/rumqttc/src/v5/client/mod.rs @@ -8,10 +8,6 @@ use tokio::runtime::{self, Runtime}; mod asyncclient; pub use asyncclient::AsyncClient; -mod publisher; -pub use publisher::Publisher; -mod subscriber; -pub use subscriber::Subscriber; mod syncclient; pub use syncclient::Client; diff --git a/rumqttc/src/v5/client/publisher.rs b/rumqttc/src/v5/client/publisher.rs deleted file mode 100644 index 3b5c07f2d..000000000 --- a/rumqttc/src/v5/client/publisher.rs +++ /dev/null @@ -1,116 +0,0 @@ -use std::{ - collections::VecDeque, - sync::{Arc, atomic::{AtomicU16, Ordering}, Mutex}, -}; - -use bytes::Bytes; -use flume::{SendError, Sender, TrySendError}; - -use crate::v5::{packet::Publish, ClientError, QoS, Request}; - -pub struct Publisher { - pub(crate) incoming_buf: Arc>>, - pub(crate) incoming_buf_capacity: usize, - pub(crate) pkid_counter: Arc, - pub(crate) max_inflight: u16, - pub(crate) request_tx: Sender<()>, - pub(crate) publish_topic: String, - pub(crate) publish_qos: QoS, -} - -impl Publisher { - /// Sends a MQTT Publish to the eventloop - pub async fn publish( - &self, - retain: bool, - payload: impl Into>, - ) -> Result { - let mut publish = Publish::new(&self.publish_topic, self.publish_qos, payload); - publish.retain = retain; - let pkid = self.increment_pkid(); - publish.pkid = pkid; - self.send_async_and_notify(Request::Publish(publish)) - .await?; - Ok(pkid) - } - - /// Sends a MQTT Publish to the eventloop - pub fn try_publish( - &self, - retain: bool, - payload: impl Into>, - ) -> Result { - let mut publish = Publish::new(&self.publish_topic, self.publish_qos, payload); - publish.retain = retain; - let pkid = self.increment_pkid(); - publish.pkid = pkid; - self.try_send_and_notify(Request::Publish(publish))?; - Ok(pkid) - } - - /// Sends a MQTT Publish to the eventloop - pub async fn publish_bytes(&self, retain: bool, payload: Bytes) -> Result { - let mut publish = Publish::from_bytes(&self.publish_topic, self.publish_qos, payload); - let pkid = self.increment_pkid(); - publish.pkid = pkid; - publish.retain = retain; - self.send_async_and_notify(Request::Publish(publish)).await?; - Ok(pkid) - } - - async fn send_async_and_notify(&self, request: Request) -> Result<(), ClientError> { - { - let mut request_buf = self.incoming_buf.lock().unwrap(); - if request_buf.len() == self.incoming_buf_capacity { - return Err(ClientError::RequestsFull); - } - request_buf.push_back(request); - } - if let Err(SendError(_)) = self.request_tx.send_async(()).await { - return Err(ClientError::EventloopClosed); - }; - Ok(()) - } - - /// Sends a MQTT disconnect to the eventloop - pub async fn disconnect(&self) -> Result<(), ClientError> { - self.send_async_and_notify(Request::Disconnect).await - } - - /// Sends a MQTT disconnect to the eventloop - pub fn try_disconnect(&self) -> Result<(), ClientError> { - self.try_send_and_notify(Request::Disconnect) - } - - fn try_send_and_notify(&self, request: Request) -> Result<(), ClientError> { - let mut request_buf = self.incoming_buf.lock().unwrap(); - if request_buf.len() == self.incoming_buf_capacity { - return Err(ClientError::RequestsFull); - } - request_buf.push_back(request); - if let Err(TrySendError::Disconnected(_)) = self.request_tx.try_send(()) { - return Err(ClientError::EventloopClosed); - } - Ok(()) - } - - fn increment_pkid(&self) -> u16 { - let mut cur_pkid = self.pkid_counter.load(Ordering::SeqCst); - loop { - let new_pkid = if cur_pkid > self.max_inflight { - 1 - } else { - cur_pkid + 1 - }; - match self.pkid_counter.compare_exchange( - cur_pkid, - new_pkid, - Ordering::SeqCst, - Ordering::Relaxed, - ) { - Ok(_prev_pkid) => break new_pkid, - Err(actual_pkid) => cur_pkid = actual_pkid, - } - } - } -} diff --git a/rumqttc/src/v5/client/subscriber.rs b/rumqttc/src/v5/client/subscriber.rs deleted file mode 100644 index d1982bddd..000000000 --- a/rumqttc/src/v5/client/subscriber.rs +++ /dev/null @@ -1,121 +0,0 @@ -use std::{ - collections::VecDeque, - sync::{Arc, Mutex}, -}; - -use flume::{SendError, Sender, TrySendError}; - -use crate::v5::{ - client::get_ack_req, ClientError, Publish, QoS, Request, Subscribe, SubscribeFilter, - Unsubscribe, -}; - -#[derive(Debug, Clone)] -pub struct Subscriber { - pub(crate) outgoing_buf: Arc>>, - pub(crate) incoming_buf: Arc>>, - pub(crate) incoming_buf_cache: VecDeque, - pub(crate) request_buf_capacity: usize, - pub(crate) request_tx: Sender<()>, -} - -impl Subscriber { - pub fn next_publish(&mut self) -> Option { - if let Some(publish) = self.incoming_buf_cache.pop_front() { - return Some(publish); - } - - std::mem::swap( - &mut self.incoming_buf_cache, - &mut *self.incoming_buf.lock().unwrap(), - ); - self.incoming_buf_cache.pop_front() - } - - /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. - pub async fn ack(&self, qos: QoS, pkid: u16) -> Result<(), ClientError> { - if let Some(ack) = get_ack_req(qos, pkid) { - self.send_async_and_notify(ack).await?; - } - Ok(()) - } - - /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. - pub fn try_ack(&self, qos: QoS, pkid: u16) -> Result<(), ClientError> { - if let Some(ack) = get_ack_req(qos, pkid) { - self.try_send_and_notify(ack)?; - } - Ok(()) - } - - /// Sends a MQTT Subscribe to the eventloop - pub async fn subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { - let subscribe = Subscribe::new(topic.into(), qos); - self.send_async_and_notify(Request::Subscribe(subscribe)) - .await - } - - /// Sends a MQTT Subscribe to the eventloop - pub fn try_subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { - let subscribe = Subscribe::new(topic.into(), qos); - self.try_send_and_notify(Request::Subscribe(subscribe)) - } - - /// Sends a MQTT Subscribe for multiple topics to the eventloop - pub async fn subscribe_many(&self, topics: T) -> Result<(), ClientError> - where - T: IntoIterator, - { - let subscribe = Subscribe::new_many(topics); - self.send_async_and_notify(Request::Subscribe(subscribe)) - .await - } - - /// Sends a MQTT Subscribe for multiple topics to the eventloop - pub fn try_subscribe_many(&self, topics: T) -> Result<(), ClientError> - where - T: IntoIterator, - { - let subscribe = Subscribe::new_many(topics); - self.try_send_and_notify(Request::Subscribe(subscribe)) - } - - /// Sends a MQTT Unsubscribe to the eventloop - pub async fn unsubscribe>(&self, topic: S) -> Result<(), ClientError> { - let unsubscribe = Unsubscribe::new(topic.into()); - self.send_async_and_notify(Request::Unsubscribe(unsubscribe)) - .await - } - - /// Sends a MQTT Unsubscribe to the eventloop - pub fn try_unsubscribe>(&self, topic: S) -> Result<(), ClientError> { - let unsubscribe = Unsubscribe::new(topic.into()); - self.try_send_and_notify(Request::Unsubscribe(unsubscribe)) - } - - async fn send_async_and_notify(&self, request: Request) -> Result<(), ClientError> { - { - let mut request_buf = self.outgoing_buf.lock().unwrap(); - if request_buf.len() == self.request_buf_capacity { - return Err(ClientError::RequestsFull); - } - request_buf.push_back(request); - } - if let Err(SendError(_)) = self.request_tx.send_async(()).await { - return Err(ClientError::EventloopClosed); - }; - Ok(()) - } - - fn try_send_and_notify(&self, request: Request) -> Result<(), ClientError> { - let mut request_buf = self.outgoing_buf.lock().unwrap(); - if request_buf.len() == self.request_buf_capacity { - return Err(ClientError::RequestsFull); - } - request_buf.push_back(request); - if let Err(TrySendError::Disconnected(_)) = self.request_tx.try_send(()) { - return Err(ClientError::EventloopClosed); - } - Ok(()) - } -} From 6fbf234070069529793c30048193438705edd753 Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Thu, 24 Mar 2022 21:50:31 +0530 Subject: [PATCH 31/38] rumqttc: eventloop: remove `Eventloop.event`, `Event` and `Outgoing` Signed-off-by: Abhik Jain --- rumqttc/examples/async_manual_acks.rs | 22 +++++---- rumqttc/examples/async_manual_acks_v5.rs | 29 ++++++------ rumqttc/src/v5/client/asyncclient.rs | 11 +++-- rumqttc/src/v5/client/mod.rs | 6 +-- rumqttc/src/v5/eventloop.rs | 43 ++++++----------- rumqttc/src/v5/mod.rs | 29 +----------- rumqttc/src/v5/state.rs | 59 ++---------------------- 7 files changed, 58 insertions(+), 141 deletions(-) diff --git a/rumqttc/examples/async_manual_acks.rs b/rumqttc/examples/async_manual_acks.rs index c89112332..000a060b7 100644 --- a/rumqttc/examples/async_manual_acks.rs +++ b/rumqttc/examples/async_manual_acks.rs @@ -1,6 +1,6 @@ use tokio::{task, time}; -use rumqttc::v4::{AsyncClient, Event, EventLoop, Incoming, MqttOptions, QoS}; +use rumqttc::v4::{AsyncClient, EventLoop, MqttOptions, QoS}; use std::error::Error; use std::time::Duration; @@ -44,21 +44,23 @@ async fn main() -> Result<(), Box> { } // create new broker connection - let (client, mut eventloop) = create_conn(); + let (_client, mut eventloop) = create_conn(); loop { // previously published messages should be republished after reconnection. let event = eventloop.poll().await; println!("{:?}", event); - if let Ok(Event::Incoming(Incoming::Publish(publish))) = event { - // this time we will ack incoming publishes. - // Its important not to block eventloop as this can cause deadlock. - let c = client.clone(); - tokio::spawn(async move { - c.ack(&publish).await.unwrap(); - }); - } + todo!("fix the commented out code below") + + // if let Ok(Event::Incoming(Incoming::Publish(publish))) = event { + // // this time we will ack incoming publishes. + // // Its important not to block eventloop as this can cause deadlock. + // let c = client.clone(); + // tokio::spawn(async move { + // c.ack(&publish).await.unwrap(); + // }); + // } } } diff --git a/rumqttc/examples/async_manual_acks_v5.rs b/rumqttc/examples/async_manual_acks_v5.rs index eb01d2e3d..471df0f40 100644 --- a/rumqttc/examples/async_manual_acks_v5.rs +++ b/rumqttc/examples/async_manual_acks_v5.rs @@ -1,6 +1,6 @@ use tokio::{task, time}; -use rumqttc::v5::{AsyncClient, Event, EventLoop, Incoming, MqttOptions, QoS}; +use rumqttc::v5::{AsyncClient, EventLoop, MqttOptions, QoS}; use std::error::Error; use std::time::Duration; @@ -44,23 +44,26 @@ async fn main() -> Result<(), Box> { } // create new broker connection - let (client, mut eventloop) = create_conn(); + let (_client, mut eventloop) = create_conn(); loop { // previously published messages should be republished after reconnection. let event = eventloop.poll().await; println!("{:?}", event); - match event { - Ok(Event::Incoming(Incoming::Publish(publish))) => { - // this time we will ack incoming publishes. - // Its important not to block eventloop as this can cause deadlock. - let c = client.clone(); - tokio::spawn(async move { - c.ack(&publish).await.unwrap(); - }); - } - _ => {} - } + + todo!("fix the commented out code below") + + // match event { + // Ok(Event::Incoming(Incoming::Publish(publish))) => { + // // this time we will ack incoming publishes. + // // Its important not to block eventloop as this can cause deadlock. + // let c = client.clone(); + // tokio::spawn(async move { + // c.ack(&publish).await.unwrap(); + // }); + // } + // _ => {} + // } } } diff --git a/rumqttc/src/v5/client/asyncclient.rs b/rumqttc/src/v5/client/asyncclient.rs index 5483da737..da778daee 100644 --- a/rumqttc/src/v5/client/asyncclient.rs +++ b/rumqttc/src/v5/client/asyncclient.rs @@ -12,7 +12,7 @@ use flume::{SendError, Sender, TrySendError}; use crate::v5::{ client::get_ack_req, packet::{Publish, Subscribe, SubscribeFilter, Unsubscribe}, - ClientError, EventLoop, MqttOptions, QoS, Request, + ClientError, EventLoop, Incoming, MqttOptions, QoS, Request, }; /// `AsyncClient` to communicate with MQTT `Eventloop` @@ -20,10 +20,10 @@ use crate::v5::{ #[derive(Clone, Debug)] pub struct AsyncClient { incoming_buf: Arc>>, - outgoing_buf: Arc>>, + outgoing_buf: Arc>>, pkid_counter: Arc, max_inflight: u16, - incoming_buf_cache: VecDeque, + incoming_buf_cache: VecDeque, incoming_buf_capacity: usize, request_tx: Sender<()>, } @@ -124,7 +124,8 @@ impl AsyncClient { publish.retain = retain; let pkid = self.increment_pkid(); publish.pkid = pkid; - self.push_and_async_notify(Request::Publish(publish)).await?; + self.push_and_async_notify(Request::Publish(publish)) + .await?; Ok(pkid) } @@ -221,7 +222,7 @@ impl AsyncClient { Ok(()) } - pub fn next_publish(&mut self) -> Option { + pub fn next_publish(&mut self) -> Option { if let Some(publish) = self.incoming_buf_cache.pop_front() { return Some(publish); } diff --git a/rumqttc/src/v5/client/mod.rs b/rumqttc/src/v5/client/mod.rs index cb47da36e..811b6889d 100644 --- a/rumqttc/src/v5/client/mod.rs +++ b/rumqttc/src/v5/client/mod.rs @@ -1,6 +1,6 @@ //! This module offers a high level synchronous and asynchronous abstraction to //! async eventloop. -use crate::v5::{packet::*, ConnectionError, Event, EventLoop, Request}; +use crate::v5::{packet::*, ConnectionError, EventLoop, Request}; use flume::SendError; use std::mem; @@ -69,12 +69,12 @@ pub struct Iter<'a> { } impl<'a> Iterator for Iter<'a> { - type Item = Result; + type Item = Result<(), ConnectionError>; fn next(&mut self) -> Option { let f = self.connection.eventloop.poll(); match self.runtime.block_on(f) { - Ok(v) => Some(Ok(v)), + Ok(_) => Some(Ok(())), // closing of request channel should stop the iterator Err(ConnectionError::RequestsDone) => { trace!("Done with requests"); diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index a28fd1e23..74edd7a6c 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -1,5 +1,5 @@ use crate::v5::{ - framed::Network, packet::*, Incoming, MqttOptions, MqttState, Outgoing, Packet, Request, + framed::Network, packet::*, Incoming, MqttOptions, MqttState, Packet, Request, StateError, Transport, }; #[cfg(feature = "use-rustls")] @@ -69,13 +69,6 @@ pub struct EventLoop { pub(crate) keepalive_timeout: Option>>, } -/// Events which can be yielded by the event loop -#[derive(Debug, PartialEq, Clone)] -pub enum Event { - Incoming(Incoming), - Outgoing(Outgoing), -} - impl EventLoop { /// New MQTT `EventLoop` /// @@ -111,7 +104,7 @@ impl EventLoop { &self.incoming_buf } - pub fn sub_events_buf(&self) -> &Arc>> { + pub fn sub_events_buf(&self) -> &Arc>> { &self.state.incoming_buf } @@ -126,29 +119,28 @@ impl EventLoop { /// the broker. Continuing to poll will reconnect to the broker if there is /// a disconnection. /// **NOTE** Don't block this while iterating - pub async fn poll(&mut self) -> Result { + pub async fn poll(&mut self) -> Result<(), ConnectionError> { if self.network.is_none() { - let (network, connack) = connect(&self.options).await?; + let (network, _connack) = connect(&self.options).await?; self.network = Some(network); if self.keepalive_timeout.is_none() { self.keepalive_timeout = Some(Box::pin(time::sleep(self.options.keep_alive))); } - return Ok(Event::Incoming(connack)); + return Ok(()); } - match self.select().await { - Ok(v) => Ok(v), - Err(e) => { - self.clean(); - Err(e) - } + if let Err(e) = self.select().await { + self.clean(); + return Err(e); } + + Ok(()) } /// Select on network and requests and generate keepalive pings when necessary - async fn select(&mut self) -> Result { + async fn select(&mut self) -> Result<(), ConnectionError> { let network = self.network.as_mut().unwrap(); // let await_acks = self.state.await_acks; let inflight_full = self.state.inflight >= self.options.inflight; @@ -156,11 +148,6 @@ impl EventLoop { let pending = self.pending.len() > 0; let collision = self.state.collision.is_some(); - // Read buffered events from previous polls before calling a new poll - if let Some(event) = self.state.events.pop_front() { - return Ok(event); - } - // this loop is necessary as self.request_buf might be empty, in which case it is possible // for self.state.events to be empty, and so popping off from it might return None. If None // is returned, we select again. @@ -171,7 +158,7 @@ impl EventLoop { o?; // flush all the acks and return first incoming packet network.flush(&mut self.state.write).await?; - return Ok(self.state.events.pop_front().unwrap()); + return Ok(()); }, // Pull next request from user requests channel. // If conditions in the below branch are for flow control. We read next user @@ -210,7 +197,7 @@ impl EventLoop { network.flush(&mut self.state.write).await?; // remaining events in the self.state.events will be taken out in next call // to poll() even before the select! is used. - return Ok(self.state.events.pop_front().unwrap()) + return Ok(()) } Err(_) => return Err(ConnectionError::RequestsDone), }, @@ -219,7 +206,7 @@ impl EventLoop { Some(request) = next_pending(throttle, &mut self.pending), if pending => { self.state.handle_outgoing_packet(request)?; network.flush(&mut self.state.write).await?; - return Ok(self.state.events.pop_front().unwrap()) + return Ok(()) }, // We generate pings irrespective of network activity. This keeps the ping logic // simple. We can change this behavior in future if necessary (to prevent extra pings) @@ -229,7 +216,7 @@ impl EventLoop { self.state.handle_outgoing_packet(Request::PingReq)?; network.flush(&mut self.state.write).await?; - return Ok(self.state.events.pop_front().unwrap()) + return Ok(()) } } } diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index 943a14f9b..2f26f4cc2 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -12,7 +12,7 @@ mod state; mod tls; pub use client::{AsyncClient, Client, ClientError, Connection}; -pub use eventloop::{ConnectionError, Event, EventLoop}; +pub use eventloop::{ConnectionError, EventLoop}; pub use flume::{SendError, Sender, TrySendError}; pub use packet::*; pub use state::{MqttState, StateError}; @@ -23,33 +23,6 @@ pub use tokio_rustls::rustls::ClientConfig; pub type Incoming = Packet; -/// Current outgoing activity on the eventloop -#[derive(Debug, Eq, PartialEq, Clone)] -pub enum Outgoing { - /// Publish packet with packet identifier. 0 implies QoS 0 - Publish(u16), - /// Subscribe packet with packet identifier - Subscribe(u16), - /// Unsubscribe packet with packet identifier - Unsubscribe(u16), - /// PubAck packet - PubAck(u16), - /// PubRec packet - PubRec(u16), - /// PubRel packet - PubRel(u16), - /// PubComp packet - PubComp(u16), - /// Ping request packet - PingReq, - /// Ping response packet - PingResp, - /// Disconnect packet - Disconnect, - /// Await for an ack for more outgoing progress - AwaitAck(u16), -} - /// Requests by the client to mqtt event loop. Request are /// handled one by one. #[derive(Clone, Debug, PartialEq)] diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 8ad0e5375..5af3c1b8d 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -1,4 +1,4 @@ -use super::{packet::*, Event, Incoming, Outgoing, Request}; +use super::{packet::*, Incoming, Request}; use bytes::BytesMut; use std::collections::VecDeque; @@ -72,13 +72,11 @@ pub struct MqttState { pub(crate) incoming_pub: Vec>, /// Last collision due to broker not acking in order pub collision: Option, - /// Buffered incoming packets - pub events: VecDeque, /// Write buffer pub write: BytesMut, /// Indicates if acknowledgements should be send immediately pub manual_acks: bool, - pub(crate) incoming_buf: Arc>>, + pub(crate) incoming_buf: Arc>>, pkid_counter: Arc, } @@ -100,7 +98,6 @@ impl MqttState { incoming_pub: vec![None; std::u16::MAX as usize + 1], collision: None, // TODO: Optimize these sizes later - events: VecDeque::with_capacity(100), write: BytesMut::with_capacity(10 * 1024), manual_acks, incoming_buf: Arc::new(Mutex::new(VecDeque::with_capacity(cap))), @@ -213,7 +210,7 @@ impl MqttState { }; out?; - self.events.push_back(Event::Incoming(packet)); + self.incoming_buf.lock().unwrap().push_back(packet); self.last_incoming = Instant::now(); Ok(()) } @@ -249,12 +246,6 @@ impl MqttState { } } - // TODO: maybe limit the capacity of `self.incoming_buf` - self.incoming_buf - .lock() - .unwrap() - .push_back(publish.clone()); - Ok(()) } @@ -275,8 +266,6 @@ impl MqttState { self.inflight += 1; publish.write(&mut self.write)?; - let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); - self.events.push_back(event); self.collision_ping_count = 0; } @@ -289,9 +278,6 @@ impl MqttState { // NOTE: Inflight - 1 for qos2 in comp self.outgoing_rel[pubrec.pkid as usize] = Some(pubrec.pkid); PubRel::new(pubrec.pkid).write(&mut self.write)?; - - let event = Event::Outgoing(Outgoing::PubRel(pubrec.pkid)); - self.events.push_back(event); Ok(()) } None => { @@ -305,8 +291,6 @@ impl MqttState { match mem::replace(&mut self.incoming_pub[pubrel.pkid as usize], None) { Some(_) => { PubComp::new(pubrel.pkid).write(&mut self.write)?; - let event = Event::Outgoing(Outgoing::PubComp(pubrel.pkid)); - self.events.push_back(event); Ok(()) } None => { @@ -319,8 +303,6 @@ impl MqttState { fn handle_incoming_pubcomp(&mut self, pubcomp: &PubComp) -> Result<(), StateError> { if let Some(publish) = self.check_collision(pubcomp.pkid) { publish.write(&mut self.write)?; - let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); - self.events.push_back(event); self.collision_ping_count = 0; } @@ -358,8 +340,6 @@ impl MqttState { { info!("Collision on packet id = {:?}", publish.pkid); self.collision = Some(publish); - let event = Event::Outgoing(Outgoing::AwaitAck(pkid)); - self.events.push_back(event); return Ok(()); } @@ -377,8 +357,6 @@ impl MqttState { ); publish.write(&mut self.write)?; - let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); - self.events.push_back(event); Ok(()) } @@ -387,23 +365,16 @@ impl MqttState { debug!("Pubrel. Pkid = {}", pubrel.pkid); PubRel::new(pubrel.pkid).write(&mut self.write)?; - - let event = Event::Outgoing(Outgoing::PubRel(pubrel.pkid)); - self.events.push_back(event); Ok(()) } fn outgoing_puback(&mut self, puback: PubAck) -> Result<(), StateError> { puback.write(&mut self.write)?; - let event = Event::Outgoing(Outgoing::PubAck(puback.pkid)); - self.events.push_back(event); Ok(()) } fn outgoing_pubrec(&mut self, pubrec: PubRec) -> Result<(), StateError> { pubrec.write(&mut self.write)?; - let event = Event::Outgoing(Outgoing::PubRec(pubrec.pkid)); - self.events.push_back(event); Ok(()) } @@ -437,8 +408,6 @@ impl MqttState { ); PingReq.write(&mut self.write)?; - let event = Event::Outgoing(Outgoing::PingReq); - self.events.push_back(event); Ok(()) } @@ -452,8 +421,6 @@ impl MqttState { ); subscription.write(&mut self.write)?; - let event = Event::Outgoing(Outgoing::Subscribe(subscription.pkid)); - self.events.push_back(event); Ok(()) } @@ -467,8 +434,6 @@ impl MqttState { ); unsub.write(&mut self.write)?; - let event = Event::Outgoing(Outgoing::Unsubscribe(unsub.pkid)); - self.events.push_back(event); Ok(()) } @@ -476,8 +441,6 @@ impl MqttState { debug!("Disconnect"); Disconnect::new().write(&mut self.write)?; - let event = Event::Outgoing(Outgoing::Disconnect); - self.events.push_back(event); Ok(()) } @@ -528,7 +491,7 @@ impl MqttState { #[cfg(test)] mod test { use super::{MqttState, StateError}; - use crate::v5::{packet::*, Event, Incoming, MqttOptions, Outgoing, Request}; + use crate::v5::{packet::*, Incoming, MqttOptions, Request}; fn build_outgoing_publish(qos: QoS) -> Publish { let topic = "hello/world".to_owned(); @@ -640,18 +603,6 @@ mod test { mqtt.handle_incoming_publish(&publish1).unwrap(); mqtt.handle_incoming_publish(&publish2).unwrap(); mqtt.handle_incoming_publish(&publish3).unwrap(); - - if let Event::Outgoing(Outgoing::PubAck(pkid)) = mqtt.events[0] { - assert_eq!(pkid, 2); - } else { - panic!("missing puback") - } - - if let Event::Outgoing(Outgoing::PubRec(pkid)) = mqtt.events[1] { - assert_eq!(pkid, 3); - } else { - panic!("missing PubRec") - } } #[test] @@ -671,7 +622,7 @@ mod test { let pkid = mqtt.incoming_pub[3].unwrap(); assert_eq!(pkid, 3); - assert!(mqtt.events.is_empty()); + assert!(mqtt.incoming_buf.lock().unwrap().is_empty()); } #[test] From b1458fdcaa8107041df7ff2281a6abef2192087c Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Thu, 24 Mar 2022 21:53:07 +0530 Subject: [PATCH 32/38] rumqttc: client: rename `AsyncClient::next_publish()` -> `AsyncClient::next()` Signed-off-by: Abhik Jain --- rumqttc/src/v5/client/asyncclient.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rumqttc/src/v5/client/asyncclient.rs b/rumqttc/src/v5/client/asyncclient.rs index da778daee..d048d9462 100644 --- a/rumqttc/src/v5/client/asyncclient.rs +++ b/rumqttc/src/v5/client/asyncclient.rs @@ -222,7 +222,7 @@ impl AsyncClient { Ok(()) } - pub fn next_publish(&mut self) -> Option { + pub fn next(&mut self) -> Option { if let Some(publish) = self.incoming_buf_cache.pop_front() { return Some(publish); } From abdd8f5da599382c7139f00b1ad101000579286b Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Thu, 24 Mar 2022 22:01:52 +0530 Subject: [PATCH 33/38] rumqttc: state: some inlines Signed-off-by: Abhik Jain --- rumqttc/src/v5/state.rs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 5af3c1b8d..cf9d2a16c 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -166,6 +166,7 @@ impl MqttState { pending } + #[inline] pub fn inflight(&self) -> u16 { self.inflight } @@ -215,10 +216,12 @@ impl MqttState { Ok(()) } + #[inline] fn handle_incoming_suback(&mut self) -> Result<(), StateError> { Ok(()) } + #[inline] fn handle_incoming_unsuback(&mut self) -> Result<(), StateError> { Ok(()) } @@ -318,6 +321,7 @@ impl MqttState { } } + #[inline] fn handle_incoming_pingresp(&mut self) -> Result<(), StateError> { self.await_pingresp = false; Ok(()) @@ -360,6 +364,7 @@ impl MqttState { Ok(()) } + #[inline] fn outgoing_pubrel(&mut self, pubrel: PubRel) -> Result<(), StateError> { let pubrel = self.save_pubrel(pubrel)?; @@ -368,11 +373,13 @@ impl MqttState { Ok(()) } + #[inline] fn outgoing_puback(&mut self, puback: PubAck) -> Result<(), StateError> { puback.write(&mut self.write)?; Ok(()) } + #[inline] fn outgoing_pubrec(&mut self, pubrec: PubRec) -> Result<(), StateError> { pubrec.write(&mut self.write)?; Ok(()) @@ -411,6 +418,7 @@ impl MqttState { Ok(()) } + #[inline] fn outgoing_subscribe(&mut self, mut subscription: Subscribe) -> Result<(), StateError> { let pkid = self.increment_pkid(); subscription.pkid = pkid; @@ -424,6 +432,7 @@ impl MqttState { Ok(()) } + #[inline] fn outgoing_unsubscribe(&mut self, mut unsub: Unsubscribe) -> Result<(), StateError> { let pkid = self.increment_pkid(); unsub.pkid = pkid; @@ -437,6 +446,7 @@ impl MqttState { Ok(()) } + #[inline] fn outgoing_disconnect(&mut self) -> Result<(), StateError> { debug!("Disconnect"); @@ -444,6 +454,7 @@ impl MqttState { Ok(()) } + #[inline] fn check_collision(&mut self, pkid: u16) -> Option { if let Some(publish) = &self.collision { if publish.pkid == pkid { @@ -454,6 +465,7 @@ impl MqttState { None } + #[inline] fn save_pubrel(&mut self, mut pubrel: PubRel) -> Result { let pubrel = match pubrel.pkid { // consider PacketIdentifier(0) as uninitialized packets From c127d3c1a3febd641574398b55030e2e41ee11d9 Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Fri, 25 Mar 2022 11:33:54 +0530 Subject: [PATCH 34/38] rumqttc: ci: ignore packet parsing in cargo clippy Signed-off-by: Abhik Jain --- rumqttc/src/lib.rs | 1 + rumqttc/src/v5/client/asyncclient.rs | 66 +++++++++++++--------------- rumqttc/src/v5/mod.rs | 42 ++++++++++++++++-- rumqttc/src/v5/notifier.rs | 56 +++++++++++++++++++++++ 4 files changed, 126 insertions(+), 39 deletions(-) create mode 100644 rumqttc/src/v5/notifier.rs diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index a579c1be7..e74d6adbf 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -99,6 +99,7 @@ #[macro_use] extern crate log; +#[allow(clippy::all)] pub mod mqttbytes; pub mod v4; pub mod v5; diff --git a/rumqttc/src/v5/client/asyncclient.rs b/rumqttc/src/v5/client/asyncclient.rs index d048d9462..c824f98f5 100644 --- a/rumqttc/src/v5/client/asyncclient.rs +++ b/rumqttc/src/v5/client/asyncclient.rs @@ -12,40 +12,34 @@ use flume::{SendError, Sender, TrySendError}; use crate::v5::{ client::get_ack_req, packet::{Publish, Subscribe, SubscribeFilter, Unsubscribe}, - ClientError, EventLoop, Incoming, MqttOptions, QoS, Request, + ClientError, EventLoop, MqttOptions, QoS, Request, }; /// `AsyncClient` to communicate with MQTT `Eventloop` /// This is cloneable and can be used to asynchronously Publish, Subscribe. #[derive(Clone, Debug)] pub struct AsyncClient { - incoming_buf: Arc>>, - outgoing_buf: Arc>>, - pkid_counter: Arc, - max_inflight: u16, - incoming_buf_cache: VecDeque, - incoming_buf_capacity: usize, - request_tx: Sender<()>, + pub(crate) outgoing_buf: Arc>>, + pub(crate) outgoing_buf_capacity: usize, + pub(crate) pkid_counter: Arc, + pub(crate) max_inflight: u16, + pub(crate) request_tx: Sender<()>, } impl AsyncClient { /// Create a new `AsyncClient` pub fn new(options: MqttOptions, cap: usize) -> (AsyncClient, EventLoop) { let eventloop = EventLoop::new(options, cap); - let request_buf = eventloop.request_buf().clone(); - let incoming_buf = eventloop.state.incoming_buf.clone(); - let incoming_buf_cache = VecDeque::with_capacity(cap); + let outgoing_buf = eventloop.request_buf().clone(); let request_tx = eventloop.handle(); let max_inflight = eventloop.state.max_inflight; let pkid_counter = eventloop.state.pkid_counter().clone(); let client = AsyncClient { - incoming_buf: request_buf, - incoming_buf_capacity: cap, - outgoing_buf: incoming_buf, + outgoing_buf, + outgoing_buf_capacity: cap, pkid_counter, max_inflight, - incoming_buf_cache, request_tx, }; @@ -66,7 +60,11 @@ impl AsyncClient { { let mut publish = Publish::new(topic, qos, payload); publish.retain = retain; - let pkid = self.increment_pkid(); + let pkid = if qos != QoS::AtMostOnce { + self.increment_pkid() + } else { + 0 + }; publish.pkid = pkid; self.push_and_async_notify(Request::Publish(publish)) .await?; @@ -87,7 +85,11 @@ impl AsyncClient { { let mut publish = Publish::new(topic, qos, payload); publish.retain = retain; - let pkid = self.increment_pkid(); + let pkid = if qos != QoS::AtMostOnce { + self.increment_pkid() + } else { + 0 + }; publish.pkid = pkid; self.push_and_try_notify(Request::Publish(publish))?; Ok(pkid) @@ -122,7 +124,11 @@ impl AsyncClient { { let mut publish = Publish::from_bytes(topic, qos, payload); publish.retain = retain; - let pkid = self.increment_pkid(); + let pkid = if qos != QoS::AtMostOnce { + self.increment_pkid() + } else { + 0 + }; publish.pkid = pkid; self.push_and_async_notify(Request::Publish(publish)) .await?; @@ -186,8 +192,8 @@ impl AsyncClient { async fn push_and_async_notify(&self, request: Request) -> Result<(), ClientError> { { - let mut request_buf = self.incoming_buf.lock().unwrap(); - if request_buf.len() == self.incoming_buf_capacity { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.len() == self.outgoing_buf_capacity { return Err(ClientError::RequestsFull); } request_buf.push_back(request); @@ -199,8 +205,8 @@ impl AsyncClient { } pub(crate) fn push_and_notify(&self, request: Request) -> Result<(), ClientError> { - let mut request_buf = self.incoming_buf.lock().unwrap(); - if request_buf.len() == self.incoming_buf_capacity { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.len() == self.outgoing_buf_capacity { return Err(ClientError::RequestsFull); } request_buf.push_back(request); @@ -211,8 +217,8 @@ impl AsyncClient { } fn push_and_try_notify(&self, request: Request) -> Result<(), ClientError> { - let mut request_buf = self.incoming_buf.lock().unwrap(); - if request_buf.len() == self.incoming_buf_capacity { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.len() == self.outgoing_buf_capacity { return Err(ClientError::RequestsFull); } request_buf.push_back(request); @@ -222,18 +228,6 @@ impl AsyncClient { Ok(()) } - pub fn next(&mut self) -> Option { - if let Some(publish) = self.incoming_buf_cache.pop_front() { - return Some(publish); - } - - std::mem::swap( - &mut self.incoming_buf_cache, - &mut *self.outgoing_buf.lock().unwrap(), - ); - self.incoming_buf_cache.pop_front() - } - fn increment_pkid(&self) -> u16 { let mut cur_pkid = self.pkid_counter.load(Ordering::SeqCst); loop { diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index 2f26f4cc2..d9a62d333 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -1,11 +1,16 @@ -use std::fmt::{self, Debug, Formatter}; #[cfg(feature = "use-rustls")] use std::sync::Arc; -use std::time::Duration; +use std::{ + collections::VecDeque, + fmt::{self, Debug, Formatter}, + time::Duration, +}; mod client; mod eventloop; mod framed; +mod notifier; +#[allow(clippy::all)] mod packet; mod state; #[cfg(feature = "use-rustls")] @@ -20,6 +25,7 @@ pub use state::{MqttState, StateError}; pub use tls::Error; #[cfg(feature = "use-rustls")] pub use tokio_rustls::rustls::ClientConfig; +pub use notifier::Notifier; pub type Incoming = Packet; @@ -69,7 +75,7 @@ impl From for Request { #[derive(Clone)] pub enum Transport { Tcp, -#[cfg(feature = "use-rustls")] + #[cfg(feature = "use-rustls")] Tls(TlsConfiguration), #[cfg(unix)] Unix, @@ -584,6 +590,36 @@ impl Debug for MqttOptions { } } +pub async fn connect(options: MqttOptions, cap: usize) -> Result<(AsyncClient, Notifier), ()> { + let mut eventloop = EventLoop::new(options, cap); + let outgoing_buf = eventloop.request_buf().clone(); + let incoming_buf = eventloop.state.incoming_buf.clone(); + let incoming_buf_cache = VecDeque::with_capacity(cap); + let request_tx = eventloop.handle(); + let max_inflight = eventloop.state.max_inflight; + let pkid_counter = eventloop.state.pkid_counter().clone(); + + let client = AsyncClient { + outgoing_buf, + incoming_buf_capacity: cap, + incoming_buf, + pkid_counter, + max_inflight, + incoming_buf_cache, + request_tx, + }; + + tokio::spawn(async move { + loop { + // TODO: maybe do something like retries for some specific errors? or maybe give user + // options to configure these retries? + eventloop.poll().await.unwrap(); + } + }); + + Ok((client, Notifier {})) +} + #[cfg(test)] mod test { use super::*; diff --git a/rumqttc/src/v5/notifier.rs b/rumqttc/src/v5/notifier.rs new file mode 100644 index 000000000..00554fed1 --- /dev/null +++ b/rumqttc/src/v5/notifier.rs @@ -0,0 +1,56 @@ +use std::{ + collections::VecDeque, + mem, + sync::{Arc, Mutex}, +}; + +use crate::v5::Incoming; + +#[derive(Debug)] +pub struct Notifier { + incoming_buf: Arc>>, + incoming_buf_cache: VecDeque, +} + +impl Notifier { + #[inline] + pub(crate) fn new( + incoming_buf: Arc>>, + incoming_buf_cache: VecDeque, + ) -> Self { + Self { + incoming_buf, + incoming_buf_cache, + } + } + + #[inline] + pub fn next(&mut self) -> Option { + match self.incoming_buf_cache.pop_front() { + None => { + mem::replace( + &mut self.incoming_buf_cache, + *self.incoming_buf.lock().unwrap(), + ); + self.incoming_buf_cache.pop_front() + } + val => val, + } + } + + #[inline] + pub fn iter(&mut self) -> NotifierIter<'_> { + NotifierIter(self) + } +} + +pub struct NotifierIter<'a>(&'a mut Notifier); + +impl<'a> Iterator for NotifierIter<'a> { + type Item = Incoming; + + #[inline] + fn next(&mut self) -> Option { + self.0.next() + } +} From 913e5f7d63fad31234e2b4d777b7f2bee6654cad Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Fri, 25 Mar 2022 11:38:45 +0530 Subject: [PATCH 35/38] rumqttc: ci: ignore packet parsing in cargo clippy Signed-off-by: Abhik Jain --- rumqttc/src/v4/tls.rs | 2 +- rumqttc/src/v5/mod.rs | 8 +++----- rumqttc/src/v5/notifier.rs | 4 ++-- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/rumqttc/src/v4/tls.rs b/rumqttc/src/v4/tls.rs index ea1809e23..72eca8aa8 100644 --- a/rumqttc/src/v4/tls.rs +++ b/rumqttc/src/v4/tls.rs @@ -96,7 +96,7 @@ pub async fn tls_connector(tls_config: &TlsConfiguration) -> Result return Err(Error::NoValidCertInChain), }; - let certs = certs.into_iter().map(|cert| Certificate(cert)).collect(); + let certs = certs.into_iter().map(Certificate).collect(); config.with_single_cert(certs, PrivateKey(key))? } else { diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index d9a62d333..6e6d4428a 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -19,13 +19,13 @@ mod tls; pub use client::{AsyncClient, Client, ClientError, Connection}; pub use eventloop::{ConnectionError, EventLoop}; pub use flume::{SendError, Sender, TrySendError}; +pub use notifier::Notifier; pub use packet::*; pub use state::{MqttState, StateError}; #[cfg(feature = "use-rustls")] pub use tls::Error; #[cfg(feature = "use-rustls")] pub use tokio_rustls::rustls::ClientConfig; -pub use notifier::Notifier; pub type Incoming = Packet; @@ -601,11 +601,9 @@ pub async fn connect(options: MqttOptions, cap: usize) -> Result<(AsyncClient, N let client = AsyncClient { outgoing_buf, - incoming_buf_capacity: cap, - incoming_buf, + outgoing_buf_capacity: cap, pkid_counter, max_inflight, - incoming_buf_cache, request_tx, }; @@ -617,7 +615,7 @@ pub async fn connect(options: MqttOptions, cap: usize) -> Result<(AsyncClient, N } }); - Ok((client, Notifier {})) + Ok((client, Notifier::new(incoming_buf, incoming_buf_cache))) } #[cfg(test)] diff --git a/rumqttc/src/v5/notifier.rs b/rumqttc/src/v5/notifier.rs index 00554fed1..4aad90d16 100644 --- a/rumqttc/src/v5/notifier.rs +++ b/rumqttc/src/v5/notifier.rs @@ -28,9 +28,9 @@ impl Notifier { pub fn next(&mut self) -> Option { match self.incoming_buf_cache.pop_front() { None => { - mem::replace( + mem::swap( &mut self.incoming_buf_cache, - *self.incoming_buf.lock().unwrap(), + &mut *self.incoming_buf.lock().unwrap(), ); self.incoming_buf_cache.pop_front() } From cefda92daf8414d59a1f409255444d5f6c316a7d Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Sun, 27 Mar 2022 09:20:58 +0530 Subject: [PATCH 36/38] rumqttc: handle pkids at client side Signed-off-by: Abhik Jain --- rumqttc/examples/async_manual_acks_v5.rs | 86 ++-- rumqttc/examples/asyncpubsub_v5.rs | 2 +- rumqttc/src/v5/client/asyncclient.rs | 114 +++-- rumqttc/src/v5/client/syncclient.rs | 36 +- rumqttc/src/v5/mod.rs | 3 +- rumqttc/src/v5/state.rs | 620 +++++++++++------------ 6 files changed, 416 insertions(+), 445 deletions(-) diff --git a/rumqttc/examples/async_manual_acks_v5.rs b/rumqttc/examples/async_manual_acks_v5.rs index 471df0f40..fde55ad81 100644 --- a/rumqttc/examples/async_manual_acks_v5.rs +++ b/rumqttc/examples/async_manual_acks_v5.rs @@ -1,3 +1,4 @@ +#![allow(dead_code, unused_imports)] use tokio::{task, time}; use rumqttc::v5::{AsyncClient, EventLoop, MqttOptions, QoS}; @@ -16,58 +17,59 @@ fn create_conn() -> (AsyncClient, EventLoop) { #[tokio::main(worker_threads = 1)] async fn main() -> Result<(), Box> { - pretty_env_logger::init(); + todo!("fix this example with new way of spawning clients") + // pretty_env_logger::init(); - // create mqtt connection with clean_session = false and manual_acks = true - let (client, mut eventloop) = create_conn(); + // // create mqtt connection with clean_session = false and manual_acks = true + // let (client, mut eventloop) = create_conn(); - // subscribe example topic - client - .subscribe("hello/world", QoS::AtLeastOnce) - .await - .unwrap(); + // // subscribe example topic + // client + // .subscribe("hello/world", QoS::AtLeastOnce) + // .await + // .unwrap(); - task::spawn(async move { - // send some messages to example topic and disconnect - requests(client.clone()).await; - client.disconnect().await.unwrap() - }); + // task::spawn(async move { + // // send some messages to example topic and disconnect + // requests(client.clone()).await; + // client.disconnect().await.unwrap() + // }); - loop { - // get subscribed messages without acking - let event = eventloop.poll().await; - println!("{:?}", event); - if let Err(_err) = event { - // break loop on disconnection - break; - } - } + // loop { + // // get subscribed messages without acking + // let event = eventloop.poll().await; + // println!("{:?}", event); + // if let Err(_err) = event { + // // break loop on disconnection + // break; + // } + // } - // create new broker connection - let (_client, mut eventloop) = create_conn(); + // // create new broker connection + // let (_client, mut eventloop) = create_conn(); - loop { - // previously published messages should be republished after reconnection. - let event = eventloop.poll().await; - println!("{:?}", event); + // loop { + // // previously published messages should be republished after reconnection. + // let event = eventloop.poll().await; + // println!("{:?}", event); - todo!("fix the commented out code below") + // todo!("fix the commented out code below") - // match event { - // Ok(Event::Incoming(Incoming::Publish(publish))) => { - // // this time we will ack incoming publishes. - // // Its important not to block eventloop as this can cause deadlock. - // let c = client.clone(); - // tokio::spawn(async move { - // c.ack(&publish).await.unwrap(); - // }); - // } - // _ => {} - // } - } + // // match event { + // // Ok(Event::Incoming(Incoming::Publish(publish))) => { + // // // this time we will ack incoming publishes. + // // // Its important not to block eventloop as this can cause deadlock. + // // let c = client.clone(); + // // tokio::spawn(async move { + // // c.ack(&publish).await.unwrap(); + // // }); + // // } + // // _ => {} + // // } + // } } -async fn requests(client: AsyncClient) { +async fn requests(mut client: AsyncClient) { for i in 1..=10 { client .publish("hello/world", QoS::AtLeastOnce, false, vec![1; i]) diff --git a/rumqttc/examples/asyncpubsub_v5.rs b/rumqttc/examples/asyncpubsub_v5.rs index b0b7a6698..b398a5e3f 100644 --- a/rumqttc/examples/asyncpubsub_v5.rs +++ b/rumqttc/examples/asyncpubsub_v5.rs @@ -24,7 +24,7 @@ async fn main() -> Result<(), Box> { } } -async fn requests(client: AsyncClient) { +async fn requests(mut client: AsyncClient) { client .subscribe("hello/world", QoS::AtMostOnce) .await diff --git a/rumqttc/src/v5/client/asyncclient.rs b/rumqttc/src/v5/client/asyncclient.rs index c824f98f5..798439e79 100644 --- a/rumqttc/src/v5/client/asyncclient.rs +++ b/rumqttc/src/v5/client/asyncclient.rs @@ -1,9 +1,6 @@ use std::{ collections::VecDeque, - sync::{ - atomic::{AtomicU16, Ordering}, - Arc, Mutex, - }, + sync::{Arc, Mutex}, }; use bytes::Bytes; @@ -17,11 +14,11 @@ use crate::v5::{ /// `AsyncClient` to communicate with MQTT `Eventloop` /// This is cloneable and can be used to asynchronously Publish, Subscribe. -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct AsyncClient { pub(crate) outgoing_buf: Arc>>, pub(crate) outgoing_buf_capacity: usize, - pub(crate) pkid_counter: Arc, + pub(crate) pkid_counter: u16, pub(crate) max_inflight: u16, pub(crate) request_tx: Sender<()>, } @@ -33,12 +30,11 @@ impl AsyncClient { let outgoing_buf = eventloop.request_buf().clone(); let request_tx = eventloop.handle(); let max_inflight = eventloop.state.max_inflight; - let pkid_counter = eventloop.state.pkid_counter().clone(); let client = AsyncClient { outgoing_buf, outgoing_buf_capacity: cap, - pkid_counter, + pkid_counter: 0, max_inflight, request_tx, }; @@ -48,7 +44,7 @@ impl AsyncClient { /// Sends a MQTT Publish to the eventloop pub async fn publish( - &self, + &mut self, topic: S, qos: QoS, retain: bool, @@ -73,7 +69,7 @@ impl AsyncClient { /// Sends a MQTT Publish to the eventloop pub fn try_publish( - &self, + &mut self, topic: S, qos: QoS, retain: bool, @@ -96,7 +92,7 @@ impl AsyncClient { } /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. - pub async fn ack(&self, publish: &Publish) -> Result<(), ClientError> { + pub async fn ack(&mut self, publish: &Publish) -> Result<(), ClientError> { if let Some(ack) = get_ack_req(publish.qos, publish.pkid) { self.push_and_async_notify(ack).await?; } @@ -104,7 +100,7 @@ impl AsyncClient { } /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. - pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { + pub fn try_ack(&mut self, publish: &Publish) -> Result<(), ClientError> { if let Some(ack) = get_ack_req(publish.qos, publish.pkid) { self.push_and_try_notify(ack)?; } @@ -113,7 +109,7 @@ impl AsyncClient { /// Sends a MQTT Publish to the eventloop pub async fn publish_bytes( - &self, + &mut self, topic: S, qos: QoS, retain: bool, @@ -136,57 +132,83 @@ impl AsyncClient { } /// Sends a MQTT Subscribe to the eventloop - pub async fn subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { - let subscribe = Subscribe::new(topic.into(), qos); + pub async fn subscribe>( + &mut self, + topic: S, + qos: QoS, + ) -> Result { + let mut subscribe = Subscribe::new(topic.into(), qos); + let pkid = self.increment_pkid(); + subscribe.pkid = pkid; self.push_and_async_notify(Request::Subscribe(subscribe)) - .await + .await?; + Ok(pkid) } /// Sends a MQTT Subscribe to the eventloop - pub fn try_subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { - let subscribe = Subscribe::new(topic.into(), qos); - self.push_and_try_notify(Request::Subscribe(subscribe)) + pub fn try_subscribe>( + &mut self, + topic: S, + qos: QoS, + ) -> Result { + let mut subscribe = Subscribe::new(topic.into(), qos); + let pkid = self.increment_pkid(); + subscribe.pkid = pkid; + self.push_and_try_notify(Request::Subscribe(subscribe))?; + Ok(pkid) } /// Sends a MQTT Subscribe for multiple topics to the eventloop - pub async fn subscribe_many(&self, topics: T) -> Result<(), ClientError> + pub async fn subscribe_many(&mut self, topics: T) -> Result where T: IntoIterator, { - let subscribe = Subscribe::new_many(topics); + let mut subscribe = Subscribe::new_many(topics); + let pkid = self.increment_pkid(); + subscribe.pkid = pkid; self.push_and_async_notify(Request::Subscribe(subscribe)) - .await + .await?; + Ok(pkid) } /// Sends a MQTT Subscribe for multiple topics to the eventloop - pub fn try_subscribe_many(&self, topics: T) -> Result<(), ClientError> + pub fn try_subscribe_many(&mut self, topics: T) -> Result where T: IntoIterator, { - let subscribe = Subscribe::new_many(topics); - self.push_and_try_notify(Request::Subscribe(subscribe)) + let mut subscribe = Subscribe::new_many(topics); + let pkid = self.increment_pkid(); + subscribe.pkid = pkid; + self.push_and_try_notify(Request::Subscribe(subscribe))?; + Ok(pkid) } /// Sends a MQTT Unsubscribe to the eventloop - pub async fn unsubscribe>(&self, topic: S) -> Result<(), ClientError> { - let unsubscribe = Unsubscribe::new(topic.into()); + pub async fn unsubscribe>(&mut self, topic: S) -> Result { + let mut unsubscribe = Unsubscribe::new(topic.into()); + let pkid = self.increment_pkid(); + unsubscribe.pkid = pkid; self.push_and_async_notify(Request::Unsubscribe(unsubscribe)) - .await + .await?; + Ok(pkid) } /// Sends a MQTT Unsubscribe to the eventloop - pub fn try_unsubscribe>(&self, topic: S) -> Result<(), ClientError> { - let unsubscribe = Unsubscribe::new(topic.into()); - self.push_and_try_notify(Request::Unsubscribe(unsubscribe)) + pub fn try_unsubscribe>(&mut self, topic: S) -> Result { + let mut unsubscribe = Unsubscribe::new(topic.into()); + let pkid = self.increment_pkid(); + unsubscribe.pkid = pkid; + self.push_and_try_notify(Request::Unsubscribe(unsubscribe))?; + Ok(pkid) } /// Sends a MQTT disconnect to the eventloop - pub async fn disconnect(&self) -> Result<(), ClientError> { + pub async fn disconnect(&mut self) -> Result<(), ClientError> { self.push_and_async_notify(Request::Disconnect).await } /// Sends a MQTT disconnect to the eventloop - pub fn try_disconnect(&self) -> Result<(), ClientError> { + pub fn try_disconnect(&mut self) -> Result<(), ClientError> { self.push_and_try_notify(Request::Disconnect) } @@ -228,23 +250,13 @@ impl AsyncClient { Ok(()) } - fn increment_pkid(&self) -> u16 { - let mut cur_pkid = self.pkid_counter.load(Ordering::SeqCst); - loop { - let new_pkid = if cur_pkid > self.max_inflight { - 1 - } else { - cur_pkid + 1 - }; - match self.pkid_counter.compare_exchange( - cur_pkid, - new_pkid, - Ordering::SeqCst, - Ordering::Relaxed, - ) { - Ok(_prev_pkid) => break new_pkid, - Err(actual_pkid) => cur_pkid = actual_pkid, - } - } + #[inline] + pub(crate) fn increment_pkid(&mut self) -> u16 { + self.pkid_counter = if self.pkid_counter == self.max_inflight { + 1 + } else { + self.pkid_counter + 1 + }; + self.pkid_counter } } diff --git a/rumqttc/src/v5/client/syncclient.rs b/rumqttc/src/v5/client/syncclient.rs index e517db9ff..ae94a1119 100644 --- a/rumqttc/src/v5/client/syncclient.rs +++ b/rumqttc/src/v5/client/syncclient.rs @@ -10,7 +10,6 @@ use crate::v5::{ /// /// Client is cloneable and can be used to synchronously Publish, Subscribe. /// Asynchronous channel handle can also be extracted if necessary -#[derive(Clone)] pub struct Client { client: AsyncClient, } @@ -71,14 +70,17 @@ impl Client { } /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. - pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { + pub fn try_ack(&mut self, publish: &Publish) -> Result<(), ClientError> { self.client.try_ack(publish) } /// Sends a MQTT Subscribe to the eventloop - pub fn subscribe>(&mut self, topic: S, qos: QoS) -> Result<(), ClientError> { - let subscribe = Subscribe::new(topic.into(), qos); - self.client.push_and_notify(Request::Subscribe(subscribe)) + pub fn subscribe>(&mut self, topic: S, qos: QoS) -> Result { + let mut subscribe = Subscribe::new(topic.into(), qos); + let pkid = self.client.increment_pkid(); + subscribe.pkid = pkid; + self.client.push_and_notify(Request::Subscribe(subscribe))?; + Ok(pkid) } /// Sends a MQTT Subscribe to the eventloop @@ -86,20 +88,23 @@ impl Client { &mut self, topic: S, qos: QoS, - ) -> Result<(), ClientError> { + ) -> Result { self.client.try_subscribe(topic, qos) } /// Sends a MQTT Subscribe for multiple topics to the eventloop - pub fn subscribe_many(&mut self, topics: T) -> Result<(), ClientError> + pub fn subscribe_many(&mut self, topics: T) -> Result where T: IntoIterator, { - let subscribe = Subscribe::new_many(topics); - self.client.push_and_notify(Request::Subscribe(subscribe)) + let mut subscribe = Subscribe::new_many(topics); + let pkid = self.client.increment_pkid(); + subscribe.pkid = pkid; + self.client.push_and_notify(Request::Subscribe(subscribe))?; + Ok(pkid) } - pub fn try_subscribe_many(&mut self, topics: T) -> Result<(), ClientError> + pub fn try_subscribe_many(&mut self, topics: T) -> Result where T: IntoIterator, { @@ -107,14 +112,17 @@ impl Client { } /// Sends a MQTT Unsubscribe to the eventloop - pub fn unsubscribe>(&mut self, topic: S) -> Result<(), ClientError> { - let unsubscribe = Unsubscribe::new(topic.into()); + pub fn unsubscribe>(&mut self, topic: S) -> Result { + let mut unsubscribe = Unsubscribe::new(topic.into()); + let pkid = self.client.increment_pkid(); + unsubscribe.pkid = pkid; self.client - .push_and_notify(Request::Unsubscribe(unsubscribe)) + .push_and_notify(Request::Unsubscribe(unsubscribe))?; + Ok(pkid) } /// Sends a MQTT Unsubscribe to the eventloop - pub fn try_unsubscribe>(&mut self, topic: S) -> Result<(), ClientError> { + pub fn try_unsubscribe>(&mut self, topic: S) -> Result { self.client.try_unsubscribe(topic) } diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index 6e6d4428a..615d18148 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -597,12 +597,11 @@ pub async fn connect(options: MqttOptions, cap: usize) -> Result<(AsyncClient, N let incoming_buf_cache = VecDeque::with_capacity(cap); let request_tx = eventloop.handle(); let max_inflight = eventloop.state.max_inflight; - let pkid_counter = eventloop.state.pkid_counter().clone(); let client = AsyncClient { outgoing_buf, outgoing_buf_capacity: cap, - pkid_counter, + pkid_counter: 0, max_inflight, request_tx, }; diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index cf9d2a16c..0b304f666 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -2,7 +2,6 @@ use super::{packet::*, Incoming, Request}; use bytes::BytesMut; use std::collections::VecDeque; -use std::sync::atomic::{AtomicU16, Ordering}; use std::{ io, mem, sync::{Arc, Mutex}, @@ -77,7 +76,6 @@ pub struct MqttState { /// Indicates if acknowledgements should be send immediately pub manual_acks: bool, pub(crate) incoming_buf: Arc>>, - pkid_counter: Arc, } impl MqttState { @@ -101,38 +99,6 @@ impl MqttState { write: BytesMut::with_capacity(10 * 1024), manual_acks, incoming_buf: Arc::new(Mutex::new(VecDeque::with_capacity(cap))), - pkid_counter: Arc::new(AtomicU16::new(0)), - } - } - - #[inline] - pub(crate) fn pkid_counter(&self) -> &Arc { - &self.pkid_counter - } - - #[cfg(test)] - #[inline] - fn cur_pkid(&self) -> u16 { - self.pkid_counter.load(Ordering::SeqCst) - } - - pub(crate) fn increment_pkid(&self) -> u16 { - let mut cur_pkid = self.pkid_counter.load(Ordering::SeqCst); - loop { - let new_pkid = if cur_pkid > self.max_inflight { - 1 - } else { - cur_pkid + 1 - }; - match self.pkid_counter.compare_exchange( - cur_pkid, - new_pkid, - Ordering::SeqCst, - Ordering::Relaxed, - ) { - Ok(_prev_pkid) => break new_pkid, - Err(actual_pkid) => cur_pkid = actual_pkid, - } } } @@ -329,12 +295,9 @@ impl MqttState { /// Adds next packet identifier to QoS 1 and 2 publish packets and returns /// it buy wrapping publish in packet - fn outgoing_publish(&mut self, mut publish: Publish) -> Result<(), StateError> { + fn outgoing_publish(&mut self, publish: Publish) -> Result<(), StateError> { if publish.qos != QoS::AtMostOnce { - if publish.pkid == 0 { - publish.pkid = self.increment_pkid(); - } - + // client should set proper pkid let pkid = publish.pkid; if self .outgoing_pub @@ -419,10 +382,8 @@ impl MqttState { } #[inline] - fn outgoing_subscribe(&mut self, mut subscription: Subscribe) -> Result<(), StateError> { - let pkid = self.increment_pkid(); - subscription.pkid = pkid; - + fn outgoing_subscribe(&mut self, subscription: Subscribe) -> Result<(), StateError> { + // client should set correct pkid debug!( "Subscribe. Topics = {:?}, Pkid = {:?}", subscription.filters, subscription.pkid @@ -433,10 +394,7 @@ impl MqttState { } #[inline] - fn outgoing_unsubscribe(&mut self, mut unsub: Unsubscribe) -> Result<(), StateError> { - let pkid = self.increment_pkid(); - unsub.pkid = pkid; - + fn outgoing_unsubscribe(&mut self, unsub: Unsubscribe) -> Result<(), StateError> { debug!( "Unsubscribe. Topics = {:?}, Pkid = {:?}", unsub.filters, unsub.pkid @@ -466,16 +424,8 @@ impl MqttState { } #[inline] - fn save_pubrel(&mut self, mut pubrel: PubRel) -> Result { - let pubrel = match pubrel.pkid { - // consider PacketIdentifier(0) as uninitialized packets - 0 => { - pubrel.pkid = self.increment_pkid(); - pubrel - } - _ => pubrel, - }; - + fn save_pubrel(&mut self, pubrel: PubRel) -> Result { + // pubrel's pkid should already be set correct self.outgoing_rel[pubrel.pkid as usize] = Some(pubrel.pkid); Ok(pubrel) } @@ -502,282 +452,282 @@ impl MqttState { #[cfg(test)] mod test { - use super::{MqttState, StateError}; - use crate::v5::{packet::*, Incoming, MqttOptions, Request}; - - fn build_outgoing_publish(qos: QoS) -> Publish { - let topic = "hello/world".to_owned(); - let payload = vec![1, 2, 3]; - - let mut publish = Publish::new(topic, QoS::AtLeastOnce, payload); - publish.qos = qos; - publish - } - - fn build_incoming_publish(qos: QoS, pkid: u16) -> Publish { - let topic = "hello/world".to_owned(); - let payload = vec![1, 2, 3]; - - let mut publish = Publish::new(topic, QoS::AtLeastOnce, payload); - publish.pkid = pkid; - publish.qos = qos; - publish - } - - fn build_mqttstate() -> MqttState { - MqttState::new(100, false, 100) - } - - #[test] - fn next_pkid_increments_as_expected() { - let mqtt = build_mqttstate(); - - for i in 1..=100 { - let pkid = mqtt.increment_pkid(); - - // loops between 0-99. % 100 == 0 implies border - let expected = i % 100; - if expected == 0 { - break; - } - - assert_eq!(expected, pkid); - } - } - - #[test] - fn outgoing_publish_should_set_pkid_and_add_publish_to_queue() { - let mut mqtt = build_mqttstate(); - - // QoS0 Publish - let publish = build_outgoing_publish(QoS::AtMostOnce); - - // QoS 0 publish shouldn't be saved in queue - mqtt.outgoing_publish(publish).unwrap(); - assert_eq!(mqtt.cur_pkid(), 0); - assert_eq!(mqtt.inflight, 0); - - // QoS1 Publish - let publish = build_outgoing_publish(QoS::AtLeastOnce); - - // Packet id should be set and publish should be saved in queue - mqtt.outgoing_publish(publish.clone()).unwrap(); - assert_eq!(mqtt.cur_pkid(), 1); - assert_eq!(mqtt.inflight, 1); - - // Packet id should be incremented and publish should be saved in queue - mqtt.outgoing_publish(publish).unwrap(); - assert_eq!(mqtt.cur_pkid(), 2); - assert_eq!(mqtt.inflight, 2); - - // QoS1 Publish - let publish = build_outgoing_publish(QoS::ExactlyOnce); - - // Packet id should be set and publish should be saved in queue - mqtt.outgoing_publish(publish.clone()).unwrap(); - assert_eq!(mqtt.cur_pkid(), 3); - assert_eq!(mqtt.inflight, 3); - - // Packet id should be incremented and publish should be saved in queue - mqtt.outgoing_publish(publish).unwrap(); - assert_eq!(mqtt.cur_pkid(), 4); - assert_eq!(mqtt.inflight, 4); - } - - #[test] - fn incoming_publish_should_be_added_to_queue_correctly() { - let mut mqtt = build_mqttstate(); - - // QoS0, 1, 2 Publishes - let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); - let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); - let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); - - mqtt.handle_incoming_publish(&publish1).unwrap(); - mqtt.handle_incoming_publish(&publish2).unwrap(); - mqtt.handle_incoming_publish(&publish3).unwrap(); - - let pkid = mqtt.incoming_pub[3].unwrap(); - - // only qos2 publish should be add to queue - assert_eq!(pkid, 3); - } - - #[test] - fn incoming_publish_should_be_acked() { - let mut mqtt = build_mqttstate(); - - // QoS0, 1, 2 Publishes - let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); - let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); - let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); - - mqtt.handle_incoming_publish(&publish1).unwrap(); - mqtt.handle_incoming_publish(&publish2).unwrap(); - mqtt.handle_incoming_publish(&publish3).unwrap(); - } - - #[test] - fn incoming_publish_should_not_be_acked_with_manual_acks() { - let mut mqtt = build_mqttstate(); - mqtt.manual_acks = true; - - // QoS0, 1, 2 Publishes - let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); - let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); - let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); - - mqtt.handle_incoming_publish(&publish1).unwrap(); - mqtt.handle_incoming_publish(&publish2).unwrap(); - mqtt.handle_incoming_publish(&publish3).unwrap(); - - let pkid = mqtt.incoming_pub[3].unwrap(); - assert_eq!(pkid, 3); - - assert!(mqtt.incoming_buf.lock().unwrap().is_empty()); - } - - #[test] - fn incoming_qos2_publish_should_send_rec_to_network_and_publish_to_user() { - let mut mqtt = build_mqttstate(); - let publish = build_incoming_publish(QoS::ExactlyOnce, 1); - - mqtt.handle_incoming_publish(&publish).unwrap(); - let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); - match packet { - Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), - _ => panic!("Invalid network request: {:?}", packet), - } - } - - #[test] - fn incoming_puback_should_remove_correct_publish_from_queue() { - let mut mqtt = build_mqttstate(); - - let publish1 = build_outgoing_publish(QoS::AtLeastOnce); - let publish2 = build_outgoing_publish(QoS::ExactlyOnce); - - mqtt.outgoing_publish(publish1).unwrap(); - mqtt.outgoing_publish(publish2).unwrap(); - assert_eq!(mqtt.inflight, 2); - - mqtt.handle_incoming_puback(&PubAck::new(1)).unwrap(); - assert_eq!(mqtt.inflight, 1); - - mqtt.handle_incoming_puback(&PubAck::new(2)).unwrap(); - assert_eq!(mqtt.inflight, 0); - - assert!(mqtt.outgoing_pub[1].is_none()); - assert!(mqtt.outgoing_pub[2].is_none()); - } - - #[test] - fn incoming_pubrec_should_release_publish_from_queue_and_add_relid_to_rel_queue() { - let mut mqtt = build_mqttstate(); - - let publish1 = build_outgoing_publish(QoS::AtLeastOnce); - let publish2 = build_outgoing_publish(QoS::ExactlyOnce); - - let _publish_out = mqtt.outgoing_publish(publish1); - let _publish_out = mqtt.outgoing_publish(publish2); - - mqtt.handle_incoming_pubrec(&PubRec::new(2)).unwrap(); - assert_eq!(mqtt.inflight, 2); - - // check if the remaining element's pkid is 1 - let backup = mqtt.outgoing_pub[1].clone(); - assert_eq!(backup.unwrap().pkid, 1); - - // check if the qos2 element's release pkid is 2 - assert_eq!(mqtt.outgoing_rel[2].unwrap(), 2); - } - - #[test] - fn incoming_pubrec_should_send_release_to_network_and_nothing_to_user() { - let mut mqtt = build_mqttstate(); - - let publish = build_outgoing_publish(QoS::ExactlyOnce); - mqtt.outgoing_publish(publish).unwrap(); - let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); - match packet { - Packet::Publish(publish) => assert_eq!(publish.pkid, 1), - packet => panic!("Invalid network request: {:?}", packet), - } - - mqtt.handle_incoming_pubrec(&PubRec::new(1)).unwrap(); - let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); - match packet { - Packet::PubRel(pubrel) => assert_eq!(pubrel.pkid, 1), - packet => panic!("Invalid network request: {:?}", packet), - } - } - - #[test] - fn incoming_pubrel_should_send_comp_to_network_and_nothing_to_user() { - let mut mqtt = build_mqttstate(); - let publish = build_incoming_publish(QoS::ExactlyOnce, 1); - - mqtt.handle_incoming_publish(&publish).unwrap(); - let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); - match packet { - Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), - packet => panic!("Invalid network request: {:?}", packet), - } - - mqtt.handle_incoming_pubrel(&PubRel::new(1)).unwrap(); - let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); - match packet { - Packet::PubComp(pubcomp) => assert_eq!(pubcomp.pkid, 1), - packet => panic!("Invalid network request: {:?}", packet), - } - } - - #[test] - fn incoming_pubcomp_should_release_correct_pkid_from_release_queue() { - let mut mqtt = build_mqttstate(); - let publish = build_outgoing_publish(QoS::ExactlyOnce); - - mqtt.outgoing_publish(publish).unwrap(); - mqtt.handle_incoming_pubrec(&PubRec::new(1)).unwrap(); - - mqtt.handle_incoming_pubcomp(&PubComp::new(1)).unwrap(); - assert_eq!(mqtt.inflight, 0); - } - - #[test] - fn outgoing_ping_handle_should_throw_errors_for_no_pingresp() { - let mut mqtt = build_mqttstate(); - let mut opts = MqttOptions::new("test", "localhost", 1883); - opts.set_keep_alive(std::time::Duration::from_secs(10)); - mqtt.outgoing_ping().unwrap(); - - // network activity other than pingresp - let publish = build_outgoing_publish(QoS::AtLeastOnce); - mqtt.handle_outgoing_packet(Request::Publish(publish)) - .unwrap(); - mqtt.handle_incoming_packet(Incoming::PubAck(PubAck::new(1))) - .unwrap(); - - // should throw error because we didn't get pingresp for previous ping - match mqtt.outgoing_ping() { - Ok(_) => panic!("Should throw pingresp await error"), - Err(StateError::AwaitPingResp) => (), - Err(e) => panic!("Should throw pingresp await error. Error = {:?}", e), - } - } - - #[test] - fn outgoing_ping_handle_should_succeed_if_pingresp_is_received() { - let mut mqtt = build_mqttstate(); - - let mut opts = MqttOptions::new("test", "localhost", 1883); - opts.set_keep_alive(std::time::Duration::from_secs(10)); - - // should ping - mqtt.outgoing_ping().unwrap(); - mqtt.handle_incoming_packet(Incoming::PingResp).unwrap(); - - // should ping - mqtt.outgoing_ping().unwrap(); - } + // use super::{MqttState, StateError}; + // use crate::v5::{packet::*, Incoming, MqttOptions, Request}; + + // fn build_outgoing_publish(qos: QoS) -> Publish { + // let topic = "hello/world".to_owned(); + // let payload = vec![1, 2, 3]; + + // let mut publish = Publish::new(topic, QoS::AtLeastOnce, payload); + // publish.qos = qos; + // publish + // } + + // fn build_incoming_publish(qos: QoS, pkid: u16) -> Publish { + // let topic = "hello/world".to_owned(); + // let payload = vec![1, 2, 3]; + + // let mut publish = Publish::new(topic, QoS::AtLeastOnce, payload); + // publish.pkid = pkid; + // publish.qos = qos; + // publish + // } + + // fn build_mqttstate() -> MqttState { + // MqttState::new(100, false, 100) + // } + + // #[test] + // fn next_pkid_increments_as_expected() { + // let mqtt = build_mqttstate(); + + // for i in 1..=100 { + // let pkid = mqtt.increment_pkid(); + + // // loops between 0-99. % 100 == 0 implies border + // let expected = i % 100; + // if expected == 0 { + // break; + // } + + // assert_eq!(expected, pkid); + // } + // } + + // #[test] + // fn outgoing_publish_should_set_pkid_and_add_publish_to_queue() { + // let mut mqtt = build_mqttstate(); + + // // QoS0 Publish + // let publish = build_outgoing_publish(QoS::AtMostOnce); + + // // QoS 0 publish shouldn't be saved in queue + // mqtt.outgoing_publish(publish).unwrap(); + // assert_eq!(mqtt.cur_pkid(), 0); + // assert_eq!(mqtt.inflight, 0); + + // // QoS1 Publish + // let publish = build_outgoing_publish(QoS::AtLeastOnce); + + // // Packet id should be set and publish should be saved in queue + // mqtt.outgoing_publish(publish.clone()).unwrap(); + // assert_eq!(mqtt.cur_pkid(), 1); + // assert_eq!(mqtt.inflight, 1); + + // // Packet id should be incremented and publish should be saved in queue + // mqtt.outgoing_publish(publish).unwrap(); + // assert_eq!(mqtt.cur_pkid(), 2); + // assert_eq!(mqtt.inflight, 2); + + // // QoS1 Publish + // let publish = build_outgoing_publish(QoS::ExactlyOnce); + + // // Packet id should be set and publish should be saved in queue + // mqtt.outgoing_publish(publish.clone()).unwrap(); + // assert_eq!(mqtt.cur_pkid(), 3); + // assert_eq!(mqtt.inflight, 3); + + // // Packet id should be incremented and publish should be saved in queue + // mqtt.outgoing_publish(publish).unwrap(); + // assert_eq!(mqtt.cur_pkid(), 4); + // assert_eq!(mqtt.inflight, 4); + // } + + // #[test] + // fn incoming_publish_should_be_added_to_queue_correctly() { + // let mut mqtt = build_mqttstate(); + + // // QoS0, 1, 2 Publishes + // let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); + // let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); + // let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); + + // mqtt.handle_incoming_publish(&publish1).unwrap(); + // mqtt.handle_incoming_publish(&publish2).unwrap(); + // mqtt.handle_incoming_publish(&publish3).unwrap(); + + // let pkid = mqtt.incoming_pub[3].unwrap(); + + // // only qos2 publish should be add to queue + // assert_eq!(pkid, 3); + // } + + // #[test] + // fn incoming_publish_should_be_acked() { + // let mut mqtt = build_mqttstate(); + + // // QoS0, 1, 2 Publishes + // let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); + // let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); + // let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); + + // mqtt.handle_incoming_publish(&publish1).unwrap(); + // mqtt.handle_incoming_publish(&publish2).unwrap(); + // mqtt.handle_incoming_publish(&publish3).unwrap(); + // } + + // #[test] + // fn incoming_publish_should_not_be_acked_with_manual_acks() { + // let mut mqtt = build_mqttstate(); + // mqtt.manual_acks = true; + + // // QoS0, 1, 2 Publishes + // let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); + // let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); + // let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); + + // mqtt.handle_incoming_publish(&publish1).unwrap(); + // mqtt.handle_incoming_publish(&publish2).unwrap(); + // mqtt.handle_incoming_publish(&publish3).unwrap(); + + // let pkid = mqtt.incoming_pub[3].unwrap(); + // assert_eq!(pkid, 3); + + // assert!(mqtt.incoming_buf.lock().unwrap().is_empty()); + // } + + // #[test] + // fn incoming_qos2_publish_should_send_rec_to_network_and_publish_to_user() { + // let mut mqtt = build_mqttstate(); + // let publish = build_incoming_publish(QoS::ExactlyOnce, 1); + + // mqtt.handle_incoming_publish(&publish).unwrap(); + // let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + // match packet { + // Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), + // _ => panic!("Invalid network request: {:?}", packet), + // } + // } + + // #[test] + // fn incoming_puback_should_remove_correct_publish_from_queue() { + // let mut mqtt = build_mqttstate(); + + // let publish1 = build_outgoing_publish(QoS::AtLeastOnce); + // let publish2 = build_outgoing_publish(QoS::ExactlyOnce); + + // mqtt.outgoing_publish(publish1).unwrap(); + // mqtt.outgoing_publish(publish2).unwrap(); + // assert_eq!(mqtt.inflight, 2); + + // mqtt.handle_incoming_puback(&PubAck::new(1)).unwrap(); + // assert_eq!(mqtt.inflight, 1); + + // mqtt.handle_incoming_puback(&PubAck::new(2)).unwrap(); + // assert_eq!(mqtt.inflight, 0); + + // assert!(mqtt.outgoing_pub[1].is_none()); + // assert!(mqtt.outgoing_pub[2].is_none()); + // } + + // #[test] + // fn incoming_pubrec_should_release_publish_from_queue_and_add_relid_to_rel_queue() { + // let mut mqtt = build_mqttstate(); + + // let publish1 = build_outgoing_publish(QoS::AtLeastOnce); + // let publish2 = build_outgoing_publish(QoS::ExactlyOnce); + + // let _publish_out = mqtt.outgoing_publish(publish1); + // let _publish_out = mqtt.outgoing_publish(publish2); + + // mqtt.handle_incoming_pubrec(&PubRec::new(2)).unwrap(); + // assert_eq!(mqtt.inflight, 2); + + // // check if the remaining element's pkid is 1 + // let backup = mqtt.outgoing_pub[1].clone(); + // assert_eq!(backup.unwrap().pkid, 1); + + // // check if the qos2 element's release pkid is 2 + // assert_eq!(mqtt.outgoing_rel[2].unwrap(), 2); + // } + + // #[test] + // fn incoming_pubrec_should_send_release_to_network_and_nothing_to_user() { + // let mut mqtt = build_mqttstate(); + + // let publish = build_outgoing_publish(QoS::ExactlyOnce); + // mqtt.outgoing_publish(publish).unwrap(); + // let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + // match packet { + // Packet::Publish(publish) => assert_eq!(publish.pkid, 1), + // packet => panic!("Invalid network request: {:?}", packet), + // } + + // mqtt.handle_incoming_pubrec(&PubRec::new(1)).unwrap(); + // let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + // match packet { + // Packet::PubRel(pubrel) => assert_eq!(pubrel.pkid, 1), + // packet => panic!("Invalid network request: {:?}", packet), + // } + // } + + // #[test] + // fn incoming_pubrel_should_send_comp_to_network_and_nothing_to_user() { + // let mut mqtt = build_mqttstate(); + // let publish = build_incoming_publish(QoS::ExactlyOnce, 1); + + // mqtt.handle_incoming_publish(&publish).unwrap(); + // let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + // match packet { + // Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), + // packet => panic!("Invalid network request: {:?}", packet), + // } + + // mqtt.handle_incoming_pubrel(&PubRel::new(1)).unwrap(); + // let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + // match packet { + // Packet::PubComp(pubcomp) => assert_eq!(pubcomp.pkid, 1), + // packet => panic!("Invalid network request: {:?}", packet), + // } + // } + + // #[test] + // fn incoming_pubcomp_should_release_correct_pkid_from_release_queue() { + // let mut mqtt = build_mqttstate(); + // let publish = build_outgoing_publish(QoS::ExactlyOnce); + + // mqtt.outgoing_publish(publish).unwrap(); + // mqtt.handle_incoming_pubrec(&PubRec::new(1)).unwrap(); + + // mqtt.handle_incoming_pubcomp(&PubComp::new(1)).unwrap(); + // assert_eq!(mqtt.inflight, 0); + // } + + // #[test] + // fn outgoing_ping_handle_should_throw_errors_for_no_pingresp() { + // let mut mqtt = build_mqttstate(); + // let mut opts = MqttOptions::new("test", "localhost", 1883); + // opts.set_keep_alive(std::time::Duration::from_secs(10)); + // mqtt.outgoing_ping().unwrap(); + + // // network activity other than pingresp + // let publish = build_outgoing_publish(QoS::AtLeastOnce); + // mqtt.handle_outgoing_packet(Request::Publish(publish)) + // .unwrap(); + // mqtt.handle_incoming_packet(Incoming::PubAck(PubAck::new(1))) + // .unwrap(); + + // // should throw error because we didn't get pingresp for previous ping + // match mqtt.outgoing_ping() { + // Ok(_) => panic!("Should throw pingresp await error"), + // Err(StateError::AwaitPingResp) => (), + // Err(e) => panic!("Should throw pingresp await error. Error = {:?}", e), + // } + // } + + // #[test] + // fn outgoing_ping_handle_should_succeed_if_pingresp_is_received() { + // let mut mqtt = build_mqttstate(); + + // let mut opts = MqttOptions::new("test", "localhost", 1883); + // opts.set_keep_alive(std::time::Duration::from_secs(10)); + + // // should ping + // mqtt.outgoing_ping().unwrap(); + // mqtt.handle_incoming_packet(Incoming::PingResp).unwrap(); + + // // should ping + // mqtt.outgoing_ping().unwrap(); + // } } From 8fffc5d649bce9501c8d8a53e037190d5f37f73b Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Sun, 27 Mar 2022 20:47:10 +0530 Subject: [PATCH 37/38] rumqttc: restore backwards compatibilty Signed-off-by: Abhik Jain --- rumqttc/examples/async_manual_acks.rs | 22 +- rumqttc/examples/asyncpubsub.rs | 2 +- rumqttc/examples/sync_pubsub_weird.rs | 41 -- rumqttc/examples/syncpubsub.rs | 2 +- rumqttc/examples/tls.rs | 2 +- rumqttc/examples/tls2.rs | 3 +- rumqttc/examples/websocket.rs | 2 +- rumqttc/src/{v4 => }/client.rs | 7 +- rumqttc/src/{v4 => }/eventloop.rs | 7 +- rumqttc/src/{v4 => }/framed.rs | 9 +- rumqttc/src/lib.rs | 725 ++++++++++++++++++++++- rumqttc/src/{v4 => }/state.rs | 6 +- rumqttc/src/{v4 => }/tls.rs | 2 +- rumqttc/src/v4/mod.rs | 817 -------------------------- rumqttc/src/v5/client/mod.rs | 1 - rumqttc/src/v5/eventloop.rs | 8 +- rumqttc/tests/broker.rs | 2 +- rumqttc/tests/reliability.rs | 2 +- 18 files changed, 754 insertions(+), 906 deletions(-) delete mode 100644 rumqttc/examples/sync_pubsub_weird.rs rename rumqttc/src/{v4 => }/client.rs (99%) rename rumqttc/src/{v4 => }/eventloop.rs (98%) rename rumqttc/src/{v4 => }/framed.rs (97%) rename rumqttc/src/{v4 => }/state.rs (99%) rename rumqttc/src/{v4 => }/tls.rs (98%) delete mode 100644 rumqttc/src/v4/mod.rs diff --git a/rumqttc/examples/async_manual_acks.rs b/rumqttc/examples/async_manual_acks.rs index 000a060b7..4e9449841 100644 --- a/rumqttc/examples/async_manual_acks.rs +++ b/rumqttc/examples/async_manual_acks.rs @@ -1,6 +1,6 @@ use tokio::{task, time}; -use rumqttc::v4::{AsyncClient, EventLoop, MqttOptions, QoS}; +use rumqttc::{self, AsyncClient, Event, EventLoop, Incoming, MqttOptions, QoS}; use std::error::Error; use std::time::Duration; @@ -44,23 +44,21 @@ async fn main() -> Result<(), Box> { } // create new broker connection - let (_client, mut eventloop) = create_conn(); + let (client, mut eventloop) = create_conn(); loop { // previously published messages should be republished after reconnection. let event = eventloop.poll().await; println!("{:?}", event); - todo!("fix the commented out code below") - - // if let Ok(Event::Incoming(Incoming::Publish(publish))) = event { - // // this time we will ack incoming publishes. - // // Its important not to block eventloop as this can cause deadlock. - // let c = client.clone(); - // tokio::spawn(async move { - // c.ack(&publish).await.unwrap(); - // }); - // } + if let Ok(Event::Incoming(Incoming::Publish(publish))) = event { + // this time we will ack incoming publishes. + // Its important not to block eventloop as this can cause deadlock. + let c = client.clone(); + tokio::spawn(async move { + c.ack(&publish).await.unwrap(); + }); + } } } diff --git a/rumqttc/examples/asyncpubsub.rs b/rumqttc/examples/asyncpubsub.rs index b4de624e7..4ec2cd983 100644 --- a/rumqttc/examples/asyncpubsub.rs +++ b/rumqttc/examples/asyncpubsub.rs @@ -1,6 +1,6 @@ use tokio::{task, time}; -use rumqttc::v4::{AsyncClient, MqttOptions, QoS}; +use rumqttc::{self, AsyncClient, MqttOptions, QoS}; use std::error::Error; use std::time::Duration; diff --git a/rumqttc/examples/sync_pubsub_weird.rs b/rumqttc/examples/sync_pubsub_weird.rs deleted file mode 100644 index e6c7b9c3d..000000000 --- a/rumqttc/examples/sync_pubsub_weird.rs +++ /dev/null @@ -1,41 +0,0 @@ -use rumqttc::v4::{Client, Event, MqttOptions, Packet, QoS}; -use std::thread; -use std::time::Duration; - -fn main() { - let mut mqttoptions = MqttOptions::new("test1", "localhost", 1883); - mqttoptions.set_keep_alive(Duration::from_secs(5)); - mqttoptions.set_inflight(5); - - let (mut client, mut connection) = Client::new(mqttoptions, 25); - - client.subscribe("test/me", QoS::AtMostOnce).unwrap(); - - thread::spawn(move || publish(client)); - - for notification in connection.iter() { - println!("notification: {:?}", notification); - match notification.unwrap() { - Event::Incoming(inc) => match inc { - Packet::Publish(_) => { - thread::sleep(Duration::from_millis(200)); - } - _ => {} - }, - _ => {} - } - } - - println!("Done with the stream!!"); -} - -fn publish(client: Client) { - loop { - let payload = b"Foobar"; - let mut cl2 = client.clone(); - cl2.publish("test/me", QoS::AtLeastOnce, false, &payload[..]) - .unwrap(); - thread::sleep(Duration::from_millis(50)); - println!("send!"); - } -} diff --git a/rumqttc/examples/syncpubsub.rs b/rumqttc/examples/syncpubsub.rs index 2001b4d29..bb2171cfd 100644 --- a/rumqttc/examples/syncpubsub.rs +++ b/rumqttc/examples/syncpubsub.rs @@ -1,4 +1,4 @@ -use rumqttc::v4::{Client, LastWill, MqttOptions, QoS}; +use rumqttc::{self, Client, LastWill, MqttOptions, QoS}; use std::thread; use std::time::Duration; diff --git a/rumqttc/examples/tls.rs b/rumqttc/examples/tls.rs index f765d2f5a..0e25860e3 100644 --- a/rumqttc/examples/tls.rs +++ b/rumqttc/examples/tls.rs @@ -4,7 +4,7 @@ use std::error::Error; #[cfg(feature = "use-rustls")] #[tokio::main] async fn main() -> Result<(), Box> { - use rumqttc::v4::{AsyncClient, Event, Incoming, MqttOptions, Transport}; + use rumqttc::{self, AsyncClient, Event, Incoming, MqttOptions, Transport}; use rustls::ClientConfig; pretty_env_logger::init(); diff --git a/rumqttc/examples/tls2.rs b/rumqttc/examples/tls2.rs index c2106d585..496a806a0 100644 --- a/rumqttc/examples/tls2.rs +++ b/rumqttc/examples/tls2.rs @@ -1,11 +1,10 @@ //! Example of how to configure rumqttd to connect to a server using TLS and authentication. - use std::error::Error; #[cfg(feature = "use-rustls")] #[tokio::main] async fn main() -> Result<(), Box> { - use rumqttc::v4::{AsyncClient, Key, MqttOptions, TlsConfiguration, Transport}; + use rumqttc::{self, AsyncClient, Key, MqttOptions, TlsConfiguration, Transport}; pretty_env_logger::init(); color_backtrace::install(); diff --git a/rumqttc/examples/websocket.rs b/rumqttc/examples/websocket.rs index b5d1e39a4..712427b3c 100644 --- a/rumqttc/examples/websocket.rs +++ b/rumqttc/examples/websocket.rs @@ -1,5 +1,5 @@ #[cfg(feature = "websocket")] -use rumqttc::v4::{AsyncClient, MqttOptions, QoS, Transport}; +use rumqttc::{self, AsyncClient, MqttOptions, QoS, Transport}; #[cfg(feature = "websocket")] use std::{error::Error, time::Duration}; #[cfg(feature = "websocket")] diff --git a/rumqttc/src/v4/client.rs b/rumqttc/src/client.rs similarity index 99% rename from rumqttc/src/v4/client.rs rename to rumqttc/src/client.rs index 31a56c6ae..8464df271 100644 --- a/rumqttc/src/v4/client.rs +++ b/rumqttc/src/client.rs @@ -1,10 +1,7 @@ //! This module offers a high level synchronous and asynchronous abstraction to //! async eventloop. -use crate::{ - mqttbytes, - mqttbytes::v4::*, - v4::{ConnectionError, Event, EventLoop, MqttOptions, QoS, Request}, -}; +use crate::mqttbytes::{self, v4::*, QoS}; +use crate::{ConnectionError, Event, EventLoop, MqttOptions, Request}; use async_channel::{SendError, Sender, TrySendError}; use bytes::Bytes; diff --git a/rumqttc/src/v4/eventloop.rs b/rumqttc/src/eventloop.rs similarity index 98% rename from rumqttc/src/v4/eventloop.rs rename to rumqttc/src/eventloop.rs index efb8da1af..e65619b3f 100644 --- a/rumqttc/src/v4/eventloop.rs +++ b/rumqttc/src/eventloop.rs @@ -1,8 +1,7 @@ -use crate::v4::{framed::Network, Transport}; -use crate::v4::{Incoming, MqttState, Packet, Request, StateError}; -use crate::v4::{MqttOptions, Outgoing}; +use crate::framed::Network; #[cfg(feature = "use-rustls")] -use crate::v4::tls; +use crate::tls; +use crate::{Incoming, MqttOptions, MqttState, Outgoing, Packet, Request, StateError, Transport}; use crate::mqttbytes::v4::*; use async_channel::{bounded, Receiver, Sender}; diff --git a/rumqttc/src/v4/framed.rs b/rumqttc/src/framed.rs similarity index 97% rename from rumqttc/src/v4/framed.rs rename to rumqttc/src/framed.rs index 171f689c7..b0a536e78 100644 --- a/rumqttc/src/v4/framed.rs +++ b/rumqttc/src/framed.rs @@ -1,13 +1,8 @@ use bytes::BytesMut; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use crate::{ - mqttbytes::{ - self, - v4::{read, Connect}, - }, - v4::{Incoming, MqttState, StateError}, -}; +use crate::mqttbytes::{self, v4::*}; +use crate::{Incoming, MqttState, StateError}; use std::io; /// Network transforms packets <-> frames efficiently. It takes diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index e74d6adbf..a124168d3 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -9,7 +9,7 @@ //! ---------------------------- //! //! ```no_run -//! use rumqttc::v4::{MqttOptions, Client, QoS}; +//! use rumqttc::{MqttOptions, Client, QoS}; //! use std::time::Duration; //! use std::thread; //! @@ -33,7 +33,7 @@ //! ------------------------------ //! //! ```no_run -//! use rumqttc::v4::{MqttOptions, AsyncClient, QoS}; +//! use rumqttc::{MqttOptions, AsyncClient, QoS}; //! use tokio::{task, time}; //! use std::time::Duration; //! use std::error::Error; @@ -99,7 +99,724 @@ #[macro_use] extern crate log; -#[allow(clippy::all)] +use std::fmt::{self, Debug, Formatter}; +#[cfg(feature = "use-rustls")] +use std::sync::Arc; +use std::time::Duration; + +mod client; +mod eventloop; +mod framed; pub mod mqttbytes; -pub mod v4; +mod state; +#[cfg(feature = "use-rustls")] +mod tls; pub mod v5; + +pub use async_channel::{SendError, Sender, TrySendError}; +pub use client::{AsyncClient, Client, ClientError, Connection}; +pub use eventloop::{ConnectionError, Event, EventLoop}; +pub use mqttbytes::v4::*; +pub use mqttbytes::*; +pub use state::{MqttState, StateError}; +#[cfg(feature = "use-rustls")] +pub use tls::Error as TlsError; +#[cfg(feature = "use-rustls")] +pub use tokio_rustls::rustls::ClientConfig; + +pub type Incoming = Packet; + +/// Current outgoing activity on the eventloop +#[derive(Debug, Eq, PartialEq, Clone)] +pub enum Outgoing { + /// Publish packet with packet identifier. 0 implies QoS 0 + Publish(u16), + /// Subscribe packet with packet identifier + Subscribe(u16), + /// Unsubscribe packet with packet identifier + Unsubscribe(u16), + /// PubAck packet + PubAck(u16), + /// PubRec packet + PubRec(u16), + /// PubRel packet + PubRel(u16), + /// PubComp packet + PubComp(u16), + /// Ping request packet + PingReq, + /// Ping response packet + PingResp, + /// Disconnect packet + Disconnect, + /// Await for an ack for more outgoing progress + AwaitAck(u16), +} + +/// Requests by the client to mqtt event loop. Request are +/// handled one by one. +#[derive(Clone, Debug, PartialEq)] +pub enum Request { + Publish(Publish), + PubAck(PubAck), + PubRec(PubRec), + PubComp(PubComp), + PubRel(PubRel), + PingReq, + PingResp, + Subscribe(Subscribe), + SubAck(SubAck), + Unsubscribe(Unsubscribe), + UnsubAck(UnsubAck), + Disconnect, +} + +/// Key type for TLS authentication +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum Key { + RSA(Vec), + ECC(Vec), +} + +impl From for Request { + fn from(publish: Publish) -> Request { + Request::Publish(publish) + } +} + +impl From for Request { + fn from(subscribe: Subscribe) -> Request { + Request::Subscribe(subscribe) + } +} + +impl From for Request { + fn from(unsubscribe: Unsubscribe) -> Request { + Request::Unsubscribe(unsubscribe) + } +} + +#[derive(Clone)] +pub enum Transport { + Tcp, + #[cfg(feature = "use-rustls")] + Tls(TlsConfiguration), + #[cfg(unix)] + Unix, + #[cfg(feature = "websocket")] + #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + Ws, + #[cfg(all(feature = "use-rustls", feature = "websocket"))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "use-rustls", feature = "websocket"))))] + Wss(TlsConfiguration), +} + +impl Default for Transport { + fn default() -> Self { + Self::tcp() + } +} + +impl Transport { + /// Use regular tcp as transport (default) + pub fn tcp() -> Self { + Self::Tcp + } + + /// Use secure tcp with tls as transport + #[cfg(feature = "use-rustls")] + pub fn tls( + ca: Vec, + client_auth: Option<(Vec, Key)>, + alpn: Option>>, + ) -> Self { + let config = TlsConfiguration::Simple { + ca, + alpn, + client_auth, + }; + + Self::tls_with_config(config) + } + + #[cfg(feature = "use-rustls")] + pub fn tls_with_config(tls_config: TlsConfiguration) -> Self { + Self::Tls(tls_config) + } + + #[cfg(unix)] + pub fn unix() -> Self { + Self::Unix + } + + /// Use websockets as transport + #[cfg(feature = "websocket")] + #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + pub fn ws() -> Self { + Self::Ws + } + + /// Use secure websockets with tls as transport + #[cfg(all(feature = "use-rustls", feature = "websocket"))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "use-rustls", feature = "websocket"))))] + pub fn wss( + ca: Vec, + client_auth: Option<(Vec, Key)>, + alpn: Option>>, + ) -> Self { + let config = TlsConfiguration::Simple { + ca, + client_auth, + alpn, + }; + + Self::wss_with_config(config) + } + + #[cfg(all(feature = "use-rustls", feature = "websocket"))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "use-rustls", feature = "websocket"))))] + pub fn wss_with_config(tls_config: TlsConfiguration) -> Self { + Self::Wss(tls_config) + } +} + +#[derive(Clone)] +#[cfg(feature = "use-rustls")] +pub enum TlsConfiguration { + Simple { + /// connection method + ca: Vec, + /// alpn settings + alpn: Option>>, + /// tls client_authentication + client_auth: Option<(Vec, Key)>, + }, + /// Injected rustls ClientConfig for TLS, to allow more customisation. + Rustls(Arc), +} + +#[cfg(feature = "use-rustls")] +impl From for TlsConfiguration { + fn from(config: ClientConfig) -> Self { + TlsConfiguration::Rustls(Arc::new(config)) + } +} + +// TODO: Should all the options be exposed as public? Drawback +// would be loosing the ability to panic when the user options +// are wrong (e.g empty client id) or aggressive (keep alive time) +/// Options to configure the behaviour of mqtt connection +#[derive(Clone)] +pub struct MqttOptions { + /// broker address that you want to connect to + broker_addr: String, + /// broker port + port: u16, + // What transport protocol to use + transport: Transport, + /// keep alive time to send pingreq to broker when the connection is idle + keep_alive: Duration, + /// clean (or) persistent session + clean_session: bool, + /// client identifier + client_id: String, + /// username and password + credentials: Option<(String, String)>, + /// maximum incoming packet size (verifies remaining length of the packet) + max_incoming_packet_size: usize, + /// Maximum outgoing packet size (only verifies publish payload size) + // TODO Verify this with all packets. This can be packet.write but message left in + // the state might be a footgun as user has to explicitly clean it. Probably state + // has to be moved to network + max_outgoing_packet_size: usize, + /// request (publish, subscribe) channel capacity + request_channel_capacity: usize, + /// Max internal request batching + max_request_batch: usize, + /// Minimum delay time between consecutive outgoing packets + /// while retransmitting pending packets + pending_throttle: Duration, + /// maximum number of outgoing inflight messages + inflight: u16, + /// Last will that will be issued on unexpected disconnect + last_will: Option, + /// Connection timeout + conn_timeout: u64, + /// If set to `true` MQTT acknowledgements are not sent automatically. + /// Every incoming publish packet must be manually acknowledged with `client.ack(...)` method. + manual_acks: bool, +} + +impl MqttOptions { + /// New mqtt options + pub fn new, T: Into>(id: S, host: T, port: u16) -> MqttOptions { + let id = id.into(); + if id.starts_with(' ') || id.is_empty() { + panic!("Invalid client id") + } + + MqttOptions { + broker_addr: host.into(), + port, + transport: Transport::tcp(), + keep_alive: Duration::from_secs(60), + clean_session: true, + client_id: id, + credentials: None, + max_incoming_packet_size: 10 * 1024, + max_outgoing_packet_size: 10 * 1024, + request_channel_capacity: 10, + max_request_batch: 0, + pending_throttle: Duration::from_micros(0), + inflight: 100, + last_will: None, + conn_timeout: 5, + manual_acks: false, + } + } + + /// Broker address + pub fn broker_address(&self) -> (String, u16) { + (self.broker_addr.clone(), self.port) + } + + pub fn set_last_will(&mut self, will: LastWill) -> &mut Self { + self.last_will = Some(will); + self + } + + pub fn last_will(&self) -> Option { + self.last_will.clone() + } + + pub fn set_transport(&mut self, transport: Transport) -> &mut Self { + self.transport = transport; + self + } + + pub fn transport(&self) -> Transport { + self.transport.clone() + } + + /// Set number of seconds after which client should ping the broker + /// if there is no other data exchange + pub fn set_keep_alive(&mut self, duration: Duration) -> &mut Self { + if duration.as_secs() < 5 { + panic!("Keep alives should be >= 5 secs"); + } + + self.keep_alive = duration; + self + } + + /// Keep alive time + pub fn keep_alive(&self) -> Duration { + self.keep_alive + } + + /// Client identifier + pub fn client_id(&self) -> String { + self.client_id.clone() + } + + /// Set packet size limit for outgoing an incoming packets + pub fn set_max_packet_size(&mut self, incoming: usize, outgoing: usize) -> &mut Self { + self.max_incoming_packet_size = incoming; + self.max_outgoing_packet_size = outgoing; + self + } + + /// Maximum packet size + pub fn max_packet_size(&self) -> usize { + self.max_incoming_packet_size + } + + /// `clean_session = true` removes all the state from queues & instructs the broker + /// to clean all the client state when client disconnects. + /// + /// When set `false`, broker will hold the client state and performs pending + /// operations on the client when reconnection with same `client_id` + /// happens. Local queue state is also held to retransmit packets after reconnection. + pub fn set_clean_session(&mut self, clean_session: bool) -> &mut Self { + self.clean_session = clean_session; + self + } + + /// Clean session + pub fn clean_session(&self) -> bool { + self.clean_session + } + + /// Username and password + pub fn set_credentials, P: Into>( + &mut self, + username: U, + password: P, + ) -> &mut Self { + self.credentials = Some((username.into(), password.into())); + self + } + + /// Security options + pub fn credentials(&self) -> Option<(String, String)> { + self.credentials.clone() + } + + /// Set request channel capacity + pub fn set_request_channel_capacity(&mut self, capacity: usize) -> &mut Self { + self.request_channel_capacity = capacity; + self + } + + /// Request channel capacity + pub fn request_channel_capacity(&self) -> usize { + self.request_channel_capacity + } + + /// Enables throttling and sets outoing message rate to the specified 'rate' + pub fn set_pending_throttle(&mut self, duration: Duration) -> &mut Self { + self.pending_throttle = duration; + self + } + + /// Outgoing message rate + pub fn pending_throttle(&self) -> Duration { + self.pending_throttle + } + + /// Set number of concurrent in flight messages + pub fn set_inflight(&mut self, inflight: u16) -> &mut Self { + if inflight == 0 { + panic!("zero in flight is not allowed") + } + + self.inflight = inflight; + self + } + + /// Number of concurrent in flight messages + pub fn inflight(&self) -> u16 { + self.inflight + } + + /// set connection timeout in secs + pub fn set_connection_timeout(&mut self, timeout: u64) -> &mut Self { + self.conn_timeout = timeout; + self + } + + /// get timeout in secs + pub fn connection_timeout(&self) -> u64 { + self.conn_timeout + } + + /// set manual acknowledgements + pub fn set_manual_acks(&mut self, manual_acks: bool) -> &mut Self { + self.manual_acks = manual_acks; + self + } + + /// get manual acknowledgements + pub fn manual_acks(&self) -> bool { + self.manual_acks + } +} + +#[cfg(feature = "url")] +#[derive(Debug, PartialEq, thiserror::Error)] +pub enum OptionError { + #[error("Unsupported URL scheme.")] + Scheme, + + #[error("Missing client ID.")] + ClientId, + + #[error("Invalid keep-alive value.")] + KeepAlive, + + #[error("Invalid clean-session value.")] + CleanSession, + + #[error("Invalid max-incoming-packet-size value.")] + MaxIncomingPacketSize, + + #[error("Invalid max-outgoing-packet-size value.")] + MaxOutgoingPacketSize, + + #[error("Invalid request-channel-capacity value.")] + RequestChannelCapacity, + + #[error("Invalid max-request-batch value.")] + MaxRequestBatch, + + #[error("Invalid pending-throttle value.")] + PendingThrottle, + + #[error("Invalid inflight value.")] + Inflight, + + #[error("Invalid conn-timeout value.")] + ConnTimeout, + + #[error("Unknown option: {0}")] + Unknown(String), +} + +#[cfg(feature = "url")] +impl std::convert::TryFrom for MqttOptions { + type Error = OptionError; + + fn try_from(url: url::Url) -> Result { + use std::collections::HashMap; + + let broker_addr = url.host_str().unwrap_or_default().to_owned(); + + let (transport, default_port) = match url.scheme() { + // Encrypted connections are supported, but require explicit TLS configuration. We fall + // back to the unencrypted transport layer, so that `set_transport` can be used to + // configure the encrypted transport layer with the provided TLS configuration. + "mqtts" | "ssl" => (Transport::Tcp, 8883), + "mqtt" | "tcp" => (Transport::Tcp, 1883), + _ => return Err(OptionError::Scheme), + }; + + let port = url.port().unwrap_or(default_port); + + let mut queries = url.query_pairs().collect::>(); + + let keep_alive = Duration::from_secs( + queries + .remove("keep_alive_secs") + .map(|v| v.parse::().map_err(|_| OptionError::KeepAlive)) + .transpose()? + .unwrap_or(60), + ); + + let client_id = queries + .remove("client_id") + .ok_or(OptionError::ClientId)? + .into_owned(); + + let clean_session = queries + .remove("clean_session") + .map(|v| v.parse::().map_err(|_| OptionError::CleanSession)) + .transpose()? + .unwrap_or(true); + + let credentials = { + match url.username() { + "" => None, + username => Some(( + username.to_owned(), + url.password().unwrap_or_default().to_owned(), + )), + } + }; + + let max_incoming_packet_size = queries + .remove("max_incoming_packet_size_bytes") + .map(|v| { + v.parse::() + .map_err(|_| OptionError::MaxIncomingPacketSize) + }) + .transpose()? + .unwrap_or(10 * 1024); + + let max_outgoing_packet_size = queries + .remove("max_outgoing_packet_size_bytes") + .map(|v| { + v.parse::() + .map_err(|_| OptionError::MaxOutgoingPacketSize) + }) + .transpose()? + .unwrap_or(10 * 1024); + + let request_channel_capacity = queries + .remove("request_channel_capacity_num") + .map(|v| { + v.parse::() + .map_err(|_| OptionError::RequestChannelCapacity) + }) + .transpose()? + .unwrap_or(10); + + let max_request_batch = queries + .remove("max_request_batch_num") + .map(|v| v.parse::().map_err(|_| OptionError::MaxRequestBatch)) + .transpose()? + .unwrap_or(0); + + let pending_throttle = Duration::from_micros( + queries + .remove("pending_throttle_usecs") + .map(|v| v.parse::().map_err(|_| OptionError::PendingThrottle)) + .transpose()? + .unwrap_or(0), + ); + + let inflight = queries + .remove("inflight_num") + .map(|v| v.parse::().map_err(|_| OptionError::Inflight)) + .transpose()? + .unwrap_or(100); + + let conn_timeout = queries + .remove("conn_timeout_secs") + .map(|v| v.parse::().map_err(|_| OptionError::ConnTimeout)) + .transpose()? + .unwrap_or(5); + + if let Some((opt, _)) = queries.into_iter().next() { + return Err(OptionError::Unknown(opt.into_owned())); + } + + Ok(Self { + broker_addr, + port, + transport, + keep_alive, + clean_session, + client_id, + credentials, + max_incoming_packet_size, + max_outgoing_packet_size, + request_channel_capacity, + max_request_batch, + pending_throttle, + inflight, + last_will: None, + conn_timeout, + manual_acks: false, + }) + } +} + +// Implement Debug manually because ClientConfig doesn't implement it, so derive(Debug) doesn't +// work. +impl Debug for MqttOptions { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("MqttOptions") + .field("broker_addr", &self.broker_addr) + .field("port", &self.port) + .field("keep_alive", &self.keep_alive) + .field("clean_session", &self.clean_session) + .field("client_id", &self.client_id) + .field("credentials", &self.credentials) + .field("max_packet_size", &self.max_incoming_packet_size) + .field("request_channel_capacity", &self.request_channel_capacity) + .field("max_request_batch", &self.max_request_batch) + .field("pending_throttle", &self.pending_throttle) + .field("inflight", &self.inflight) + .field("last_will", &self.last_will) + .field("conn_timeout", &self.conn_timeout) + .field("manual_acks", &self.manual_acks) + .finish() + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + #[should_panic] + fn client_id_startswith_space() { + let _mqtt_opts = MqttOptions::new(" client_a", "127.0.0.1", 1883).set_clean_session(true); + } + + #[test] + #[cfg(all(feature = "use-rustls", feature = "websocket"))] + fn no_scheme() { + let mut _mqtt_opts = MqttOptions::new("client_a", "a3f8czas.iot.eu-west-1.amazonaws.com/mqtt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=MyCreds%2F20201001%2Feu-west-1%2Fiotdevicegateway%2Faws4_request&X-Amz-Date=20201001T130812Z&X-Amz-Expires=7200&X-Amz-Signature=9ae09b49896f44270f2707551581953e6cac71a4ccf34c7c3415555be751b2d1&X-Amz-SignedHeaders=host", 443); + + _mqtt_opts.set_transport(crate::Transport::wss(Vec::from("Test CA"), None, None)); + + if let crate::Transport::Wss(TlsConfiguration::Simple { + ca, + client_auth, + alpn, + }) = _mqtt_opts.transport + { + assert_eq!(ca, Vec::from("Test CA")); + assert_eq!(client_auth, None); + assert_eq!(alpn, None); + } else { + panic!("Unexpected transport!"); + } + + assert_eq!(_mqtt_opts.broker_addr, "a3f8czas.iot.eu-west-1.amazonaws.com/mqtt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=MyCreds%2F20201001%2Feu-west-1%2Fiotdevicegateway%2Faws4_request&X-Amz-Date=20201001T130812Z&X-Amz-Expires=7200&X-Amz-Signature=9ae09b49896f44270f2707551581953e6cac71a4ccf34c7c3415555be751b2d1&X-Amz-SignedHeaders=host"); + } + + #[test] + #[cfg(feature = "url")] + fn from_url() { + use std::convert::TryInto; + use std::str::FromStr; + + fn opt(s: &str) -> Result { + url::Url::from_str(s).expect("valid url").try_into() + } + fn ok(s: &str) -> MqttOptions { + opt(s).expect("valid options") + } + fn err(s: &str) -> OptionError { + opt(s).expect_err("invalid options") + } + + let v = ok("mqtt://host:42?client_id=foo"); + assert_eq!(v.broker_address(), ("host".to_owned(), 42)); + assert_eq!(v.client_id(), "foo".to_owned()); + + let v = ok("mqtt://host:42?client_id=foo&keep_alive_secs=5"); + assert_eq!(v.keep_alive, Duration::from_secs(5)); + + assert_eq!(err("mqtt://host:42"), OptionError::ClientId); + assert_eq!( + err("mqtt://host:42?client_id=foo&foo=bar"), + OptionError::Unknown("foo".to_owned()) + ); + assert_eq!(err("mqt://host:42?client_id=foo"), OptionError::Scheme); + assert_eq!( + err("mqtt://host:42?client_id=foo&keep_alive_secs=foo"), + OptionError::KeepAlive + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&clean_session=foo"), + OptionError::CleanSession + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&max_incoming_packet_size_bytes=foo"), + OptionError::MaxIncomingPacketSize + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&max_outgoing_packet_size_bytes=foo"), + OptionError::MaxOutgoingPacketSize + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&request_channel_capacity_num=foo"), + OptionError::RequestChannelCapacity + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&max_request_batch_num=foo"), + OptionError::MaxRequestBatch + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&pending_throttle_usecs=foo"), + OptionError::PendingThrottle + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&inflight_num=foo"), + OptionError::Inflight + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&conn_timeout_secs=foo"), + OptionError::ConnTimeout + ); + } + + #[test] + #[should_panic] + fn no_client_id() { + let _mqtt_opts = MqttOptions::new("", "127.0.0.1", 1883).set_clean_session(true); + } +} diff --git a/rumqttc/src/v4/state.rs b/rumqttc/src/state.rs similarity index 99% rename from rumqttc/src/v4/state.rs rename to rumqttc/src/state.rs index 7f6d5959f..399001cd8 100644 --- a/rumqttc/src/v4/state.rs +++ b/rumqttc/src/state.rs @@ -1,4 +1,4 @@ -use crate::v4::{Event, Incoming, Outgoing, Request}; +use crate::{Event, Incoming, Outgoing, Request}; use crate::mqttbytes::v4::*; use crate::mqttbytes::{self, *}; @@ -487,7 +487,9 @@ impl MqttState { #[cfg(test)] mod test { use super::{MqttState, StateError}; - use crate::{mqttbytes::v4::read, v4::*}; + use crate::mqttbytes::v4::*; + use crate::mqttbytes::*; + use crate::{Event, Incoming, MqttOptions, Outgoing, Request}; fn build_outgoing_publish(qos: QoS) -> Publish { let topic = "hello/world".to_owned(); diff --git a/rumqttc/src/v4/tls.rs b/rumqttc/src/tls.rs similarity index 98% rename from rumqttc/src/v4/tls.rs rename to rumqttc/src/tls.rs index 72eca8aa8..c740626b3 100644 --- a/rumqttc/src/v4/tls.rs +++ b/rumqttc/src/tls.rs @@ -7,7 +7,7 @@ use tokio_rustls::rustls::{ use tokio_rustls::webpki; use tokio_rustls::{client::TlsStream, TlsConnector}; -use crate::v4::{Key, MqttOptions, TlsConfiguration}; +use crate::{Key, MqttOptions, TlsConfiguration}; use std::convert::TryFrom; use std::io; diff --git a/rumqttc/src/v4/mod.rs b/rumqttc/src/v4/mod.rs deleted file mode 100644 index 0c39b6478..000000000 --- a/rumqttc/src/v4/mod.rs +++ /dev/null @@ -1,817 +0,0 @@ -//! A pure rust MQTT client which strives to be robust, efficient and easy to use. -//! This library is backed by an async (tokio) eventloop which handles all the -//! robustness and and efficiency parts of MQTT but naturally fits into both sync -//! and async worlds as we'll see -//! -//! Let's jump into examples right away -//! -//! A simple synchronous publish and subscribe -//! ---------------------------- -//! -//! ```no_run -//! use rumqttc::v4::{MqttOptions, Client, QoS}; -//! use std::time::Duration; -//! use std::thread; -//! -//! let mut mqttoptions = MqttOptions::new("rumqtt-sync", "test.mosquitto.org", 1883); -//! mqttoptions.set_keep_alive(Duration::from_secs(5)); -//! -//! let (mut client, mut connection) = Client::new(mqttoptions, 10); -//! client.subscribe("hello/rumqtt", QoS::AtMostOnce).unwrap(); -//! thread::spawn(move || for i in 0..10 { -//! client.publish("hello/rumqtt", QoS::AtLeastOnce, false, vec![i; i as usize]).unwrap(); -//! thread::sleep(Duration::from_millis(100)); -//! }); -//! -//! // Iterate to poll the eventloop for connection progress -//! for (i, notification) in connection.iter().enumerate() { -//! println!("Notification = {:?}", notification); -//! } -//! ``` -//! -//! A simple asynchronous publish and subscribe -//! ------------------------------ -//! -//! ```no_run -//! use rumqttc::v4::{MqttOptions, AsyncClient, QoS}; -//! use tokio::{task, time}; -//! use std::time::Duration; -//! use std::error::Error; -//! -//! # #[tokio::main(worker_threads = 1)] -//! # async fn main() { -//! let mut mqttoptions = MqttOptions::new("rumqtt-async", "test.mosquitto.org", 1883); -//! mqttoptions.set_keep_alive(Duration::from_secs(5)); -//! -//! let (mut client, mut eventloop) = AsyncClient::new(mqttoptions, 10); -//! client.subscribe("hello/rumqtt", QoS::AtMostOnce).await.unwrap(); -//! -//! task::spawn(async move { -//! for i in 0..10 { -//! client.publish("hello/rumqtt", QoS::AtLeastOnce, false, vec![i; i as usize]).await.unwrap(); -//! time::sleep(Duration::from_millis(100)).await; -//! } -//! }); -//! -//! loop { -//! let notification = eventloop.poll().await.unwrap(); -//! println!("Received = {:?}", notification); -//! } -//! # } -//! ``` -//! -//! Quick overview of features -//! - Eventloop orchestrates outgoing/incoming packets concurrently and hadles the state -//! - Pings the broker when necessary and detects client side half open connections as well -//! - Throttling of outgoing packets (todo) -//! - Queue size based flow control on outgoing packets -//! - Automatic reconnections by just continuing the `eventloop.poll()/connection.iter()` loop` -//! - Natural backpressure to client APIs during bad network -//! - Immediate cancellation with `client.cancel()` -//! -//! In short, everything necessary to maintain a robust connection -//! -//! Since the eventloop is externally polled (with `iter()/poll()` in a loop) -//! out side the library and `Eventloop` is accessible, users can -//! - Distribute incoming messages based on topics -//! - Stop it when required -//! - Access internal state for use cases like graceful shutdown or to modify options before reconnection -//! -//! ## Important notes -//! -//! - Looping on `connection.iter()`/`eventloop.poll()` is necessary to run the -//! event loop and make progress. It yields incoming and outgoing activity -//! notifications which allows customization as you see fit. -//! -//! - Blocking inside the `connection.iter()`/`eventloop.poll()` loop will block -//! connection progress. -//! -//! ## FAQ -//! **Connecting to a broker using raw ip doesn't work** -//! -//! You cannot create a TLS connection to a bare IP address with a self-signed -//! certificate. This is a [limitation of rustls](https://github.com/ctz/rustls/issues/184). -//! One workaround, which only works under *nix/BSD-like systems, is to add an -//! entry to wherever your DNS resolver looks (e.g. `/etc/hosts`) for the bare IP -//! address and use that name in your code. -#![cfg_attr(docsrs, feature(doc_cfg))] - -use std::fmt::{self, Debug, Formatter}; -#[cfg(feature = "use-rustls")] -use std::sync::Arc; -use std::time::Duration; - -mod client; -mod eventloop; -mod framed; -mod state; -#[cfg(feature = "use-rustls")] -mod tls; - -pub use async_channel::{SendError, Sender, TrySendError}; -pub use client::{AsyncClient, Client, ClientError, Connection}; -pub use eventloop::{ConnectionError, Event, EventLoop}; -pub use crate::mqttbytes::v4::*; -pub use crate::mqttbytes::*; -pub use state::{MqttState, StateError}; -#[cfg(feature = "use-rustls")] -pub use tls::Error as TlsError; -#[cfg(feature = "use-rustls")] -pub use tokio_rustls::rustls::ClientConfig; - -pub type Incoming = Packet; - -/// Current outgoing activity on the eventloop -#[derive(Debug, Eq, PartialEq, Clone)] -pub enum Outgoing { - /// Publish packet with packet identifier. 0 implies QoS 0 - Publish(u16), - /// Subscribe packet with packet identifier - Subscribe(u16), - /// Unsubscribe packet with packet identifier - Unsubscribe(u16), - /// PubAck packet - PubAck(u16), - /// PubRec packet - PubRec(u16), - /// PubRel packet - PubRel(u16), - /// PubComp packet - PubComp(u16), - /// Ping request packet - PingReq, - /// Ping response packet - PingResp, - /// Disconnect packet - Disconnect, - /// Await for an ack for more outgoing progress - AwaitAck(u16), -} - -/// Requests by the client to mqtt event loop. Request are -/// handled one by one. -#[derive(Clone, Debug, PartialEq)] -pub enum Request { - Publish(Publish), - PubAck(PubAck), - PubRec(PubRec), - PubComp(PubComp), - PubRel(PubRel), - PingReq, - PingResp, - Subscribe(Subscribe), - SubAck(SubAck), - Unsubscribe(Unsubscribe), - UnsubAck(UnsubAck), - Disconnect, -} - -/// Key type for TLS authentication -#[derive(Clone, Debug, Eq, PartialEq)] -pub enum Key { - RSA(Vec), - ECC(Vec), -} - -impl From for Request { - fn from(publish: Publish) -> Request { - Request::Publish(publish) - } -} - -impl From for Request { - fn from(subscribe: Subscribe) -> Request { - Request::Subscribe(subscribe) - } -} - -impl From for Request { - fn from(unsubscribe: Unsubscribe) -> Request { - Request::Unsubscribe(unsubscribe) - } -} - -#[derive(Clone)] -pub enum Transport { - Tcp, - #[cfg(feature = "use-rustls")] - Tls(TlsConfiguration), - #[cfg(unix)] - Unix, - #[cfg(feature = "websocket")] - #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] - Ws, - #[cfg(all(feature = "use-rustls", feature = "websocket"))] - #[cfg_attr(docsrs, doc(cfg(all(feature = "use-rustls", feature = "websocket"))))] - Wss(TlsConfiguration), -} - -impl Default for Transport { - fn default() -> Self { - Self::tcp() - } -} - -impl Transport { - /// Use regular tcp as transport (default) - pub fn tcp() -> Self { - Self::Tcp - } - - /// Use secure tcp with tls as transport - #[cfg(feature = "use-rustls")] - pub fn tls( - ca: Vec, - client_auth: Option<(Vec, Key)>, - alpn: Option>>, - ) -> Self { - let config = TlsConfiguration::Simple { - ca, - alpn, - client_auth, - }; - - Self::tls_with_config(config) - } - - #[cfg(feature = "use-rustls")] - pub fn tls_with_config(tls_config: TlsConfiguration) -> Self { - Self::Tls(tls_config) - } - - #[cfg(unix)] - pub fn unix() -> Self { - Self::Unix - } - - /// Use websockets as transport - #[cfg(feature = "websocket")] - #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] - pub fn ws() -> Self { - Self::Ws - } - - /// Use secure websockets with tls as transport - #[cfg(all(feature = "use-rustls", feature = "websocket"))] - #[cfg_attr(docsrs, doc(cfg(all(feature = "use-rustls", feature = "websocket"))))] - pub fn wss( - ca: Vec, - client_auth: Option<(Vec, Key)>, - alpn: Option>>, - ) -> Self { - let config = TlsConfiguration::Simple { - ca, - client_auth, - alpn, - }; - - Self::wss_with_config(config) - } - - #[cfg(all(feature = "use-rustls", feature = "websocket"))] - #[cfg_attr(docsrs, doc(cfg(all(feature = "use-rustls", feature = "websocket"))))] - pub fn wss_with_config(tls_config: TlsConfiguration) -> Self { - Self::Wss(tls_config) - } -} - -#[derive(Clone)] -#[cfg(feature = "use-rustls")] -pub enum TlsConfiguration { - Simple { - /// connection method - ca: Vec, - /// alpn settings - alpn: Option>>, - /// tls client_authentication - client_auth: Option<(Vec, Key)>, - }, - /// Injected rustls ClientConfig for TLS, to allow more customisation. - Rustls(Arc), -} - -#[cfg(feature = "use-rustls")] -impl From for TlsConfiguration { - fn from(config: ClientConfig) -> Self { - TlsConfiguration::Rustls(Arc::new(config)) - } -} - -// TODO: Should all the options be exposed as public? Drawback -// would be loosing the ability to panic when the user options -// are wrong (e.g empty client id) or aggressive (keep alive time) -/// Options to configure the behaviour of mqtt connection -#[derive(Clone)] -pub struct MqttOptions { - /// broker address that you want to connect to - broker_addr: String, - /// broker port - port: u16, - // What transport protocol to use - transport: Transport, - /// keep alive time to send pingreq to broker when the connection is idle - keep_alive: Duration, - /// clean (or) persistent session - clean_session: bool, - /// client identifier - client_id: String, - /// username and password - credentials: Option<(String, String)>, - /// maximum incoming packet size (verifies remaining length of the packet) - max_incoming_packet_size: usize, - /// Maximum outgoing packet size (only verifies publish payload size) - // TODO Verify this with all packets. This can be packet.write but message left in - // the state might be a footgun as user has to explicitly clean it. Probably state - // has to be moved to network - max_outgoing_packet_size: usize, - /// request (publish, subscribe) channel capacity - request_channel_capacity: usize, - /// Max internal request batching - max_request_batch: usize, - /// Minimum delay time between consecutive outgoing packets - /// while retransmitting pending packets - pending_throttle: Duration, - /// maximum number of outgoing inflight messages - inflight: u16, - /// Last will that will be issued on unexpected disconnect - last_will: Option, - /// Connection timeout - conn_timeout: u64, - /// If set to `true` MQTT acknowledgements are not sent automatically. - /// Every incoming publish packet must be manually acknowledged with `client.ack(...)` method. - manual_acks: bool, -} - -impl MqttOptions { - /// New mqtt options - pub fn new, T: Into>(id: S, host: T, port: u16) -> MqttOptions { - let id = id.into(); - if id.starts_with(' ') || id.is_empty() { - panic!("Invalid client id") - } - - MqttOptions { - broker_addr: host.into(), - port, - transport: Transport::tcp(), - keep_alive: Duration::from_secs(60), - clean_session: true, - client_id: id, - credentials: None, - max_incoming_packet_size: 10 * 1024, - max_outgoing_packet_size: 10 * 1024, - request_channel_capacity: 10, - max_request_batch: 0, - pending_throttle: Duration::from_micros(0), - inflight: 100, - last_will: None, - conn_timeout: 5, - manual_acks: false, - } - } - - /// Broker address - pub fn broker_address(&self) -> (String, u16) { - (self.broker_addr.clone(), self.port) - } - - pub fn set_last_will(&mut self, will: LastWill) -> &mut Self { - self.last_will = Some(will); - self - } - - pub fn last_will(&self) -> Option { - self.last_will.clone() - } - - pub fn set_transport(&mut self, transport: Transport) -> &mut Self { - self.transport = transport; - self - } - - pub fn transport(&self) -> Transport { - self.transport.clone() - } - - /// Set number of seconds after which client should ping the broker - /// if there is no other data exchange - pub fn set_keep_alive(&mut self, duration: Duration) -> &mut Self { - if duration.as_secs() < 5 { - panic!("Keep alives should be >= 5 secs"); - } - - self.keep_alive = duration; - self - } - - /// Keep alive time - pub fn keep_alive(&self) -> Duration { - self.keep_alive - } - - /// Client identifier - pub fn client_id(&self) -> String { - self.client_id.clone() - } - - /// Set packet size limit for outgoing an incoming packets - pub fn set_max_packet_size(&mut self, incoming: usize, outgoing: usize) -> &mut Self { - self.max_incoming_packet_size = incoming; - self.max_outgoing_packet_size = outgoing; - self - } - - /// Maximum packet size - pub fn max_packet_size(&self) -> usize { - self.max_incoming_packet_size - } - - /// `clean_session = true` removes all the state from queues & instructs the broker - /// to clean all the client state when client disconnects. - /// - /// When set `false`, broker will hold the client state and performs pending - /// operations on the client when reconnection with same `client_id` - /// happens. Local queue state is also held to retransmit packets after reconnection. - pub fn set_clean_session(&mut self, clean_session: bool) -> &mut Self { - self.clean_session = clean_session; - self - } - - /// Clean session - pub fn clean_session(&self) -> bool { - self.clean_session - } - - /// Username and password - pub fn set_credentials, P: Into>( - &mut self, - username: U, - password: P, - ) -> &mut Self { - self.credentials = Some((username.into(), password.into())); - self - } - - /// Security options - pub fn credentials(&self) -> Option<(String, String)> { - self.credentials.clone() - } - - /// Set request channel capacity - pub fn set_request_channel_capacity(&mut self, capacity: usize) -> &mut Self { - self.request_channel_capacity = capacity; - self - } - - /// Request channel capacity - pub fn request_channel_capacity(&self) -> usize { - self.request_channel_capacity - } - - /// Enables throttling and sets outoing message rate to the specified 'rate' - pub fn set_pending_throttle(&mut self, duration: Duration) -> &mut Self { - self.pending_throttle = duration; - self - } - - /// Outgoing message rate - pub fn pending_throttle(&self) -> Duration { - self.pending_throttle - } - - /// Set number of concurrent in flight messages - pub fn set_inflight(&mut self, inflight: u16) -> &mut Self { - if inflight == 0 { - panic!("zero in flight is not allowed") - } - - self.inflight = inflight; - self - } - - /// Number of concurrent in flight messages - pub fn inflight(&self) -> u16 { - self.inflight - } - - /// set connection timeout in secs - pub fn set_connection_timeout(&mut self, timeout: u64) -> &mut Self { - self.conn_timeout = timeout; - self - } - - /// get timeout in secs - pub fn connection_timeout(&self) -> u64 { - self.conn_timeout - } - - /// set manual acknowledgements - pub fn set_manual_acks(&mut self, manual_acks: bool) -> &mut Self { - self.manual_acks = manual_acks; - self - } - - /// get manual acknowledgements - pub fn manual_acks(&self) -> bool { - self.manual_acks - } -} - -#[cfg(feature = "url")] -#[derive(Debug, PartialEq, thiserror::Error)] -pub enum OptionError { - #[error("Unsupported URL scheme.")] - Scheme, - - #[error("Missing client ID.")] - ClientId, - - #[error("Invalid keep-alive value.")] - KeepAlive, - - #[error("Invalid clean-session value.")] - CleanSession, - - #[error("Invalid max-incoming-packet-size value.")] - MaxIncomingPacketSize, - - #[error("Invalid max-outgoing-packet-size value.")] - MaxOutgoingPacketSize, - - #[error("Invalid request-channel-capacity value.")] - RequestChannelCapacity, - - #[error("Invalid max-request-batch value.")] - MaxRequestBatch, - - #[error("Invalid pending-throttle value.")] - PendingThrottle, - - #[error("Invalid inflight value.")] - Inflight, - - #[error("Invalid conn-timeout value.")] - ConnTimeout, - - #[error("Unknown option: {0}")] - Unknown(String), -} - -#[cfg(feature = "url")] -impl std::convert::TryFrom for MqttOptions { - type Error = OptionError; - - fn try_from(url: url::Url) -> Result { - use std::collections::HashMap; - - let broker_addr = url.host_str().unwrap_or_default().to_owned(); - - let (transport, default_port) = match url.scheme() { - // Encrypted connections are supported, but require explicit TLS configuration. We fall - // back to the unencrypted transport layer, so that `set_transport` can be used to - // configure the encrypted transport layer with the provided TLS configuration. - "mqtts" | "ssl" => (Transport::Tcp, 8883), - "mqtt" | "tcp" => (Transport::Tcp, 1883), - _ => return Err(OptionError::Scheme), - }; - - let port = url.port().unwrap_or(default_port); - - let mut queries = url.query_pairs().collect::>(); - - let keep_alive = Duration::from_secs( - queries - .remove("keep_alive_secs") - .map(|v| v.parse::().map_err(|_| OptionError::KeepAlive)) - .transpose()? - .unwrap_or(60), - ); - - let client_id = queries - .remove("client_id") - .ok_or(OptionError::ClientId)? - .into_owned(); - - let clean_session = queries - .remove("clean_session") - .map(|v| v.parse::().map_err(|_| OptionError::CleanSession)) - .transpose()? - .unwrap_or(true); - - let credentials = { - match url.username() { - "" => None, - username => Some(( - username.to_owned(), - url.password().unwrap_or_default().to_owned(), - )), - } - }; - - let max_incoming_packet_size = queries - .remove("max_incoming_packet_size_bytes") - .map(|v| { - v.parse::() - .map_err(|_| OptionError::MaxIncomingPacketSize) - }) - .transpose()? - .unwrap_or(10 * 1024); - - let max_outgoing_packet_size = queries - .remove("max_outgoing_packet_size_bytes") - .map(|v| { - v.parse::() - .map_err(|_| OptionError::MaxOutgoingPacketSize) - }) - .transpose()? - .unwrap_or(10 * 1024); - - let request_channel_capacity = queries - .remove("request_channel_capacity_num") - .map(|v| { - v.parse::() - .map_err(|_| OptionError::RequestChannelCapacity) - }) - .transpose()? - .unwrap_or(10); - - let max_request_batch = queries - .remove("max_request_batch_num") - .map(|v| v.parse::().map_err(|_| OptionError::MaxRequestBatch)) - .transpose()? - .unwrap_or(0); - - let pending_throttle = Duration::from_micros( - queries - .remove("pending_throttle_usecs") - .map(|v| v.parse::().map_err(|_| OptionError::PendingThrottle)) - .transpose()? - .unwrap_or(0), - ); - - let inflight = queries - .remove("inflight_num") - .map(|v| v.parse::().map_err(|_| OptionError::Inflight)) - .transpose()? - .unwrap_or(100); - - let conn_timeout = queries - .remove("conn_timeout_secs") - .map(|v| v.parse::().map_err(|_| OptionError::ConnTimeout)) - .transpose()? - .unwrap_or(5); - - if let Some((opt, _)) = queries.into_iter().next() { - return Err(OptionError::Unknown(opt.into_owned())); - } - - Ok(Self { - broker_addr, - port, - transport, - keep_alive, - clean_session, - client_id, - credentials, - max_incoming_packet_size, - max_outgoing_packet_size, - request_channel_capacity, - max_request_batch, - pending_throttle, - inflight, - last_will: None, - conn_timeout, - manual_acks: false, - }) - } -} - -// Implement Debug manually because ClientConfig doesn't implement it, so derive(Debug) doesn't -// work. -impl Debug for MqttOptions { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - f.debug_struct("MqttOptions") - .field("broker_addr", &self.broker_addr) - .field("port", &self.port) - .field("keep_alive", &self.keep_alive) - .field("clean_session", &self.clean_session) - .field("client_id", &self.client_id) - .field("credentials", &self.credentials) - .field("max_packet_size", &self.max_incoming_packet_size) - .field("request_channel_capacity", &self.request_channel_capacity) - .field("max_request_batch", &self.max_request_batch) - .field("pending_throttle", &self.pending_throttle) - .field("inflight", &self.inflight) - .field("last_will", &self.last_will) - .field("conn_timeout", &self.conn_timeout) - .field("manual_acks", &self.manual_acks) - .finish() - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - #[should_panic] - fn client_id_startswith_space() { - let _mqtt_opts = MqttOptions::new(" client_a", "127.0.0.1", 1883).set_clean_session(true); - } - - #[test] - #[cfg(all(feature = "use-rustls", feature = "websocket"))] - fn no_scheme() { - let mut _mqtt_opts = MqttOptions::new("client_a", "a3f8czas.iot.eu-west-1.amazonaws.com/mqtt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=MyCreds%2F20201001%2Feu-west-1%2Fiotdevicegateway%2Faws4_request&X-Amz-Date=20201001T130812Z&X-Amz-Expires=7200&X-Amz-Signature=9ae09b49896f44270f2707551581953e6cac71a4ccf34c7c3415555be751b2d1&X-Amz-SignedHeaders=host", 443); - - _mqtt_opts.set_transport(crate::v4::Transport::wss(Vec::from("Test CA"), None, None)); - - if let crate::v4::Transport::Wss(TlsConfiguration::Simple { - ca, - client_auth, - alpn, - }) = _mqtt_opts.transport - { - assert_eq!(ca, Vec::from("Test CA")); - assert_eq!(client_auth, None); - assert_eq!(alpn, None); - } else { - panic!("Unexpected transport!"); - } - - assert_eq!(_mqtt_opts.broker_addr, "a3f8czas.iot.eu-west-1.amazonaws.com/mqtt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=MyCreds%2F20201001%2Feu-west-1%2Fiotdevicegateway%2Faws4_request&X-Amz-Date=20201001T130812Z&X-Amz-Expires=7200&X-Amz-Signature=9ae09b49896f44270f2707551581953e6cac71a4ccf34c7c3415555be751b2d1&X-Amz-SignedHeaders=host"); - } - - #[test] - #[cfg(feature = "url")] - fn from_url() { - use std::convert::TryInto; - use std::str::FromStr; - - fn opt(s: &str) -> Result { - url::Url::from_str(s).expect("valid url").try_into() - } - fn ok(s: &str) -> MqttOptions { - opt(s).expect("valid options") - } - fn err(s: &str) -> OptionError { - opt(s).expect_err("invalid options") - } - - let v = ok("mqtt://host:42?client_id=foo"); - assert_eq!(v.broker_address(), ("host".to_owned(), 42)); - assert_eq!(v.client_id(), "foo".to_owned()); - - let v = ok("mqtt://host:42?client_id=foo&keep_alive_secs=5"); - assert_eq!(v.keep_alive, Duration::from_secs(5)); - - assert_eq!(err("mqtt://host:42"), OptionError::ClientId); - assert_eq!( - err("mqtt://host:42?client_id=foo&foo=bar"), - OptionError::Unknown("foo".to_owned()) - ); - assert_eq!(err("mqt://host:42?client_id=foo"), OptionError::Scheme); - assert_eq!( - err("mqtt://host:42?client_id=foo&keep_alive_secs=foo"), - OptionError::KeepAlive - ); - assert_eq!( - err("mqtt://host:42?client_id=foo&clean_session=foo"), - OptionError::CleanSession - ); - assert_eq!( - err("mqtt://host:42?client_id=foo&max_incoming_packet_size_bytes=foo"), - OptionError::MaxIncomingPacketSize - ); - assert_eq!( - err("mqtt://host:42?client_id=foo&max_outgoing_packet_size_bytes=foo"), - OptionError::MaxOutgoingPacketSize - ); - assert_eq!( - err("mqtt://host:42?client_id=foo&request_channel_capacity_num=foo"), - OptionError::RequestChannelCapacity - ); - assert_eq!( - err("mqtt://host:42?client_id=foo&max_request_batch_num=foo"), - OptionError::MaxRequestBatch - ); - assert_eq!( - err("mqtt://host:42?client_id=foo&pending_throttle_usecs=foo"), - OptionError::PendingThrottle - ); - assert_eq!( - err("mqtt://host:42?client_id=foo&inflight_num=foo"), - OptionError::Inflight - ); - assert_eq!( - err("mqtt://host:42?client_id=foo&conn_timeout_secs=foo"), - OptionError::ConnTimeout - ); - } - - #[test] - #[should_panic] - fn no_client_id() { - let _mqtt_opts = MqttOptions::new("", "127.0.0.1", 1883).set_clean_session(true); - } -} diff --git a/rumqttc/src/v5/client/mod.rs b/rumqttc/src/v5/client/mod.rs index 811b6889d..f20efe87e 100644 --- a/rumqttc/src/v5/client/mod.rs +++ b/rumqttc/src/v5/client/mod.rs @@ -33,7 +33,6 @@ fn get_ack_req(qos: QoS, pkid: u16) -> Option { Some(ack) } - /// MQTT connection. Maintains all the necessary state pub struct Connection { pub eventloop: EventLoop, diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index 74edd7a6c..90deb7c11 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -1,9 +1,9 @@ -use crate::v5::{ - framed::Network, packet::*, Incoming, MqttOptions, MqttState, Packet, Request, - StateError, Transport, -}; #[cfg(feature = "use-rustls")] use crate::v5::tls; +use crate::v5::{ + framed::Network, packet::*, Incoming, MqttOptions, MqttState, Packet, Request, StateError, + Transport, +}; #[cfg(feature = "websocket")] use async_tungstenite::tokio::{connect_async, connect_async_with_tls_connector}; diff --git a/rumqttc/tests/broker.rs b/rumqttc/tests/broker.rs index 92f18f9bd..dc788c0e4 100644 --- a/rumqttc/tests/broker.rs +++ b/rumqttc/tests/broker.rs @@ -9,7 +9,7 @@ use tokio::{task, time}; use async_channel::{bounded, Receiver, Sender}; use bytes::BytesMut; -use rumqttc::v4::{Event, Incoming, Outgoing, Packet}; +use rumqttc::{Event, Incoming, Outgoing, Packet}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; pub struct Broker { diff --git a/rumqttc/tests/reliability.rs b/rumqttc/tests/reliability.rs index 68337a452..3fc3e48c9 100644 --- a/rumqttc/tests/reliability.rs +++ b/rumqttc/tests/reliability.rs @@ -5,7 +5,7 @@ use tokio::{task, time}; mod broker; use broker::*; -use rumqttc::v4::*; +use rumqttc::*; async fn start_requests(count: u8, qos: QoS, delay: u64, requests_tx: Sender) { for i in 1..=count { From 7792a5ef6e617147359872ea4b0b147fcc50f2fd Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Fri, 1 Apr 2022 22:02:55 +0530 Subject: [PATCH 38/38] rumqttc: v5: move shared buffer to separate struct Signed-off-by: Abhik Jain --- rumqttc/src/v5/client/asyncclient.rs | 208 ++++++---- rumqttc/src/v5/client/syncclient.rs | 74 +++- rumqttc/src/v5/eventloop.rs | 32 +- rumqttc/src/v5/mod.rs | 7 +- rumqttc/src/v5/outgoing_buf.rs | 34 ++ rumqttc/src/v5/state.rs | 596 ++++++++++++++------------- 6 files changed, 549 insertions(+), 402 deletions(-) create mode 100644 rumqttc/src/v5/outgoing_buf.rs diff --git a/rumqttc/src/v5/client/asyncclient.rs b/rumqttc/src/v5/client/asyncclient.rs index 798439e79..fb39f8d6a 100644 --- a/rumqttc/src/v5/client/asyncclient.rs +++ b/rumqttc/src/v5/client/asyncclient.rs @@ -1,13 +1,11 @@ -use std::{ - collections::VecDeque, - sync::{Arc, Mutex}, -}; +use std::sync::{Arc, Mutex}; use bytes::Bytes; use flume::{SendError, Sender, TrySendError}; use crate::v5::{ client::get_ack_req, + outgoing_buf::OutgoingBuf, packet::{Publish, Subscribe, SubscribeFilter, Unsubscribe}, ClientError, EventLoop, MqttOptions, QoS, Request, }; @@ -16,10 +14,7 @@ use crate::v5::{ /// This is cloneable and can be used to asynchronously Publish, Subscribe. #[derive(Debug)] pub struct AsyncClient { - pub(crate) outgoing_buf: Arc>>, - pub(crate) outgoing_buf_capacity: usize, - pub(crate) pkid_counter: u16, - pub(crate) max_inflight: u16, + pub(crate) outgoing_buf: Arc>, pub(crate) request_tx: Sender<()>, } @@ -27,15 +22,11 @@ impl AsyncClient { /// Create a new `AsyncClient` pub fn new(options: MqttOptions, cap: usize) -> (AsyncClient, EventLoop) { let eventloop = EventLoop::new(options, cap); - let outgoing_buf = eventloop.request_buf().clone(); + let outgoing_buf = eventloop.state.outgoing_buf.clone(); let request_tx = eventloop.handle(); - let max_inflight = eventloop.state.max_inflight; let client = AsyncClient { outgoing_buf, - outgoing_buf_capacity: cap, - pkid_counter: 0, - max_inflight, request_tx, }; @@ -57,13 +48,18 @@ impl AsyncClient { let mut publish = Publish::new(topic, qos, payload); publish.retain = retain; let pkid = if qos != QoS::AtMostOnce { - self.increment_pkid() + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + publish.pkid = pkid; + request_buf.buf.push_back(Request::Publish(publish)); + pkid } else { 0 }; - publish.pkid = pkid; - self.push_and_async_notify(Request::Publish(publish)) - .await?; + self.notify_async().await?; Ok(pkid) } @@ -82,19 +78,30 @@ impl AsyncClient { let mut publish = Publish::new(topic, qos, payload); publish.retain = retain; let pkid = if qos != QoS::AtMostOnce { - self.increment_pkid() + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + publish.pkid = pkid; + request_buf.buf.push_back(Request::Publish(publish)); + pkid } else { 0 }; - publish.pkid = pkid; - self.push_and_try_notify(Request::Publish(publish))?; + self.try_notify()?; Ok(pkid) } /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. pub async fn ack(&mut self, publish: &Publish) -> Result<(), ClientError> { if let Some(ack) = get_ack_req(publish.qos, publish.pkid) { - self.push_and_async_notify(ack).await?; + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + request_buf.buf.push_back(ack); + self.notify_async().await?; } Ok(()) } @@ -102,7 +109,12 @@ impl AsyncClient { /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. pub fn try_ack(&mut self, publish: &Publish) -> Result<(), ClientError> { if let Some(ack) = get_ack_req(publish.qos, publish.pkid) { - self.push_and_try_notify(ack)?; + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + request_buf.buf.push_back(ack); + self.try_notify()?; } Ok(()) } @@ -121,13 +133,18 @@ impl AsyncClient { let mut publish = Publish::from_bytes(topic, qos, payload); publish.retain = retain; let pkid = if qos != QoS::AtMostOnce { - self.increment_pkid() + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + publish.pkid = pkid; + request_buf.buf.push_back(Request::Publish(publish)); + pkid } else { 0 }; - publish.pkid = pkid; - self.push_and_async_notify(Request::Publish(publish)) - .await?; + self.notify_async().await?; Ok(pkid) } @@ -138,10 +155,17 @@ impl AsyncClient { qos: QoS, ) -> Result { let mut subscribe = Subscribe::new(topic.into(), qos); - let pkid = self.increment_pkid(); - subscribe.pkid = pkid; - self.push_and_async_notify(Request::Subscribe(subscribe)) - .await?; + let pkid = { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + subscribe.pkid = pkid; + request_buf.buf.push_back(Request::Subscribe(subscribe)); + pkid + }; + self.notify_async().await?; Ok(pkid) } @@ -152,9 +176,17 @@ impl AsyncClient { qos: QoS, ) -> Result { let mut subscribe = Subscribe::new(topic.into(), qos); - let pkid = self.increment_pkid(); - subscribe.pkid = pkid; - self.push_and_try_notify(Request::Subscribe(subscribe))?; + let pkid = { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + subscribe.pkid = pkid; + request_buf.buf.push_back(Request::Subscribe(subscribe)); + pkid + }; + self.try_notify()?; Ok(pkid) } @@ -164,10 +196,17 @@ impl AsyncClient { T: IntoIterator, { let mut subscribe = Subscribe::new_many(topics); - let pkid = self.increment_pkid(); - subscribe.pkid = pkid; - self.push_and_async_notify(Request::Subscribe(subscribe)) - .await?; + let pkid = { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + subscribe.pkid = pkid; + request_buf.buf.push_back(Request::Subscribe(subscribe)); + pkid + }; + self.notify_async().await?; Ok(pkid) } @@ -177,86 +216,97 @@ impl AsyncClient { T: IntoIterator, { let mut subscribe = Subscribe::new_many(topics); - let pkid = self.increment_pkid(); - subscribe.pkid = pkid; - self.push_and_try_notify(Request::Subscribe(subscribe))?; + let pkid = { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + subscribe.pkid = pkid; + request_buf.buf.push_back(Request::Subscribe(subscribe)); + pkid + }; + self.try_notify()?; Ok(pkid) } /// Sends a MQTT Unsubscribe to the eventloop pub async fn unsubscribe>(&mut self, topic: S) -> Result { let mut unsubscribe = Unsubscribe::new(topic.into()); - let pkid = self.increment_pkid(); - unsubscribe.pkid = pkid; - self.push_and_async_notify(Request::Unsubscribe(unsubscribe)) - .await?; + let pkid = { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + unsubscribe.pkid = pkid; + request_buf.buf.push_back(Request::Unsubscribe(unsubscribe)); + pkid + }; + self.notify_async().await?; Ok(pkid) } /// Sends a MQTT Unsubscribe to the eventloop pub fn try_unsubscribe>(&mut self, topic: S) -> Result { let mut unsubscribe = Unsubscribe::new(topic.into()); - let pkid = self.increment_pkid(); - unsubscribe.pkid = pkid; - self.push_and_try_notify(Request::Unsubscribe(unsubscribe))?; + let pkid = { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + unsubscribe.pkid = pkid; + request_buf.buf.push_back(Request::Unsubscribe(unsubscribe)); + pkid + }; + self.try_notify()?; Ok(pkid) } /// Sends a MQTT disconnect to the eventloop + #[inline] pub async fn disconnect(&mut self) -> Result<(), ClientError> { - self.push_and_async_notify(Request::Disconnect).await + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + request_buf.buf.push_back(Request::Disconnect); + self.notify_async().await } /// Sends a MQTT disconnect to the eventloop + #[inline] pub fn try_disconnect(&mut self) -> Result<(), ClientError> { - self.push_and_try_notify(Request::Disconnect) + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + request_buf.buf.push_back(Request::Disconnect); + self.try_notify() } - async fn push_and_async_notify(&self, request: Request) -> Result<(), ClientError> { - { - let mut request_buf = self.outgoing_buf.lock().unwrap(); - if request_buf.len() == self.outgoing_buf_capacity { - return Err(ClientError::RequestsFull); - } - request_buf.push_back(request); - } + #[inline] + async fn notify_async(&self) -> Result<(), ClientError> { if let Err(SendError(_)) = self.request_tx.send_async(()).await { return Err(ClientError::EventloopClosed); }; Ok(()) } - pub(crate) fn push_and_notify(&self, request: Request) -> Result<(), ClientError> { - let mut request_buf = self.outgoing_buf.lock().unwrap(); - if request_buf.len() == self.outgoing_buf_capacity { - return Err(ClientError::RequestsFull); - } - request_buf.push_back(request); + #[inline] + pub(crate) fn notify(&self) -> Result<(), ClientError> { if let Err(SendError(_)) = self.request_tx.send(()) { return Err(ClientError::EventloopClosed); }; Ok(()) } - fn push_and_try_notify(&self, request: Request) -> Result<(), ClientError> { - let mut request_buf = self.outgoing_buf.lock().unwrap(); - if request_buf.len() == self.outgoing_buf_capacity { - return Err(ClientError::RequestsFull); - } - request_buf.push_back(request); + #[inline] + fn try_notify(&self) -> Result<(), ClientError> { if let Err(TrySendError::Disconnected(_)) = self.request_tx.try_send(()) { return Err(ClientError::EventloopClosed); } Ok(()) } - - #[inline] - pub(crate) fn increment_pkid(&mut self) -> u16 { - self.pkid_counter = if self.pkid_counter == self.max_inflight { - 1 - } else { - self.pkid_counter + 1 - }; - self.pkid_counter - } } diff --git a/rumqttc/src/v5/client/syncclient.rs b/rumqttc/src/v5/client/syncclient.rs index ae94a1119..425ad4e03 100644 --- a/rumqttc/src/v5/client/syncclient.rs +++ b/rumqttc/src/v5/client/syncclient.rs @@ -42,8 +42,19 @@ impl Client { { let mut publish = Publish::new(topic, qos, payload); publish.retain = retain; - let pkid = publish.pkid; - self.client.push_and_notify(Request::Publish(publish))?; + let pkid = if qos != QoS::AtMostOnce { + let mut request_buf = self.client.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + publish.pkid = pkid; + request_buf.buf.push_back(Request::Publish(publish)); + pkid + } else { + 0 + }; + self.client.notify()?; Ok(pkid) } @@ -64,7 +75,12 @@ impl Client { /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. pub fn ack(&self, publish: &Publish) -> Result<(), ClientError> { if let Some(ack) = get_ack_req(publish.qos, publish.pkid) { - self.client.push_and_notify(ack)?; + let mut request_buf = self.client.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + request_buf.buf.push_back(ack); + self.client.notify()?; } Ok(()) } @@ -77,9 +93,19 @@ impl Client { /// Sends a MQTT Subscribe to the eventloop pub fn subscribe>(&mut self, topic: S, qos: QoS) -> Result { let mut subscribe = Subscribe::new(topic.into(), qos); - let pkid = self.client.increment_pkid(); - subscribe.pkid = pkid; - self.client.push_and_notify(Request::Subscribe(subscribe))?; + let pkid = if qos != QoS::AtMostOnce { + let mut request_buf = self.client.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + subscribe.pkid = pkid; + request_buf.buf.push_back(Request::Subscribe(subscribe)); + pkid + } else { + 0 + }; + self.client.notify()?; Ok(pkid) } @@ -98,9 +124,17 @@ impl Client { T: IntoIterator, { let mut subscribe = Subscribe::new_many(topics); - let pkid = self.client.increment_pkid(); - subscribe.pkid = pkid; - self.client.push_and_notify(Request::Subscribe(subscribe))?; + let pkid = { + let mut request_buf = self.client.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + subscribe.pkid = pkid; + request_buf.buf.push_back(Request::Subscribe(subscribe)); + pkid + }; + self.client.notify()?; Ok(pkid) } @@ -114,10 +148,17 @@ impl Client { /// Sends a MQTT Unsubscribe to the eventloop pub fn unsubscribe>(&mut self, topic: S) -> Result { let mut unsubscribe = Unsubscribe::new(topic.into()); - let pkid = self.client.increment_pkid(); - unsubscribe.pkid = pkid; - self.client - .push_and_notify(Request::Unsubscribe(unsubscribe))?; + let pkid = { + let mut request_buf = self.client.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + unsubscribe.pkid = pkid; + request_buf.buf.push_back(Request::Unsubscribe(unsubscribe)); + pkid + }; + self.client.notify()?; Ok(pkid) } @@ -128,7 +169,12 @@ impl Client { /// Sends a MQTT disconnect to the eventloop pub fn disconnect(&mut self) -> Result<(), ClientError> { - self.client.push_and_notify(Request::Disconnect) + let mut request_buf = self.client.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + request_buf.buf.push_back(Request::Disconnect); + self.client.notify() } /// Sends a MQTT disconnect to the eventloop diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index 90deb7c11..7725c373f 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -1,8 +1,8 @@ #[cfg(feature = "use-rustls")] use crate::v5::tls; use crate::v5::{ - framed::Network, packet::*, Incoming, MqttOptions, MqttState, Packet, Request, StateError, - Transport, + framed::Network, outgoing_buf::OutgoingBuf, packet::*, Incoming, MqttOptions, MqttState, + Packet, Request, StateError, Transport, }; #[cfg(feature = "websocket")] @@ -55,8 +55,8 @@ pub struct EventLoop { pub options: MqttOptions, /// Current state of the connection pub state: MqttState, - incoming_buf: Arc>>, - incoming_buf_cache: VecDeque, + outgoing_buf: Arc>, + outgoing_buf_cache: VecDeque, /// Request stream pub incoming_rx: Receiver<()>, /// Requests handle to send requests @@ -76,17 +76,18 @@ impl EventLoop { /// access and update `options`, `state` and `requests`. pub fn new(options: MqttOptions, cap: usize) -> EventLoop { let (incoming_tx, incoming_rx) = bounded(1); - let request_buf = Arc::new(Mutex::new(VecDeque::with_capacity(cap))); let pending = Vec::new(); let pending = pending.into_iter(); let max_inflight = options.inflight; let manual_acks = options.manual_acks; + let state = MqttState::new(max_inflight, manual_acks, cap); + let outgoing_buf = state.outgoing_buf.clone(); EventLoop { options, - state: MqttState::new(max_inflight, manual_acks, cap), - incoming_buf: request_buf, - incoming_buf_cache: VecDeque::with_capacity(cap), + state, + outgoing_buf, + outgoing_buf_cache: VecDeque::with_capacity(cap), incoming_tx, incoming_rx, pending, @@ -96,18 +97,11 @@ impl EventLoop { } /// Returns a handle to communicate with this eventloop + #[inline] pub fn handle(&self) -> Sender<()> { self.incoming_tx.clone() } - pub fn request_buf(&self) -> &Arc>> { - &self.incoming_buf - } - - pub fn sub_events_buf(&self) -> &Arc>> { - &self.state.incoming_buf - } - fn clean(&mut self) { self.network = None; self.keepalive_timeout = None; @@ -187,11 +181,11 @@ impl EventLoop { o = self.incoming_rx.recv_async(), if !inflight_full && !pending && !collision => match o { Ok(_request_notif) => { // swapping to avoid blocking the mutex - std::mem::swap(&mut self.incoming_buf_cache,&mut *self.incoming_buf.lock().unwrap()); - if self.incoming_buf_cache.is_empty() { + std::mem::swap(&mut self.outgoing_buf_cache, &mut self.outgoing_buf.lock().unwrap().buf); + if self.outgoing_buf_cache.is_empty() { continue; } - for request in self.incoming_buf_cache.drain(..) { + for request in self.outgoing_buf_cache.drain(..) { self.state.handle_outgoing_packet(request)?; } network.flush(&mut self.state.write).await?; diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index 615d18148..07230ab62 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -10,6 +10,7 @@ mod client; mod eventloop; mod framed; mod notifier; +mod outgoing_buf; #[allow(clippy::all)] mod packet; mod state; @@ -592,17 +593,13 @@ impl Debug for MqttOptions { pub async fn connect(options: MqttOptions, cap: usize) -> Result<(AsyncClient, Notifier), ()> { let mut eventloop = EventLoop::new(options, cap); - let outgoing_buf = eventloop.request_buf().clone(); + let outgoing_buf = eventloop.state.outgoing_buf.clone(); let incoming_buf = eventloop.state.incoming_buf.clone(); let incoming_buf_cache = VecDeque::with_capacity(cap); let request_tx = eventloop.handle(); - let max_inflight = eventloop.state.max_inflight; let client = AsyncClient { outgoing_buf, - outgoing_buf_capacity: cap, - pkid_counter: 0, - max_inflight, request_tx, }; diff --git a/rumqttc/src/v5/outgoing_buf.rs b/rumqttc/src/v5/outgoing_buf.rs new file mode 100644 index 000000000..d036dd0c8 --- /dev/null +++ b/rumqttc/src/v5/outgoing_buf.rs @@ -0,0 +1,34 @@ +use std::{ + collections::VecDeque, + sync::{Arc, Mutex}, +}; + +use crate::v5::Request; + +#[derive(Debug)] +pub struct OutgoingBuf { + pub(crate) buf: VecDeque, + pub(crate) pkid_counter: u16, + pub(crate) capacity: usize, +} + +impl OutgoingBuf { + #[inline] + pub fn new(cap: usize) -> Arc> { + Arc::new(Mutex::new(Self { + buf: VecDeque::with_capacity(cap), + pkid_counter: 0, + capacity: cap, + })) + } + + #[inline] + pub fn increment_pkid(&mut self) -> u16 { + self.pkid_counter = if self.pkid_counter == self.capacity as u16 { + 1 + } else { + self.pkid_counter + 1 + }; + self.pkid_counter + } +} diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 0b304f666..b6a8ebfd9 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -1,13 +1,14 @@ -use super::{packet::*, Incoming, Request}; - -use bytes::BytesMut; -use std::collections::VecDeque; use std::{ + collections::VecDeque, io, mem, sync::{Arc, Mutex}, time::Instant, }; +use bytes::BytesMut; + +use crate::v5::{outgoing_buf::OutgoingBuf, packet::*, Incoming, Request}; + /// Errors during state handling #[derive(Debug, thiserror::Error)] pub enum StateError { @@ -61,8 +62,6 @@ pub struct MqttState { last_outgoing: Instant, /// Number of outgoing inflight publishes pub(crate) inflight: u16, - /// Maximum number of allowed inflight - pub(crate) max_inflight: u16, /// Outgoing QoS 1, 2 publishes which aren't acked yet pub(crate) outgoing_pub: Vec>, /// Packet ids of released QoS 2 publishes @@ -76,6 +75,7 @@ pub struct MqttState { /// Indicates if acknowledgements should be send immediately pub manual_acks: bool, pub(crate) incoming_buf: Arc>>, + pub(crate) outgoing_buf: Arc>, } impl MqttState { @@ -89,7 +89,6 @@ impl MqttState { last_incoming: Instant::now(), last_outgoing: Instant::now(), inflight: 0, - max_inflight, // index 0 is wasted as 0 is not a valid packet id outgoing_pub: vec![None; max_inflight as usize + 1], outgoing_rel: vec![None; max_inflight as usize + 1], @@ -99,6 +98,7 @@ impl MqttState { write: BytesMut::with_capacity(10 * 1024), manual_acks, incoming_buf: Arc::new(Mutex::new(VecDeque::with_capacity(cap))), + outgoing_buf: OutgoingBuf::new(max_inflight as usize), } } @@ -137,6 +137,11 @@ impl MqttState { self.inflight } + #[inline] + pub fn cur_pkid(&self) -> u16 { + self.outgoing_buf.lock().unwrap().pkid_counter + } + /// Consolidates handling of all outgoing mqtt packet logic. Returns a packet which should /// be put on to the network by the eventloop pub fn handle_outgoing_packet(&mut self, request: Request) -> Result<(), StateError> { @@ -430,6 +435,11 @@ impl MqttState { Ok(pubrel) } + #[inline] + pub fn increment_pkid(&self) -> u16 { + self.outgoing_buf.lock().unwrap().increment_pkid() + } + ///// http://stackoverflow.com/questions/11115364/mqtt-messageid-practical-implementation ///// Packet ids are incremented till maximum set inflight messages and reset to 1 after that. ///// @@ -452,282 +462,298 @@ impl MqttState { #[cfg(test)] mod test { - // use super::{MqttState, StateError}; - // use crate::v5::{packet::*, Incoming, MqttOptions, Request}; - - // fn build_outgoing_publish(qos: QoS) -> Publish { - // let topic = "hello/world".to_owned(); - // let payload = vec![1, 2, 3]; - - // let mut publish = Publish::new(topic, QoS::AtLeastOnce, payload); - // publish.qos = qos; - // publish - // } - - // fn build_incoming_publish(qos: QoS, pkid: u16) -> Publish { - // let topic = "hello/world".to_owned(); - // let payload = vec![1, 2, 3]; - - // let mut publish = Publish::new(topic, QoS::AtLeastOnce, payload); - // publish.pkid = pkid; - // publish.qos = qos; - // publish - // } - - // fn build_mqttstate() -> MqttState { - // MqttState::new(100, false, 100) - // } - - // #[test] - // fn next_pkid_increments_as_expected() { - // let mqtt = build_mqttstate(); - - // for i in 1..=100 { - // let pkid = mqtt.increment_pkid(); - - // // loops between 0-99. % 100 == 0 implies border - // let expected = i % 100; - // if expected == 0 { - // break; - // } - - // assert_eq!(expected, pkid); - // } - // } - - // #[test] - // fn outgoing_publish_should_set_pkid_and_add_publish_to_queue() { - // let mut mqtt = build_mqttstate(); - - // // QoS0 Publish - // let publish = build_outgoing_publish(QoS::AtMostOnce); - - // // QoS 0 publish shouldn't be saved in queue - // mqtt.outgoing_publish(publish).unwrap(); - // assert_eq!(mqtt.cur_pkid(), 0); - // assert_eq!(mqtt.inflight, 0); - - // // QoS1 Publish - // let publish = build_outgoing_publish(QoS::AtLeastOnce); - - // // Packet id should be set and publish should be saved in queue - // mqtt.outgoing_publish(publish.clone()).unwrap(); - // assert_eq!(mqtt.cur_pkid(), 1); - // assert_eq!(mqtt.inflight, 1); - - // // Packet id should be incremented and publish should be saved in queue - // mqtt.outgoing_publish(publish).unwrap(); - // assert_eq!(mqtt.cur_pkid(), 2); - // assert_eq!(mqtt.inflight, 2); - - // // QoS1 Publish - // let publish = build_outgoing_publish(QoS::ExactlyOnce); - - // // Packet id should be set and publish should be saved in queue - // mqtt.outgoing_publish(publish.clone()).unwrap(); - // assert_eq!(mqtt.cur_pkid(), 3); - // assert_eq!(mqtt.inflight, 3); - - // // Packet id should be incremented and publish should be saved in queue - // mqtt.outgoing_publish(publish).unwrap(); - // assert_eq!(mqtt.cur_pkid(), 4); - // assert_eq!(mqtt.inflight, 4); - // } - - // #[test] - // fn incoming_publish_should_be_added_to_queue_correctly() { - // let mut mqtt = build_mqttstate(); - - // // QoS0, 1, 2 Publishes - // let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); - // let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); - // let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); - - // mqtt.handle_incoming_publish(&publish1).unwrap(); - // mqtt.handle_incoming_publish(&publish2).unwrap(); - // mqtt.handle_incoming_publish(&publish3).unwrap(); - - // let pkid = mqtt.incoming_pub[3].unwrap(); - - // // only qos2 publish should be add to queue - // assert_eq!(pkid, 3); - // } - - // #[test] - // fn incoming_publish_should_be_acked() { - // let mut mqtt = build_mqttstate(); - - // // QoS0, 1, 2 Publishes - // let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); - // let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); - // let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); - - // mqtt.handle_incoming_publish(&publish1).unwrap(); - // mqtt.handle_incoming_publish(&publish2).unwrap(); - // mqtt.handle_incoming_publish(&publish3).unwrap(); - // } - - // #[test] - // fn incoming_publish_should_not_be_acked_with_manual_acks() { - // let mut mqtt = build_mqttstate(); - // mqtt.manual_acks = true; - - // // QoS0, 1, 2 Publishes - // let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); - // let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); - // let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); - - // mqtt.handle_incoming_publish(&publish1).unwrap(); - // mqtt.handle_incoming_publish(&publish2).unwrap(); - // mqtt.handle_incoming_publish(&publish3).unwrap(); - - // let pkid = mqtt.incoming_pub[3].unwrap(); - // assert_eq!(pkid, 3); - - // assert!(mqtt.incoming_buf.lock().unwrap().is_empty()); - // } - - // #[test] - // fn incoming_qos2_publish_should_send_rec_to_network_and_publish_to_user() { - // let mut mqtt = build_mqttstate(); - // let publish = build_incoming_publish(QoS::ExactlyOnce, 1); - - // mqtt.handle_incoming_publish(&publish).unwrap(); - // let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); - // match packet { - // Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), - // _ => panic!("Invalid network request: {:?}", packet), - // } - // } - - // #[test] - // fn incoming_puback_should_remove_correct_publish_from_queue() { - // let mut mqtt = build_mqttstate(); - - // let publish1 = build_outgoing_publish(QoS::AtLeastOnce); - // let publish2 = build_outgoing_publish(QoS::ExactlyOnce); - - // mqtt.outgoing_publish(publish1).unwrap(); - // mqtt.outgoing_publish(publish2).unwrap(); - // assert_eq!(mqtt.inflight, 2); - - // mqtt.handle_incoming_puback(&PubAck::new(1)).unwrap(); - // assert_eq!(mqtt.inflight, 1); - - // mqtt.handle_incoming_puback(&PubAck::new(2)).unwrap(); - // assert_eq!(mqtt.inflight, 0); - - // assert!(mqtt.outgoing_pub[1].is_none()); - // assert!(mqtt.outgoing_pub[2].is_none()); - // } - - // #[test] - // fn incoming_pubrec_should_release_publish_from_queue_and_add_relid_to_rel_queue() { - // let mut mqtt = build_mqttstate(); - - // let publish1 = build_outgoing_publish(QoS::AtLeastOnce); - // let publish2 = build_outgoing_publish(QoS::ExactlyOnce); - - // let _publish_out = mqtt.outgoing_publish(publish1); - // let _publish_out = mqtt.outgoing_publish(publish2); - - // mqtt.handle_incoming_pubrec(&PubRec::new(2)).unwrap(); - // assert_eq!(mqtt.inflight, 2); - - // // check if the remaining element's pkid is 1 - // let backup = mqtt.outgoing_pub[1].clone(); - // assert_eq!(backup.unwrap().pkid, 1); - - // // check if the qos2 element's release pkid is 2 - // assert_eq!(mqtt.outgoing_rel[2].unwrap(), 2); - // } - - // #[test] - // fn incoming_pubrec_should_send_release_to_network_and_nothing_to_user() { - // let mut mqtt = build_mqttstate(); - - // let publish = build_outgoing_publish(QoS::ExactlyOnce); - // mqtt.outgoing_publish(publish).unwrap(); - // let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); - // match packet { - // Packet::Publish(publish) => assert_eq!(publish.pkid, 1), - // packet => panic!("Invalid network request: {:?}", packet), - // } - - // mqtt.handle_incoming_pubrec(&PubRec::new(1)).unwrap(); - // let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); - // match packet { - // Packet::PubRel(pubrel) => assert_eq!(pubrel.pkid, 1), - // packet => panic!("Invalid network request: {:?}", packet), - // } - // } - - // #[test] - // fn incoming_pubrel_should_send_comp_to_network_and_nothing_to_user() { - // let mut mqtt = build_mqttstate(); - // let publish = build_incoming_publish(QoS::ExactlyOnce, 1); - - // mqtt.handle_incoming_publish(&publish).unwrap(); - // let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); - // match packet { - // Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), - // packet => panic!("Invalid network request: {:?}", packet), - // } - - // mqtt.handle_incoming_pubrel(&PubRel::new(1)).unwrap(); - // let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); - // match packet { - // Packet::PubComp(pubcomp) => assert_eq!(pubcomp.pkid, 1), - // packet => panic!("Invalid network request: {:?}", packet), - // } - // } - - // #[test] - // fn incoming_pubcomp_should_release_correct_pkid_from_release_queue() { - // let mut mqtt = build_mqttstate(); - // let publish = build_outgoing_publish(QoS::ExactlyOnce); - - // mqtt.outgoing_publish(publish).unwrap(); - // mqtt.handle_incoming_pubrec(&PubRec::new(1)).unwrap(); - - // mqtt.handle_incoming_pubcomp(&PubComp::new(1)).unwrap(); - // assert_eq!(mqtt.inflight, 0); - // } - - // #[test] - // fn outgoing_ping_handle_should_throw_errors_for_no_pingresp() { - // let mut mqtt = build_mqttstate(); - // let mut opts = MqttOptions::new("test", "localhost", 1883); - // opts.set_keep_alive(std::time::Duration::from_secs(10)); - // mqtt.outgoing_ping().unwrap(); - - // // network activity other than pingresp - // let publish = build_outgoing_publish(QoS::AtLeastOnce); - // mqtt.handle_outgoing_packet(Request::Publish(publish)) - // .unwrap(); - // mqtt.handle_incoming_packet(Incoming::PubAck(PubAck::new(1))) - // .unwrap(); - - // // should throw error because we didn't get pingresp for previous ping - // match mqtt.outgoing_ping() { - // Ok(_) => panic!("Should throw pingresp await error"), - // Err(StateError::AwaitPingResp) => (), - // Err(e) => panic!("Should throw pingresp await error. Error = {:?}", e), - // } - // } - - // #[test] - // fn outgoing_ping_handle_should_succeed_if_pingresp_is_received() { - // let mut mqtt = build_mqttstate(); - - // let mut opts = MqttOptions::new("test", "localhost", 1883); - // opts.set_keep_alive(std::time::Duration::from_secs(10)); - - // // should ping - // mqtt.outgoing_ping().unwrap(); - // mqtt.handle_incoming_packet(Incoming::PingResp).unwrap(); - - // // should ping - // mqtt.outgoing_ping().unwrap(); - // } + use super::{MqttState, StateError}; + use crate::v5::{packet::*, Incoming, MqttOptions, Request}; + + fn build_outgoing_publish(qos: QoS) -> Publish { + let topic = "hello/world".to_owned(); + let payload = vec![1, 2, 3]; + + let mut publish = Publish::new(topic, QoS::AtLeastOnce, payload); + publish.qos = qos; + publish + } + + fn build_incoming_publish(qos: QoS, pkid: u16) -> Publish { + let topic = "hello/world".to_owned(); + let payload = vec![1, 2, 3]; + + let mut publish = Publish::new(topic, QoS::AtLeastOnce, payload); + publish.pkid = pkid; + publish.qos = qos; + publish + } + + fn build_mqttstate() -> MqttState { + MqttState::new(100, false, 100) + } + + #[test] + fn next_pkid_increments_as_expected() { + let mqtt = build_mqttstate(); + + for i in 1..=100 { + let pkid = mqtt.increment_pkid(); + + // loops between 0-99. % 100 == 0 implies border + let expected = i % 100; + if expected == 0 { + break; + } + + assert_eq!(expected, pkid); + } + } + + #[test] + fn outgoing_publish_should_set_pkid_and_add_publish_to_queue() { + let mut mqtt = build_mqttstate(); + + // QoS0 Publish + let mut publish = build_outgoing_publish(QoS::AtMostOnce); + publish.pkid = 1; + + // QoS 0 publish shouldn't be saved in queue + mqtt.outgoing_publish(publish).unwrap(); + assert_eq!(mqtt.cur_pkid(), 0); + assert_eq!(mqtt.inflight, 0); + + // QoS1 Publish + let mut publish = build_outgoing_publish(QoS::AtLeastOnce); + publish.pkid = 2; + + // Packet id should be set and publish should be saved in queue + mqtt.outgoing_publish(publish.clone()).unwrap(); + // cur_pkid == 0 as there is no client to update it + assert_eq!(mqtt.cur_pkid(), 0); + assert_eq!(mqtt.inflight, 1); + + // Packet id should be incremented and publish should be saved in queue + publish.pkid = 3; + mqtt.outgoing_publish(publish).unwrap(); + // cur_pkid == 0 as there is no client to update it + assert_eq!(mqtt.cur_pkid(), 0); + assert_eq!(mqtt.inflight, 2); + + // QoS1 Publish + let mut publish = build_outgoing_publish(QoS::ExactlyOnce); + publish.pkid = 4; + + // Packet id should be set and publish should be saved in queue + mqtt.outgoing_publish(publish.clone()).unwrap(); + // cur_pkid == 0 as there is no client to update it + assert_eq!(mqtt.cur_pkid(), 0); + assert_eq!(mqtt.inflight, 3); + + publish.pkid = 5; + // Packet id should be incremented and publish should be saved in queue + mqtt.outgoing_publish(publish).unwrap(); + // cur_pkid == 0 as there is no client to update it + assert_eq!(mqtt.cur_pkid(), 0); + assert_eq!(mqtt.inflight, 4); + } + + #[test] + fn incoming_publish_should_be_added_to_queue_correctly() { + let mut mqtt = build_mqttstate(); + + // QoS0, 1, 2 Publishes + let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); + let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); + let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); + + mqtt.handle_incoming_publish(&publish1).unwrap(); + mqtt.handle_incoming_publish(&publish2).unwrap(); + mqtt.handle_incoming_publish(&publish3).unwrap(); + + let pkid = mqtt.incoming_pub[3].unwrap(); + + // only qos2 publish should be add to queue + assert_eq!(pkid, 3); + } + + #[test] + fn incoming_publish_should_be_acked() { + let mut mqtt = build_mqttstate(); + + // QoS0, 1, 2 Publishes + let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); + let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); + let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); + + mqtt.handle_incoming_publish(&publish1).unwrap(); + mqtt.handle_incoming_publish(&publish2).unwrap(); + mqtt.handle_incoming_publish(&publish3).unwrap(); + } + + #[test] + fn incoming_publish_should_not_be_acked_with_manual_acks() { + let mut mqtt = build_mqttstate(); + mqtt.manual_acks = true; + + // QoS0, 1, 2 Publishes + let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); + let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); + let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); + + mqtt.handle_incoming_publish(&publish1).unwrap(); + mqtt.handle_incoming_publish(&publish2).unwrap(); + mqtt.handle_incoming_publish(&publish3).unwrap(); + + let pkid = mqtt.incoming_pub[3].unwrap(); + assert_eq!(pkid, 3); + + assert!(mqtt.incoming_buf.lock().unwrap().is_empty()); + } + + #[test] + fn incoming_qos2_publish_should_send_rec_to_network_and_publish_to_user() { + let mut mqtt = build_mqttstate(); + let publish = build_incoming_publish(QoS::ExactlyOnce, 1); + + mqtt.handle_incoming_publish(&publish).unwrap(); + let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + match packet { + Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), + _ => panic!("Invalid network request: {:?}", packet), + } + } + + #[test] + fn incoming_puback_should_remove_correct_publish_from_queue() { + let mut mqtt = build_mqttstate(); + + let mut publish1 = build_outgoing_publish(QoS::AtLeastOnce); + let mut publish2 = build_outgoing_publish(QoS::ExactlyOnce); + publish1.pkid = 1; + publish2.pkid = 2; + + mqtt.outgoing_publish(publish1).unwrap(); + mqtt.outgoing_publish(publish2).unwrap(); + assert_eq!(mqtt.inflight, 2); + + mqtt.handle_incoming_puback(&PubAck::new(1)).unwrap(); + assert_eq!(mqtt.inflight, 1); + + mqtt.handle_incoming_puback(&PubAck::new(2)).unwrap(); + assert_eq!(mqtt.inflight, 0); + + assert!(mqtt.outgoing_pub[1].is_none()); + assert!(mqtt.outgoing_pub[2].is_none()); + } + + #[test] + fn incoming_pubrec_should_release_publish_from_queue_and_add_relid_to_rel_queue() { + let mut mqtt = build_mqttstate(); + + let mut publish1 = build_outgoing_publish(QoS::AtLeastOnce); + let mut publish2 = build_outgoing_publish(QoS::ExactlyOnce); + publish1.pkid = 1; + publish2.pkid = 2; + + let _publish_out = mqtt.outgoing_publish(publish1); + let _publish_out = mqtt.outgoing_publish(publish2); + + mqtt.handle_incoming_pubrec(&PubRec::new(2)).unwrap(); + assert_eq!(mqtt.inflight, 2); + + // check if the remaining element's pkid is 1 + let backup = mqtt.outgoing_pub[1].clone(); + assert_eq!(backup.unwrap().pkid, 1); + + // check if the qos2 element's release pkid is 2 + assert_eq!(mqtt.outgoing_rel[2].unwrap(), 2); + } + + #[test] + fn incoming_pubrec_should_send_release_to_network_and_nothing_to_user() { + let mut mqtt = build_mqttstate(); + + let mut publish = build_outgoing_publish(QoS::ExactlyOnce); + publish.pkid = 1; + mqtt.outgoing_publish(publish).unwrap(); + let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + match packet { + Packet::Publish(publish) => assert_eq!(publish.pkid, 1), + packet => panic!("Invalid network request: {:?}", packet), + } + + mqtt.handle_incoming_pubrec(&PubRec::new(1)).unwrap(); + let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + match packet { + Packet::PubRel(pubrel) => assert_eq!(pubrel.pkid, 1), + packet => panic!("Invalid network request: {:?}", packet), + } + } + + #[test] + fn incoming_pubrel_should_send_comp_to_network_and_nothing_to_user() { + let mut mqtt = build_mqttstate(); + let publish = build_incoming_publish(QoS::ExactlyOnce, 1); + + mqtt.handle_incoming_publish(&publish).unwrap(); + let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + match packet { + Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), + packet => panic!("Invalid network request: {:?}", packet), + } + + mqtt.handle_incoming_pubrel(&PubRel::new(1)).unwrap(); + let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + match packet { + Packet::PubComp(pubcomp) => assert_eq!(pubcomp.pkid, 1), + packet => panic!("Invalid network request: {:?}", packet), + } + } + + #[test] + fn incoming_pubcomp_should_release_correct_pkid_from_release_queue() { + let mut mqtt = build_mqttstate(); + let mut publish = build_outgoing_publish(QoS::ExactlyOnce); + publish.pkid = 1; + + mqtt.outgoing_publish(publish).unwrap(); + mqtt.handle_incoming_pubrec(&PubRec::new(1)).unwrap(); + + mqtt.handle_incoming_pubcomp(&PubComp::new(1)).unwrap(); + assert_eq!(mqtt.inflight, 0); + } + + #[test] + fn outgoing_ping_handle_should_throw_errors_for_no_pingresp() { + let mut mqtt = build_mqttstate(); + let mut opts = MqttOptions::new("test", "localhost", 1883); + opts.set_keep_alive(std::time::Duration::from_secs(10)); + mqtt.outgoing_ping().unwrap(); + + // network activity other than pingresp + let mut publish = build_outgoing_publish(QoS::AtLeastOnce); + publish.pkid = 1; + mqtt.handle_outgoing_packet(Request::Publish(publish)) + .unwrap(); + mqtt.handle_incoming_packet(Incoming::PubAck(PubAck::new(1))) + .unwrap(); + + // should throw error because we didn't get pingresp for previous ping + match mqtt.outgoing_ping() { + Ok(_) => panic!("Should throw pingresp await error"), + Err(StateError::AwaitPingResp) => (), + Err(e) => panic!("Should throw pingresp await error. Error = {:?}", e), + } + } + + #[test] + fn outgoing_ping_handle_should_succeed_if_pingresp_is_received() { + let mut mqtt = build_mqttstate(); + + let mut opts = MqttOptions::new("test", "localhost", 1883); + opts.set_keep_alive(std::time::Duration::from_secs(10)); + + // should ping + mqtt.outgoing_ping().unwrap(); + mqtt.handle_incoming_packet(Incoming::PingResp).unwrap(); + + // should ping + mqtt.outgoing_ping().unwrap(); + } }