diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 1459755b7..70d41005c 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -2,6 +2,10 @@ on: pull_request: branches: - master + paths: + - '**.rs' + - 'Cargo.*' + - '*/Cargo.*' name: Build and Test diff --git a/Cargo.lock b/Cargo.lock index d2bc02c3e..142b86ca7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -46,6 +46,15 @@ dependencies = [ "winapi", ] +[[package]] +name = "ansi_term" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d52a9bb7ec0cf484c551830a7ce27bd20d67eac647e1befb56b0be4ee39a55d2" +dependencies = [ + "winapi", +] + [[package]] name = "anyhow" version = "1.0.56" @@ -408,6 +417,12 @@ dependencies = [ "uuid", ] +[[package]] +name = "diff" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e25ea47919b1560c4e3b7fe0aaab9becf5b84a10325ddf7db0f0ba5e1026499" + [[package]] name = "difference" version = "2.0.0" @@ -503,6 +518,19 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37ab347416e802de484e4d03c7316c48f1ecb56574dfd4a46a80f173ce1de04d" +[[package]] +name = "flume" +version = "0.10.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "843c03199d0c0ca54bc1ea90ac0d507274c28abcc4f691ae8b4eaa375087c76a" +dependencies = [ + "futures-core", + "futures-sink", + "nanorand", + "pin-project", + "spin 0.9.2", +] + [[package]] name = "fnv" version = "1.0.7" @@ -653,8 +681,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9be70c98951c83b8d2f8f60d7065fa6d5146873094452a1008da8c2f1e4205ad" dependencies = [ "cfg-if 1.0.0", + "js-sys", "libc", "wasi 0.10.2+wasi-snapshot-preview1", + "wasm-bindgen", ] [[package]] @@ -1066,6 +1096,15 @@ dependencies = [ "twoway", ] +[[package]] +name = "nanorand" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a51313c5820b0b02bd422f4b44776fbf47961755c74ce64afc73bfad10226c3" +dependencies = [ + "getrandom", +] + [[package]] name = "native-tls" version = "0.2.10" @@ -1419,12 +1458,24 @@ version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f81e1644e1b54f5a68959a29aa86cde704219254669da328ecfdf6a1f09d427" dependencies = [ - "ansi_term", + "ansi_term 0.11.0", "ctor", "difference", "output_vt100", ] +[[package]] +name = "pretty_assertions" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c038cb5319b9c704bf9c227c261d275bfec0ad438118a2787ce47944fb228b" +dependencies = [ + "ansi_term 0.12.1", + "ctor", + "diff", + "output_vt100", +] + [[package]] name = "pretty_env_logger" version = "0.4.0" @@ -1636,7 +1687,7 @@ dependencies = [ "cc", "libc", "once_cell", - "spin", + "spin 0.5.2", "untrusted", "web-sys", "winapi", @@ -1652,12 +1703,13 @@ dependencies = [ "color-backtrace", "crossbeam-channel", "envy", + "flume", "http", "jsonwebtoken", "log", "matches", "pollster", - "pretty_assertions", + "pretty_assertions 1.2.0", "pretty_env_logger", "rustls", "rustls-native-certs", @@ -1682,7 +1734,7 @@ dependencies = [ "jemallocator", "log", "pprof 0.4.5", - "pretty_assertions", + "pretty_assertions 0.6.1", "pretty_env_logger", "rustls-pemfile 0.3.0", "segments", @@ -1962,6 +2014,15 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" +[[package]] +name = "spin" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "511254be0c5bcf062b019a6c89c01a664aa359ded62f78aa72c6fc137c0590e5" +dependencies = [ + "lock_api", +] + [[package]] name = "stable_deref_trait" version = "1.2.0" diff --git a/benchmarks/clients/mesh.rs b/benchmarks/clients/mesh.rs index 8f6d04eec..96f38de50 100644 --- a/benchmarks/clients/mesh.rs +++ b/benchmarks/clients/mesh.rs @@ -3,7 +3,7 @@ use tokio::task; use std::thread; use std::path::PathBuf; use bytes::Bytes; -use rumqttlog::router::{Data}; +use rumqttlog::router::Data; mod common; @@ -74,5 +74,3 @@ async fn read(tx: Sender<(usize, RouterInMessage)>) { println!("Id = {}, Total size = {}", id, total_size); } - - diff --git a/benchmarks/clients/rumqttasync.rs b/benchmarks/clients/rumqttasync.rs index 3bfb49ee4..d16cf466b 100644 --- a/benchmarks/clients/rumqttasync.rs +++ b/benchmarks/clients/rumqttasync.rs @@ -1,4 +1,5 @@ -use rumqttc::*; +use rumqttc::{AsyncClient, Event, Incoming, MqttOptions, QoS}; + use std::error::Error; use std::time::{Duration, Instant}; diff --git a/benchmarks/clients/rumqttasyncqos0.rs b/benchmarks/clients/rumqttasyncqos0.rs index a2a668b0b..d00dd300a 100644 --- a/benchmarks/clients/rumqttasyncqos0.rs +++ b/benchmarks/clients/rumqttasyncqos0.rs @@ -1,4 +1,5 @@ -use rumqttc::*; +use rumqttc::{AsyncClient, Event, Incoming, MqttOptions, QoS}; + use std::error::Error; use std::time::{Duration, Instant}; diff --git a/benchmarks/clients/rumqttsync.rs b/benchmarks/clients/rumqttsync.rs index da85194dd..20c3e5c20 100644 --- a/benchmarks/clients/rumqttsync.rs +++ b/benchmarks/clients/rumqttsync.rs @@ -1,4 +1,4 @@ -use rumqttc::{self, Client, Event, Incoming, MqttOptions, QoS}; +use rumqttc::{Client, Event, Incoming, MqttOptions, QoS}; use std::error::Error; use std::thread; use std::time::{Duration, Instant}; diff --git a/benchmarks/simplerouter/src/protocol/v5.rs b/benchmarks/simplerouter/src/protocol/v5.rs index c63fe816e..ab58f8ebd 100644 --- a/benchmarks/simplerouter/src/protocol/v5.rs +++ b/benchmarks/simplerouter/src/protocol/v5.rs @@ -135,16 +135,16 @@ pub(crate) mod connect { len } - fn read(connect_flags: u8, mut bytes: &mut Bytes) -> Result, Error> { + fn read(connect_flags: u8, bytes: &mut Bytes) -> Result, Error> { let last_will = match connect_flags & 0b100 { 0 if (connect_flags & 0b0011_1000) != 0 => { return Err(Error::IncorrectPacketFormat); } 0 => None, _ => { - let will_topic = read_mqtt_bytes(&mut bytes)?; + let will_topic = read_mqtt_bytes(bytes)?; let will_topic = std::str::from_utf8(&will_topic)?.to_owned(); - let will_message = read_mqtt_bytes(&mut bytes)?; + let will_message = read_mqtt_bytes(bytes)?; let will_qos = qos((connect_flags & 0b11000) >> 3)?; Some(LastWill { topic: will_topic, @@ -173,11 +173,11 @@ pub(crate) mod connect { } } - fn read(connect_flags: u8, mut bytes: &mut Bytes) -> Result, Error> { + fn read(connect_flags: u8, bytes: &mut Bytes) -> Result, Error> { let username = match connect_flags & 0b1000_0000 { 0 => String::new(), _ => { - let username = read_mqtt_bytes(&mut bytes)?; + let username = read_mqtt_bytes(bytes)?; std::str::from_utf8(&username)?.to_owned() } }; @@ -185,7 +185,7 @@ pub(crate) mod connect { let password = match connect_flags & 0b0100_0000 { 0 => String::new(), _ => { - let password = read_mqtt_bytes(&mut bytes)?; + let password = read_mqtt_bytes(bytes)?; std::str::from_utf8(&password)?.to_owned() } }; @@ -247,7 +247,7 @@ pub(crate) mod connect { } } - fn read(mut bytes: &mut Bytes) -> Result, Error> { + fn read(bytes: &mut Bytes) -> Result, Error> { let mut session_expiry_interval = None; let mut receive_maximum = None; let mut max_packet_size = None; @@ -267,49 +267,49 @@ pub(crate) mod connect { let mut cursor = 0; // read until cursor reaches property length. properties_len = 0 will skip this loop while cursor < properties_len { - let prop = read_u8(&mut bytes)?; + let prop = read_u8(bytes)?; cursor += 1; match property(prop)? { PropertyType::SessionExpiryInterval => { - session_expiry_interval = Some(read_u32(&mut bytes)?); + session_expiry_interval = Some(read_u32(bytes)?); cursor += 4; } PropertyType::ReceiveMaximum => { - receive_maximum = Some(read_u16(&mut bytes)?); + receive_maximum = Some(read_u16(bytes)?); cursor += 2; } PropertyType::MaximumPacketSize => { - max_packet_size = Some(read_u32(&mut bytes)?); + max_packet_size = Some(read_u32(bytes)?); cursor += 4; } PropertyType::TopicAliasMaximum => { - topic_alias_max = Some(read_u16(&mut bytes)?); + topic_alias_max = Some(read_u16(bytes)?); cursor += 2; } PropertyType::RequestResponseInformation => { - request_response_info = Some(read_u8(&mut bytes)?); + request_response_info = Some(read_u8(bytes)?); cursor += 1; } PropertyType::RequestProblemInformation => { - request_problem_info = Some(read_u8(&mut bytes)?); + request_problem_info = Some(read_u8(bytes)?); cursor += 1; } PropertyType::UserProperty => { - let key = read_mqtt_bytes(&mut bytes)?; + let key = read_mqtt_bytes(bytes)?; let key = std::str::from_utf8(&key)?.to_owned(); - let value = read_mqtt_bytes(&mut bytes)?; + let value = read_mqtt_bytes(bytes)?; let value = std::str::from_utf8(&value)?.to_owned(); cursor += 2 + key.len() + 2 + value.len(); user_properties.push((key, value)); } PropertyType::AuthenticationMethod => { - let method = read_mqtt_bytes(&mut bytes)?; + let method = read_mqtt_bytes(bytes)?; let method = std::str::from_utf8(&method)?.to_owned(); cursor += 2 + method.len(); authentication_method = Some(method); } PropertyType::AuthenticationData => { - let data = read_mqtt_bytes(&mut bytes)?; + let data = read_mqtt_bytes(bytes)?; cursor += 2 + data.len(); authentication_data = Some(data); } @@ -541,23 +541,23 @@ pub(crate) mod connack { pub fn len(&self) -> usize { let mut len = 0; - if let Some(_) = &self.session_expiry_interval { + if self.session_expiry_interval.is_some() { len += 1 + 4; } - if let Some(_) = &self.receive_max { + if self.receive_max.is_some() { len += 1 + 2; } - if let Some(_) = &self.max_qos { + if self.max_qos.is_some() { len += 1 + 1; } - if let Some(_) = &self.retain_available { + if self.retain_available.is_some() { len += 1 + 1; } - if let Some(_) = &self.max_packet_size { + if self.max_packet_size.is_some() { len += 1 + 4; } @@ -565,7 +565,7 @@ pub(crate) mod connack { len += 1 + 2 + id.len(); } - if let Some(_) = &self.topic_alias_max { + if self.topic_alias_max.is_some() { len += 1 + 2; } @@ -577,19 +577,19 @@ pub(crate) mod connack { len += 1 + 2 + key.len() + 2 + value.len(); } - if let Some(_) = &self.wildcard_subscription_available { + if self.wildcard_subscription_available.is_some() { len += 1 + 1; } - if let Some(_) = &self.subscription_identifiers_available { + if self.subscription_identifiers_available.is_some() { len += 1 + 1; } - if let Some(_) = &self.shared_subscription_available { + if self.shared_subscription_available.is_some() { len += 1 + 1; } - if let Some(_) = &self.server_keep_alive { + if self.server_keep_alive.is_some() { len += 1 + 2; } @@ -612,7 +612,7 @@ pub(crate) mod connack { len } - pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + pub fn extract(bytes: &mut Bytes) -> Result, Error> { let mut session_expiry_interval = None; let mut receive_max = None; let mut max_qos = None; @@ -640,90 +640,90 @@ pub(crate) mod connack { let mut cursor = 0; // read until cursor reaches property length. properties_len = 0 will skip this loop while cursor < properties_len { - let prop = read_u8(&mut bytes)?; + let prop = read_u8(bytes)?; cursor += 1; match property(prop)? { PropertyType::SessionExpiryInterval => { - session_expiry_interval = Some(read_u32(&mut bytes)?); + session_expiry_interval = Some(read_u32(bytes)?); cursor += 4; } PropertyType::ReceiveMaximum => { - receive_max = Some(read_u16(&mut bytes)?); + receive_max = Some(read_u16(bytes)?); cursor += 2; } PropertyType::MaximumQos => { - max_qos = Some(read_u8(&mut bytes)?); + max_qos = Some(read_u8(bytes)?); cursor += 1; } PropertyType::RetainAvailable => { - retain_available = Some(read_u8(&mut bytes)?); + retain_available = Some(read_u8(bytes)?); cursor += 1; } PropertyType::AssignedClientIdentifier => { - let bytes = read_mqtt_bytes(&mut bytes)?; + let bytes = read_mqtt_bytes(bytes)?; let id = std::str::from_utf8(&bytes)?.to_owned(); cursor += 2 + id.len(); assigned_client_identifier = Some(id); } PropertyType::MaximumPacketSize => { - max_packet_size = Some(read_u32(&mut bytes)?); + max_packet_size = Some(read_u32(bytes)?); cursor += 4; } PropertyType::TopicAliasMaximum => { - topic_alias_max = Some(read_u16(&mut bytes)?); + topic_alias_max = Some(read_u16(bytes)?); cursor += 2; } PropertyType::ReasonString => { - let reason = read_mqtt_bytes(&mut bytes)?; + let reason = read_mqtt_bytes(bytes)?; let reason = std::str::from_utf8(&reason)?.to_owned(); cursor += 2 + reason.len(); reason_string = Some(reason); } PropertyType::UserProperty => { - let key = read_mqtt_bytes(&mut bytes)?; + let key = read_mqtt_bytes(bytes)?; let key = std::str::from_utf8(&key)?.to_owned(); - let value = read_mqtt_bytes(&mut bytes)?; + let value = read_mqtt_bytes(bytes)?; let value = std::str::from_utf8(&value)?.to_owned(); cursor += 2 + key.len() + 2 + value.len(); user_properties.push((key, value)); } PropertyType::WildcardSubscriptionAvailable => { - wildcard_subscription_available = Some(read_u8(&mut bytes)?); + wildcard_subscription_available = Some(read_u8(bytes)?); cursor += 1; } PropertyType::SubscriptionIdentifierAvailable => { - subscription_identifiers_available = Some(read_u8(&mut bytes)?); + subscription_identifiers_available = Some(read_u8(bytes)?); cursor += 1; } PropertyType::SharedSubscriptionAvailable => { - shared_subscription_available = Some(read_u8(&mut bytes)?); + shared_subscription_available = Some(read_u8(bytes)?); cursor += 1; } PropertyType::ServerKeepAlive => { - server_keep_alive = Some(read_u16(&mut bytes)?); + server_keep_alive = Some(read_u16(bytes)?); cursor += 2; } PropertyType::ResponseInformation => { - let info = read_mqtt_bytes(&mut bytes)?; + let info = read_mqtt_bytes(bytes)?; let info = std::str::from_utf8(&info)?.to_owned(); cursor += 2 + info.len(); response_information = Some(info); } PropertyType::ServerReference => { - let bytes = read_mqtt_bytes(&mut bytes)?; + let bytes = read_mqtt_bytes(bytes)?; let reference = std::str::from_utf8(&bytes)?.to_owned(); cursor += 2 + reference.len(); server_reference = Some(reference); } PropertyType::AuthenticationMethod => { - let bytes = read_mqtt_bytes(&mut bytes)?; + let bytes = read_mqtt_bytes(bytes)?; let method = std::str::from_utf8(&bytes)?.to_owned(); cursor += 2 + method.len(); authentication_method = Some(method); } PropertyType::AuthenticationData => { - let data = read_mqtt_bytes(&mut bytes)?; + let data = read_mqtt_bytes(bytes)?; cursor += 2 + data.len(); authentication_data = Some(data); } @@ -926,7 +926,7 @@ pub(crate) mod publish { // FIXME: Remove indexes and use get method let stream = &self.raw[self.fixed_header.fixed_header_len..]; - let topic_len = view_u16(&stream)? as usize; + let topic_len = view_u16(stream)? as usize; let stream = &stream[2..]; let topic = view_str(stream, topic_len)?; @@ -935,8 +935,7 @@ pub(crate) mod publish { 0 => 0, 1 => { let stream = &stream[topic_len..]; - let pkid = view_u16(stream)?; - pkid + view_u16(stream)? } v => return Err(Error::InvalidQoS(v)), }; @@ -951,7 +950,7 @@ pub(crate) mod publish { pub fn view_topic(&self) -> Result<&str, Error> { // FIXME: Remove indexes let stream = &self.raw[self.fixed_header.fixed_header_len..]; - let topic_len = view_u16(&stream)? as usize; + let topic_len = view_u16(stream)? as usize; let stream = &stream[2..]; let topic = view_str(stream, topic_len)?; @@ -1030,7 +1029,7 @@ pub(crate) mod publish { buffer.put_u16(pkid); } - buffer.extend_from_slice(&payload); + buffer.extend_from_slice(payload); // TODO: Returned length is wrong in other packets. Fix it Ok(1 + count + len) @@ -1164,7 +1163,7 @@ pub(crate) mod puback { len } - pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + pub fn extract(bytes: &mut Bytes) -> Result, Error> { let mut reason_string = None; let mut user_properties = Vec::new(); @@ -1177,20 +1176,20 @@ pub(crate) mod puback { let mut cursor = 0; // read until cursor reaches property length. properties_len = 0 will skip this loop while cursor < properties_len { - let prop = read_u8(&mut bytes)?; + let prop = read_u8(bytes)?; cursor += 1; match property(prop)? { PropertyType::ReasonString => { - let bytes = read_mqtt_bytes(&mut bytes)?; + let bytes = read_mqtt_bytes(bytes)?; let reason = std::str::from_utf8(&bytes)?.to_owned(); cursor += 2 + reason.len(); reason_string = Some(reason); } PropertyType::UserProperty => { - let key = read_mqtt_bytes(&mut bytes)?; + let key = read_mqtt_bytes(bytes)?; let key = std::str::from_utf8(&key)?.to_owned(); - let value = read_mqtt_bytes(&mut bytes)?; + let value = read_mqtt_bytes(bytes)?; let value = std::str::from_utf8(&value)?.to_owned(); cursor += 2 + key.len() + 2 + value.len(); user_properties.push((key, value)); @@ -1264,8 +1263,7 @@ pub(crate) mod subscribe { retain_forward_rule: RetainForwardRule::OnEverySubscribe, }; - let mut filters = Vec::new(); - filters.push(filter); + let filters = vec![filter]; Subscribe { pkid: 0, filters, @@ -1318,10 +1316,10 @@ pub(crate) mod subscribe { let requested_qos = options & 0b0000_0011; let nolocal = options >> 2 & 0b0000_0001; - let nolocal = if nolocal == 0 { false } else { true }; + let nolocal = !(nolocal == 0); let preserve_retain = options >> 3 & 0b0000_0001; - let preserve_retain = if preserve_retain == 0 { false } else { true }; + let preserve_retain = !(preserve_retain == 0); let retain_forward_rule = (options >> 4) & 0b0000_0011; let retain_forward_rule = match retain_forward_rule { @@ -1461,7 +1459,7 @@ pub(crate) mod subscribe { len } - pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + pub fn extract(bytes: &mut Bytes) -> Result, Error> { let mut id = None; let mut user_properties = Vec::new(); @@ -1475,7 +1473,7 @@ pub(crate) mod subscribe { let mut cursor = 0; // read until cursor reaches property length. properties_len = 0 will skip this loop while cursor < properties_len { - let prop = read_u8(&mut bytes)?; + let prop = read_u8(bytes)?; cursor += 1; match property(prop)? { @@ -1487,9 +1485,9 @@ pub(crate) mod subscribe { id = Some(sub_id) } PropertyType::UserProperty => { - let key = read_mqtt_bytes(&mut bytes)?; + let key = read_mqtt_bytes(bytes)?; let key = std::str::from_utf8(&key)?.to_owned(); - let value = read_mqtt_bytes(&mut bytes)?; + let value = read_mqtt_bytes(bytes)?; let value = std::str::from_utf8(&value)?.to_owned(); cursor += 2 + key.len() + 2 + value.len(); user_properties.push((key, value)); @@ -1679,7 +1677,7 @@ pub(crate) mod suback { len } - pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + pub fn extract(bytes: &mut Bytes) -> Result, Error> { let mut reason_string = None; let mut user_properties = Vec::new(); @@ -1692,20 +1690,20 @@ pub(crate) mod suback { let mut cursor = 0; // read until cursor reaches property length. properties_len = 0 will skip this loop while cursor < properties_len { - let prop = read_u8(&mut bytes)?; + let prop = read_u8(bytes)?; cursor += 1; match property(prop)? { PropertyType::ReasonString => { - let bytes = read_mqtt_bytes(&mut bytes)?; + let bytes = read_mqtt_bytes(bytes)?; let reason = std::str::from_utf8(&bytes)?.to_owned(); cursor += 2 + reason.len(); reason_string = Some(reason); } PropertyType::UserProperty => { - let key = read_mqtt_bytes(&mut bytes)?; + let key = read_mqtt_bytes(bytes)?; let key = std::str::from_utf8(&key)?.to_owned(); - let value = read_mqtt_bytes(&mut bytes)?; + let value = read_mqtt_bytes(bytes)?; let value = std::str::from_utf8(&value)?.to_owned(); cursor += 2 + key.len() + 2 + value.len(); user_properties.push((key, value)); diff --git a/rumqttc/Cargo.toml b/rumqttc/Cargo.toml index f8fe525b5..aa157ab75 100644 --- a/rumqttc/Cargo.toml +++ b/rumqttc/Cargo.toml @@ -31,6 +31,7 @@ log = "0.4" thiserror = "1.0.21" http = "^0.2" url = { version = "2.2", default-features = false, optional = true } +flume = "0.10.10" [dev-dependencies] pretty_env_logger = "0.4" @@ -43,4 +44,4 @@ tokio = { version = "1.0", features = ["full", "macros"] } matches = "0.1.8" rustls = "0.20.2" rustls-native-certs = "0.6.1" -pretty_assertions = "0.6.1" +pretty_assertions = "1.1.0" diff --git a/rumqttc/examples/async_manual_acks_v5.rs b/rumqttc/examples/async_manual_acks_v5.rs new file mode 100644 index 000000000..079049e01 --- /dev/null +++ b/rumqttc/examples/async_manual_acks_v5.rs @@ -0,0 +1,81 @@ +#![allow(dead_code, unused_imports)] +use tokio::{task, time}; + +use rumqttc::v5::{AsyncClient, EventLoop, MqttOptions, QoS}; +use std::error::Error; +use std::time::Duration; + +fn create_conn() -> (AsyncClient, EventLoop) { + let mut mqttoptions = MqttOptions::new("test-1", "localhost", 1883); + mqttoptions + .set_keep_alive(Duration::from_secs(5)) + .set_manual_acks(true) + .set_clean_session(false); + + AsyncClient::new(mqttoptions, 10) +} + +#[tokio::main(worker_threads = 1)] +async fn main() -> Result<(), Box> { + todo!("fix this example with new way of spawning clients") + // pretty_env_logger::init(); + + // // create mqtt connection with clean_session = false and manual_acks = true + // let (client, mut eventloop) = create_conn(); + + // // subscribe example topic + // client + // .subscribe("hello/world", QoS::AtLeastOnce) + // .await + // .unwrap(); + + // task::spawn(async move { + // // send some messages to example topic and disconnect + // requests(client.clone()).await; + // client.disconnect().await.unwrap() + // }); + + // loop { + // // get subscribed messages without acking + // let event = eventloop.poll().await; + // println!("{:?}", event); + // if let Err(_err) = event { + // // break loop on disconnection + // break; + // } + // } + + // // create new broker connection + // let (_client, mut eventloop) = create_conn(); + + // loop { + // // previously published messages should be republished after reconnection. + // let event = eventloop.poll().await; + // println!("{:?}", event); + + // todo!("fix the commented out code below") + + // // match event { + // // Ok(Event::Incoming(Incoming::Publish(publish))) => { + // // // this time we will ack incoming publishes. + // // // Its important not to block eventloop as this can cause deadlock. + // // let c = client.clone(); + // // tokio::spawn(async move { + // // c.ack(&publish).await.unwrap(); + // // }); + // // } + // // _ => {} + // // } + // } +} + +async fn requests(client: &AsyncClient) { + for i in 1..=10 { + client + .publish("hello/world", QoS::AtLeastOnce, false, vec![1; i]) + .await + .unwrap(); + + time::sleep(Duration::from_secs(1)).await; + } +} diff --git a/rumqttc/examples/asyncpubsub_v5.rs b/rumqttc/examples/asyncpubsub_v5.rs new file mode 100644 index 000000000..b0b7a6698 --- /dev/null +++ b/rumqttc/examples/asyncpubsub_v5.rs @@ -0,0 +1,43 @@ +use tokio::{task, time}; + +use rumqttc::v5::{AsyncClient, MqttOptions, QoS}; +use std::error::Error; +use std::time::Duration; + +#[tokio::main(worker_threads = 1)] +async fn main() -> Result<(), Box> { + pretty_env_logger::init(); + // color_backtrace::install(); + + let mut mqttoptions = MqttOptions::new("test-1", "localhost", 1883); + mqttoptions.set_keep_alive(Duration::from_secs(5)); + + let (client, mut eventloop) = AsyncClient::new(mqttoptions, 10); + task::spawn(async move { + requests(client).await; + time::sleep(Duration::from_secs(3)).await; + }); + + loop { + let event = eventloop.poll().await; + println!("{:?}", event.unwrap()); + } +} + +async fn requests(client: AsyncClient) { + client + .subscribe("hello/world", QoS::AtMostOnce) + .await + .unwrap(); + + for i in 1..=10 { + client + .publish("hello/world", QoS::ExactlyOnce, false, vec![1; i]) + .await + .unwrap(); + + time::sleep(Duration::from_secs(1)).await; + } + + time::sleep(Duration::from_secs(120)).await; +} diff --git a/rumqttc/examples/syncpubsub_v5.rs b/rumqttc/examples/syncpubsub_v5.rs new file mode 100644 index 000000000..12957aaa9 --- /dev/null +++ b/rumqttc/examples/syncpubsub_v5.rs @@ -0,0 +1,35 @@ +use rumqttc::v5::{Client, LastWill, MqttOptions, QoS}; +use std::thread; +use std::time::Duration; + +fn main() { + pretty_env_logger::init(); + + let mut mqttoptions = MqttOptions::new("test-1", "localhost", 1883); + let will = LastWill::new("hello/world", "good bye", QoS::AtMostOnce, false); + mqttoptions + .set_keep_alive(Duration::from_secs(5)) + .set_last_will(will); + + let (client, mut connection) = Client::new(mqttoptions, 10); + thread::spawn(move || publish(client)); + + for (i, notification) in connection.iter().enumerate() { + println!("{}. Notification = {:?}", i, notification); + } + + println!("Done with the stream!!"); +} + +fn publish(client: Client) { + client.subscribe("hello/+/world", QoS::AtMostOnce).unwrap(); + for i in 0..10 { + let payload = vec![1; i as usize]; + let topic = format!("hello/{}/world", i); + let qos = QoS::AtLeastOnce; + + client.publish(topic, qos, true, payload).unwrap(); + } + + thread::sleep(Duration::from_secs(1)); +} diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index ab9f7e2c1..d1cf4af68 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -111,6 +111,7 @@ pub mod mqttbytes; mod state; #[cfg(feature = "use-rustls")] mod tls; +pub mod v5; pub use async_channel::{SendError, Sender, TrySendError}; pub use client::{AsyncClient, Client, ClientError, Connection}; @@ -434,9 +435,7 @@ impl MqttOptions { /// Set number of seconds after which client should ping the broker /// if there is no other data exchange pub fn set_keep_alive(&mut self, duration: Duration) -> &mut Self { - if duration.as_secs() < 5 { - panic!("Keep alives should be >= 5 secs"); - } + assert!(duration.as_secs() >= 5, "Keep alives should be >= 5 secs"); self.keep_alive = duration; self @@ -519,9 +518,7 @@ impl MqttOptions { /// Set number of concurrent in flight messages pub fn set_inflight(&mut self, inflight: u16) -> &mut Self { - if inflight == 0 { - panic!("zero in flight is not allowed") - } + assert!(inflight != 0, "zero in flight is not allowed"); self.inflight = inflight; self diff --git a/rumqttc/src/v5/client/asyncclient.rs b/rumqttc/src/v5/client/asyncclient.rs new file mode 100644 index 000000000..6c6e498bd --- /dev/null +++ b/rumqttc/src/v5/client/asyncclient.rs @@ -0,0 +1,308 @@ +use std::sync::{Arc, Mutex}; + +use bytes::Bytes; +use flume::{SendError, Sender, TrySendError}; + +use crate::v5::{ + client::get_ack_req, + outgoing_buf::OutgoingBuf, + packet::{Publish, Subscribe, SubscribeFilter, Unsubscribe}, + ClientError, EventLoop, MqttOptions, QoS, Request, +}; + +/// `AsyncClient` to communicate with MQTT `Eventloop` +/// This is cloneable and can be used to asynchronously Publish, Subscribe. +#[derive(Clone, Debug)] +pub struct AsyncClient { + pub(crate) outgoing_buf: Arc>, + pub(crate) request_tx: Sender<()>, +} + +impl AsyncClient { + /// Create a new `AsyncClient` + pub fn new(options: MqttOptions, cap: usize) -> (AsyncClient, EventLoop) { + let eventloop = EventLoop::new(options, cap); + let outgoing_buf = eventloop.state.outgoing_buf.clone(); + let request_tx = eventloop.handle(); + + let client = AsyncClient { + outgoing_buf, + request_tx, + }; + + (client, eventloop) + } + + /// Sends a MQTT Publish to the eventloop + pub async fn publish( + &self, + topic: S, + qos: QoS, + retain: bool, + payload: V, + ) -> Result + where + S: Into, + V: Into>, + { + let mut publish = Publish::new(topic, qos, payload); + publish.retain = retain; + let pkid = if qos != QoS::AtMostOnce { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + publish.pkid = pkid; + request_buf.buf.push_back(Request::Publish(publish)); + pkid + } else { + 0 + }; + self.notify_async().await?; + Ok(pkid) + } + + /// Sends a MQTT Publish to the eventloop + pub fn try_publish( + &self, + topic: S, + qos: QoS, + retain: bool, + payload: V, + ) -> Result + where + S: Into, + V: Into>, + { + let mut publish = Publish::new(topic, qos, payload); + publish.retain = retain; + let pkid = if qos != QoS::AtMostOnce { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + publish.pkid = pkid; + request_buf.buf.push_back(Request::Publish(publish)); + pkid + } else { + 0 + }; + self.try_notify()?; + Ok(pkid) + } + + /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. + pub async fn ack(&self, publish: &Publish) -> Result<(), ClientError> { + if let Some(ack) = get_ack_req(publish.qos, publish.pkid) { + { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + request_buf.buf.push_back(ack); + } + self.notify_async().await?; + } + Ok(()) + } + + /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. + pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { + if let Some(ack) = get_ack_req(publish.qos, publish.pkid) { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + request_buf.buf.push_back(ack); + self.try_notify()?; + } + Ok(()) + } + + /// Sends a MQTT Publish to the eventloop + pub async fn publish_bytes( + &self, + topic: S, + qos: QoS, + retain: bool, + payload: Bytes, + ) -> Result + where + S: Into, + { + let mut publish = Publish::from_bytes(topic, qos, payload); + publish.retain = retain; + let pkid = if qos != QoS::AtMostOnce { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + publish.pkid = pkid; + request_buf.buf.push_back(Request::Publish(publish)); + pkid + } else { + 0 + }; + self.notify_async().await?; + Ok(pkid) + } + + /// Sends a MQTT Subscribe to the eventloop + pub async fn subscribe>(&self, topic: S, qos: QoS) -> Result { + let mut subscribe = Subscribe::new(topic.into(), qos); + let pkid = { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + subscribe.pkid = pkid; + request_buf.buf.push_back(Request::Subscribe(subscribe)); + pkid + }; + self.notify_async().await?; + Ok(pkid) + } + + /// Sends a MQTT Subscribe to the eventloop + pub fn try_subscribe>(&self, topic: S, qos: QoS) -> Result { + let mut subscribe = Subscribe::new(topic.into(), qos); + let pkid = { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + subscribe.pkid = pkid; + request_buf.buf.push_back(Request::Subscribe(subscribe)); + pkid + }; + self.try_notify()?; + Ok(pkid) + } + + /// Sends a MQTT Subscribe for multiple topics to the eventloop + pub async fn subscribe_many(&self, topics: T) -> Result + where + T: IntoIterator, + { + let mut subscribe = Subscribe::new_many(topics); + let pkid = { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + subscribe.pkid = pkid; + request_buf.buf.push_back(Request::Subscribe(subscribe)); + pkid + }; + self.notify_async().await?; + Ok(pkid) + } + + /// Sends a MQTT Subscribe for multiple topics to the eventloop + pub fn try_subscribe_many(&self, topics: T) -> Result + where + T: IntoIterator, + { + let mut subscribe = Subscribe::new_many(topics); + let pkid = { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + subscribe.pkid = pkid; + request_buf.buf.push_back(Request::Subscribe(subscribe)); + pkid + }; + self.try_notify()?; + Ok(pkid) + } + + /// Sends a MQTT Unsubscribe to the eventloop + pub async fn unsubscribe>(&self, topic: S) -> Result { + let mut unsubscribe = Unsubscribe::new(topic.into()); + let pkid = { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + unsubscribe.pkid = pkid; + request_buf.buf.push_back(Request::Unsubscribe(unsubscribe)); + pkid + }; + self.notify_async().await?; + Ok(pkid) + } + + /// Sends a MQTT Unsubscribe to the eventloop + pub fn try_unsubscribe>(&self, topic: S) -> Result { + let mut unsubscribe = Unsubscribe::new(topic.into()); + let pkid = { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + unsubscribe.pkid = pkid; + request_buf.buf.push_back(Request::Unsubscribe(unsubscribe)); + pkid + }; + self.try_notify()?; + Ok(pkid) + } + + /// Sends a MQTT disconnect to the eventloop + #[inline] + pub async fn disconnect(&self) -> Result<(), ClientError> { + { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + request_buf.buf.push_back(Request::Disconnect); + } + self.notify_async().await + } + + /// Sends a MQTT disconnect to the eventloop + #[inline] + pub fn try_disconnect(&self) -> Result<(), ClientError> { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + request_buf.buf.push_back(Request::Disconnect); + self.try_notify() + } + + #[inline] + async fn notify_async(&self) -> Result<(), ClientError> { + if let Err(SendError(_)) = self.request_tx.send_async(()).await { + return Err(ClientError::EventloopClosed); + }; + Ok(()) + } + + #[inline] + pub(crate) fn notify(&self) -> Result<(), ClientError> { + if let Err(SendError(_)) = self.request_tx.send(()) { + return Err(ClientError::EventloopClosed); + }; + Ok(()) + } + + #[inline] + fn try_notify(&self) -> Result<(), ClientError> { + if let Err(TrySendError::Disconnected(_)) = self.request_tx.try_send(()) { + return Err(ClientError::EventloopClosed); + } + Ok(()) + } +} diff --git a/rumqttc/src/v5/client/mod.rs b/rumqttc/src/v5/client/mod.rs new file mode 100644 index 000000000..f20efe87e --- /dev/null +++ b/rumqttc/src/v5/client/mod.rs @@ -0,0 +1,97 @@ +//! This module offers a high level synchronous and asynchronous abstraction to +//! async eventloop. +use crate::v5::{packet::*, ConnectionError, EventLoop, Request}; + +use flume::SendError; +use std::mem; +use tokio::runtime::{self, Runtime}; + +mod asyncclient; +pub use asyncclient::AsyncClient; +mod syncclient; +pub use syncclient::Client; + +/// Client Error +#[derive(Debug, thiserror::Error)] +pub enum ClientError { + #[error("Failed to send cancel request to eventloop")] + Cancel(SendError<()>), + #[error("Failed to send mqtt request to eventloop, the evenloop has been closed")] + EventloopClosed, + #[error("Failed to send mqtt request to evenloop, to requests buffer is full right now")] + RequestsFull, + #[error("Serialization error")] + Mqtt5(Error), +} + +fn get_ack_req(qos: QoS, pkid: u16) -> Option { + let ack = match qos { + QoS::AtMostOnce => return None, + QoS::AtLeastOnce => Request::PubAck(PubAck::new(pkid)), + QoS::ExactlyOnce => Request::PubRec(PubRec::new(pkid)), + }; + Some(ack) +} + +/// MQTT connection. Maintains all the necessary state +pub struct Connection { + pub eventloop: EventLoop, + runtime: Option, +} + +impl Connection { + fn new(eventloop: EventLoop, runtime: Runtime) -> Connection { + Connection { + eventloop, + runtime: Some(runtime), + } + } + + /// Returns an iterator over this connection. Iterating over this is all that's + /// necessary to make connection progress and maintain a robust connection. + /// Just continuing to loop will reconnect + /// **NOTE** Don't block this while iterating + #[must_use = "Connection should be iterated over a loop to make progress"] + pub fn iter(&mut self) -> Iter { + let runtime = self.runtime.take().unwrap(); + Iter { + connection: self, + runtime, + } + } +} + +/// Iterator which polls the eventloop for connection progress +pub struct Iter<'a> { + connection: &'a mut Connection, + runtime: runtime::Runtime, +} + +impl<'a> Iterator for Iter<'a> { + type Item = Result<(), ConnectionError>; + + fn next(&mut self) -> Option { + let f = self.connection.eventloop.poll(); + match self.runtime.block_on(f) { + Ok(_) => Some(Ok(())), + // closing of request channel should stop the iterator + Err(ConnectionError::RequestsDone) => { + trace!("Done with requests"); + None + } + Err(ConnectionError::Cancel) => { + trace!("Cancellation request received"); + None + } + Err(e) => Some(Err(e)), + } + } +} + +impl<'a> Drop for Iter<'a> { + fn drop(&mut self) { + // TODO: Don't create new runtime in drop + let runtime = runtime::Builder::new_current_thread().build().unwrap(); + self.connection.runtime = Some(mem::replace(&mut self.runtime, runtime)); + } +} diff --git a/rumqttc/src/v5/client/syncclient.rs b/rumqttc/src/v5/client/syncclient.rs new file mode 100644 index 000000000..420b452aa --- /dev/null +++ b/rumqttc/src/v5/client/syncclient.rs @@ -0,0 +1,180 @@ +use tokio::runtime; + +use crate::v5::{ + client::get_ack_req, + packet::{Publish, Subscribe, SubscribeFilter, Unsubscribe}, + AsyncClient, ClientError, Connection, MqttOptions, QoS, Request, +}; + +/// `Client` to communicate with MQTT eventloop `Connection`. +/// +/// Client is cloneable and can be used to synchronously Publish, Subscribe. +/// Asynchronous channel handle can also be extracted if necessary +pub struct Client { + client: AsyncClient, +} + +impl Client { + /// Create a new `Client` + pub fn new(options: MqttOptions, cap: usize) -> (Client, Connection) { + let (client, eventloop) = AsyncClient::new(options, cap); + let client = Client { client }; + let runtime = runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + let connection = Connection::new(eventloop, runtime); + (client, connection) + } + + /// Sends a MQTT Publish to the eventloop + pub fn publish( + &self, + topic: S, + qos: QoS, + retain: bool, + payload: V, + ) -> Result + where + S: Into, + V: Into>, + { + let mut publish = Publish::new(topic, qos, payload); + publish.retain = retain; + let pkid = if qos != QoS::AtMostOnce { + let mut request_buf = self.client.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + publish.pkid = pkid; + request_buf.buf.push_back(Request::Publish(publish)); + pkid + } else { + 0 + }; + self.client.notify()?; + Ok(pkid) + } + + pub fn try_publish( + &self, + topic: S, + qos: QoS, + retain: bool, + payload: V, + ) -> Result + where + S: Into, + V: Into>, + { + self.client.try_publish(topic, qos, retain, payload) + } + + /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. + pub fn ack(&self, publish: &Publish) -> Result<(), ClientError> { + if let Some(ack) = get_ack_req(publish.qos, publish.pkid) { + let mut request_buf = self.client.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + request_buf.buf.push_back(ack); + self.client.notify()?; + } + Ok(()) + } + + /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. + pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { + self.client.try_ack(publish) + } + + /// Sends a MQTT Subscribe to the eventloop + pub fn subscribe>(&self, topic: S, qos: QoS) -> Result { + let mut subscribe = Subscribe::new(topic.into(), qos); + let pkid = if qos != QoS::AtMostOnce { + let mut request_buf = self.client.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + subscribe.pkid = pkid; + request_buf.buf.push_back(Request::Subscribe(subscribe)); + pkid + } else { + 0 + }; + self.client.notify()?; + Ok(pkid) + } + + /// Sends a MQTT Subscribe to the eventloop + pub fn try_subscribe>(&self, topic: S, qos: QoS) -> Result { + self.client.try_subscribe(topic, qos) + } + + /// Sends a MQTT Subscribe for multiple topics to the eventloop + pub fn subscribe_many(&self, topics: T) -> Result + where + T: IntoIterator, + { + let mut subscribe = Subscribe::new_many(topics); + let pkid = { + let mut request_buf = self.client.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + subscribe.pkid = pkid; + request_buf.buf.push_back(Request::Subscribe(subscribe)); + pkid + }; + self.client.notify()?; + Ok(pkid) + } + + pub fn try_subscribe_many(&self, topics: T) -> Result + where + T: IntoIterator, + { + self.client.try_subscribe_many(topics) + } + + /// Sends a MQTT Unsubscribe to the eventloop + pub fn unsubscribe>(&self, topic: S) -> Result { + let mut unsubscribe = Unsubscribe::new(topic.into()); + let pkid = { + let mut request_buf = self.client.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + unsubscribe.pkid = pkid; + request_buf.buf.push_back(Request::Unsubscribe(unsubscribe)); + pkid + }; + self.client.notify()?; + Ok(pkid) + } + + /// Sends a MQTT Unsubscribe to the eventloop + pub fn try_unsubscribe>(&self, topic: S) -> Result { + self.client.try_unsubscribe(topic) + } + + /// Sends a MQTT disconnect to the eventloop + pub fn disconnect(&self) -> Result<(), ClientError> { + let mut request_buf = self.client.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + request_buf.buf.push_back(Request::Disconnect); + self.client.notify() + } + + /// Sends a MQTT disconnect to the eventloop + pub fn try_disconnect(&self) -> Result<(), ClientError> { + self.client.try_disconnect() + } +} diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs new file mode 100644 index 000000000..7725c373f --- /dev/null +++ b/rumqttc/src/v5/eventloop.rs @@ -0,0 +1,361 @@ +#[cfg(feature = "use-rustls")] +use crate::v5::tls; +use crate::v5::{ + framed::Network, outgoing_buf::OutgoingBuf, packet::*, Incoming, MqttOptions, MqttState, + Packet, Request, StateError, Transport, +}; + +#[cfg(feature = "websocket")] +use async_tungstenite::tokio::{connect_async, connect_async_with_tls_connector}; +use flume::{bounded, Receiver, Sender}; +use tokio::net::TcpStream; +#[cfg(unix)] +use tokio::net::UnixStream; +use tokio::select; +use tokio::time::{self, error::Elapsed, Instant, Sleep}; +#[cfg(feature = "websocket")] +use ws_stream_tungstenite::WsStream; + +#[cfg(unix)] +use std::path::Path; +use std::{ + collections::VecDeque, + io, + pin::Pin, + sync::{Arc, Mutex}, + time::Duration, + vec::IntoIter, +}; + +/// Critical errors during eventloop polling +#[derive(Debug, thiserror::Error)] +pub enum ConnectionError { + #[error("Mqtt state: {0}")] + MqttState(#[from] StateError), + #[error("Timeout")] + Timeout(#[from] Elapsed), + #[error("Packet parsing error: {0}")] + Mqtt5Bytes(Error), + #[cfg(feature = "use-rustls")] + #[error("Network: {0}")] + Network(#[from] tls::Error), + #[error("I/O: {0}")] + Io(#[from] io::Error), + #[error("Stream done")] + StreamDone, + #[error("Requests done")] + RequestsDone, + #[error("Cancel request by the user")] + Cancel, +} + +/// Eventloop with all the state of a connection +pub struct EventLoop { + /// Options of the current mqtt connection + pub options: MqttOptions, + /// Current state of the connection + pub state: MqttState, + outgoing_buf: Arc>, + outgoing_buf_cache: VecDeque, + /// Request stream + pub incoming_rx: Receiver<()>, + /// Requests handle to send requests + pub incoming_tx: Sender<()>, + /// Pending packets from last session + pub pending: IntoIter, + /// Network connection to the broker + pub(crate) network: Option, + /// Keep alive time + pub(crate) keepalive_timeout: Option>>, +} + +impl EventLoop { + /// New MQTT `EventLoop` + /// + /// When connection encounters critical errors (like auth failure), user has a choice to + /// access and update `options`, `state` and `requests`. + pub fn new(options: MqttOptions, cap: usize) -> EventLoop { + let (incoming_tx, incoming_rx) = bounded(1); + let pending = Vec::new(); + let pending = pending.into_iter(); + let max_inflight = options.inflight; + let manual_acks = options.manual_acks; + let state = MqttState::new(max_inflight, manual_acks, cap); + let outgoing_buf = state.outgoing_buf.clone(); + + EventLoop { + options, + state, + outgoing_buf, + outgoing_buf_cache: VecDeque::with_capacity(cap), + incoming_tx, + incoming_rx, + pending, + network: None, + keepalive_timeout: None, + } + } + + /// Returns a handle to communicate with this eventloop + #[inline] + pub fn handle(&self) -> Sender<()> { + self.incoming_tx.clone() + } + + fn clean(&mut self) { + self.network = None; + self.keepalive_timeout = None; + let pending = self.state.clean(); + self.pending = pending.into_iter(); + } + + /// Yields Next notification or outgoing request and periodically pings + /// the broker. Continuing to poll will reconnect to the broker if there is + /// a disconnection. + /// **NOTE** Don't block this while iterating + pub async fn poll(&mut self) -> Result<(), ConnectionError> { + if self.network.is_none() { + let (network, _connack) = connect(&self.options).await?; + self.network = Some(network); + + if self.keepalive_timeout.is_none() { + self.keepalive_timeout = Some(Box::pin(time::sleep(self.options.keep_alive))); + } + + return Ok(()); + } + + if let Err(e) = self.select().await { + self.clean(); + return Err(e); + } + + Ok(()) + } + + /// Select on network and requests and generate keepalive pings when necessary + async fn select(&mut self) -> Result<(), ConnectionError> { + let network = self.network.as_mut().unwrap(); + // let await_acks = self.state.await_acks; + let inflight_full = self.state.inflight >= self.options.inflight; + let throttle = self.options.pending_throttle; + let pending = self.pending.len() > 0; + let collision = self.state.collision.is_some(); + + // this loop is necessary as self.request_buf might be empty, in which case it is possible + // for self.state.events to be empty, and so popping off from it might return None. If None + // is returned, we select again. + loop { + select! { + // Pull a bunch of packets from network, reply in bunch and yield the first item + o = network.readb(&mut self.state) => { + o?; + // flush all the acks and return first incoming packet + network.flush(&mut self.state.write).await?; + return Ok(()); + }, + // Pull next request from user requests channel. + // If conditions in the below branch are for flow control. We read next user + // request only when inflight messages are < configured inflight and there are no + // collisions while handling previous outgoing requests. + // + // Flow control is based on ack count. If inflight packet count in the buffer is + // less than max_inflight setting, next outgoing request will progress. For this to + // work correctly, broker should ack in sequence (a lot of brokers won't) + // + // E.g If max inflight = 5, user requests will be blocked when inflight queue looks + // like this -> [1, 2, 3, 4, 5]. + // If broker acking 2 instead of 1 -> [1, x, 3, 4, 5]. + // This pulls next user request. But because max packet id = max_inflight, next + // user request's packet id will roll to 1. This replaces existing packet id 1. + // Resulting in a collision + // + // Eventloop can stop receiving outgoing user requests when previous outgoing + // request collided. I.e collision state. Collision state will be cleared only + // when correct ack is received + // Full inflight queue will look like -> [1a, 2, 3, 4, 5]. + // If 3 is acked instead of 1 first -> [1a, 2, x, 4, 5]. + // After collision with pkid 1 -> [1b ,2, x, 4, 5]. + // 1a is saved to state and event loop is set to collision mode stopping new + // outgoing requests (along with 1b). + o = self.incoming_rx.recv_async(), if !inflight_full && !pending && !collision => match o { + Ok(_request_notif) => { + // swapping to avoid blocking the mutex + std::mem::swap(&mut self.outgoing_buf_cache, &mut self.outgoing_buf.lock().unwrap().buf); + if self.outgoing_buf_cache.is_empty() { + continue; + } + for request in self.outgoing_buf_cache.drain(..) { + self.state.handle_outgoing_packet(request)?; + } + network.flush(&mut self.state.write).await?; + // remaining events in the self.state.events will be taken out in next call + // to poll() even before the select! is used. + return Ok(()) + } + Err(_) => return Err(ConnectionError::RequestsDone), + }, + // Handle the next pending packet from previous session. Disable + // this branch when done with all the pending packets + Some(request) = next_pending(throttle, &mut self.pending), if pending => { + self.state.handle_outgoing_packet(request)?; + network.flush(&mut self.state.write).await?; + return Ok(()) + }, + // We generate pings irrespective of network activity. This keeps the ping logic + // simple. We can change this behavior in future if necessary (to prevent extra pings) + _ = self.keepalive_timeout.as_mut().unwrap() => { + let timeout = self.keepalive_timeout.as_mut().unwrap(); + timeout.as_mut().reset(Instant::now() + self.options.keep_alive); + + self.state.handle_outgoing_packet(Request::PingReq)?; + network.flush(&mut self.state.write).await?; + return Ok(()) + } + } + } + } +} + +/// This stream internally processes requests from the request stream provided to the eventloop +/// while also consuming byte stream from the network and yielding mqtt packets as the output of +/// the stream. +/// This function (for convenience) includes internal delays for users to perform internal sleeps +/// between re-connections so that cancel semantics can be used during this sleep +async fn connect(options: &MqttOptions) -> Result<(Network, Incoming), ConnectionError> { + // connect to the broker + let mut network = match network_connect(options).await { + Ok(network) => network, + Err(e) => { + return Err(e); + } + }; + + // make MQTT connection request (which internally awaits for ack) + let packet = match mqtt_connect(options, &mut network).await { + Ok(p) => p, + Err(e) => return Err(e), + }; + + // Last session might contain packets which aren't acked. MQTT says these packets should be + // republished in the next session + // move pending messages from state to eventloop + // let pending = self.state.clean(); + // self.pending = pending.into_iter(); + Ok((network, packet)) +} + +async fn network_connect(options: &MqttOptions) -> Result { + let network = match options.transport() { + Transport::Tcp => { + let addr = options.broker_addr.as_str(); + let port = options.port; + let socket = TcpStream::connect((addr, port)).await?; + Network::new(socket, options.max_incoming_packet_size) + } + #[cfg(feature = "use-rustls")] + Transport::Tls(tls_config) => { + let socket = tls::tls_connect(options, &tls_config).await?; + Network::new(socket, options.max_incoming_packet_size) + } + #[cfg(unix)] + Transport::Unix => { + let file = options.broker_addr.as_str(); + let socket = UnixStream::connect(Path::new(file)).await?; + Network::new(socket, options.max_incoming_packet_size) + } + #[cfg(feature = "websocket")] + Transport::Ws => { + let request = http::Request::builder() + .method(http::Method::GET) + .uri(options.broker_addr.as_str()) + .header("Sec-WebSocket-Protocol", "mqttv3.1") + .body(()) + .unwrap(); + + let (socket, _) = connect_async(request) + .await + .map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, e))?; + + Network::new(WsStream::new(socket), options.max_incoming_packet_size) + } + #[cfg(feature = "websocket")] + Transport::Wss(tls_config) => { + let request = http::Request::builder() + .method(http::Method::GET) + .uri(options.broker_addr.as_str()) + .header("Sec-WebSocket-Protocol", "mqttv3.1") + .body(()) + .unwrap(); + + let connector = tls::tls_connector(&tls_config).await?; + + let (socket, _) = connect_async_with_tls_connector(request, Some(connector)) + .await + .map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, e))?; + + Network::new(WsStream::new(socket), options.max_incoming_packet_size) + } + }; + + Ok(network) +} + +async fn mqtt_connect( + options: &MqttOptions, + network: &mut Network, +) -> Result { + let keep_alive = options.keep_alive().as_secs() as u16; + let clean_session = options.clean_session(); + let last_will = options.last_will(); + + let mut connect = Connect::new(options.client_id()); + connect.keep_alive = keep_alive; + connect.clean_session = clean_session; + connect.last_will = last_will; + + if let Some((username, password)) = options.credentials() { + let login = Login::new(username, password); + connect.login = Some(login); + } + + // mqtt connection with timeout + time::timeout(Duration::from_secs(options.connection_timeout()), async { + network.connect(connect).await?; + Ok::<_, ConnectionError>(()) + }) + .await??; + + // wait for 'timeout' time to validate connack + let packet = time::timeout(Duration::from_secs(options.connection_timeout()), async { + let packet = match network.read().await? { + Incoming::ConnAck(connack) if connack.code == ConnectReturnCode::Success => { + Packet::ConnAck(connack) + } + Incoming::ConnAck(connack) => { + let error = format!("Broker rejected. Reason = {:?}", connack.code); + return Err(io::Error::new(io::ErrorKind::InvalidData, error)); + } + packet => { + let error = format!("Expecting connack. Received = {:?}", packet); + return Err(io::Error::new(io::ErrorKind::InvalidData, error)); + } + }; + + io::Result::Ok(packet) + }) + .await??; + + Ok(packet) +} + +/// Returns the next pending packet asynchronously to be used in select! +/// This is a synchronous function but made async to make it fit in select! +pub(crate) async fn next_pending( + delay: Duration, + pending: &mut IntoIter, +) -> Option { + // return next packet with a delay + time::sleep(delay).await; + pending.next() +} diff --git a/rumqttc/src/v5/framed.rs b/rumqttc/src/v5/framed.rs new file mode 100644 index 000000000..684694d5a --- /dev/null +++ b/rumqttc/src/v5/framed.rs @@ -0,0 +1,120 @@ +use bytes::BytesMut; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +use crate::v5::{packet::*, Incoming, MqttState, StateError}; +use std::io; + +/// Network transforms packets <-> frames efficiently. It takes +/// advantage of pre-allocation, buffering and vectorization when +/// appropriate to achieve performance +pub struct Network { + /// Socket for IO + socket: Box, + /// Buffered reads + read: BytesMut, + /// Maximum packet size + max_incoming_size: usize, + /// Maximum readv count + max_readb_count: usize, +} + +impl Network { + pub fn new(socket: impl N + 'static, max_incoming_size: usize) -> Network { + let socket = Box::new(socket) as Box; + Network { + socket, + read: BytesMut::with_capacity(10 * 1024), + max_incoming_size, + max_readb_count: 10, + } + } + + /// Reads more than 'required' bytes to frame a packet into self.read buffer + async fn read_bytes(&mut self, required: usize) -> io::Result { + let mut total_read = 0; + loop { + let read = self.socket.read_buf(&mut self.read).await?; + if 0 == read { + return if self.read.is_empty() { + Err(io::Error::new( + io::ErrorKind::ConnectionAborted, + "connection closed by peer", + )) + } else { + Err(io::Error::new( + io::ErrorKind::ConnectionReset, + "connection reset by peer", + )) + }; + } + + total_read += read; + if total_read >= required { + return Ok(total_read); + } + } + } + + pub async fn read(&mut self) -> io::Result { + loop { + let required = match read(&mut self.read, self.max_incoming_size) { + Ok(packet) => return Ok(packet), + Err(Error::InsufficientBytes(required)) => required, + Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())), + }; + + // read more packets until a frame can be created. This function + // blocks until a frame can be created. Use this in a select! branch + self.read_bytes(required).await?; + } + } + + /// Read packets in bulk. This allow replies to be in bulk. This method is used + /// after the connection is established to read a bunch of incoming packets + pub async fn readb(&mut self, state: &mut MqttState) -> Result<(), StateError> { + let mut count = 0; + loop { + match read(&mut self.read, self.max_incoming_size) { + Ok(packet) => { + state.handle_incoming_packet(packet)?; + + count += 1; + if count >= self.max_readb_count { + return Ok(()); + } + } + // If some packets are already framed, return those + Err(Error::InsufficientBytes(_)) if count > 0 => return Ok(()), + // Wait for more bytes until a frame can be created + Err(Error::InsufficientBytes(required)) => { + self.read_bytes(required).await?; + } + Err(e) => return Err(StateError::Deserialization(e)), + }; + } + } + + pub async fn connect(&mut self, connect: Connect) -> io::Result { + let mut write = BytesMut::new(); + let len = match connect.write(&mut write) { + Ok(size) => size, + Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())), + }; + + self.socket.write_all(&write[..]).await?; + Ok(len) + } + + pub async fn flush(&mut self, write: &mut BytesMut) -> io::Result<()> { + if write.is_empty() { + return Ok(()); + } + + self.socket.write_all(&write[..]).await?; + write.clear(); + Ok(()) + } +} + +pub trait N: AsyncRead + AsyncWrite + Send + Unpin {} +impl N for T where T: AsyncRead + AsyncWrite + Send + Unpin {} diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs new file mode 100644 index 000000000..bddc7013b --- /dev/null +++ b/rumqttc/src/v5/mod.rs @@ -0,0 +1,720 @@ +#[cfg(feature = "use-rustls")] +use std::sync::Arc; +use std::{ + collections::VecDeque, + fmt::{self, Debug, Formatter}, + time::Duration, +}; + +mod client; +mod eventloop; +mod framed; +mod notifier; +mod outgoing_buf; +#[allow(clippy::all)] +mod packet; +mod state; +#[cfg(feature = "use-rustls")] +mod tls; + +pub use client::{AsyncClient, Client, ClientError, Connection}; +pub use eventloop::{ConnectionError, EventLoop}; +pub use flume::{SendError, Sender, TrySendError}; +pub use notifier::Notifier; +pub use packet::*; +pub use state::{MqttState, StateError}; +#[cfg(feature = "use-rustls")] +pub use tls::Error; +#[cfg(feature = "use-rustls")] +pub use tokio_rustls::rustls::ClientConfig; + +pub type Incoming = Packet; + +/// Requests by the client to mqtt event loop. Request are +/// handled one by one. +#[derive(Clone, Debug, PartialEq)] +pub enum Request { + Publish(Publish), + PubAck(PubAck), + PubRec(PubRec), + PubComp(PubComp), + PubRel(PubRel), + PingReq, + PingResp, + Subscribe(Subscribe), + SubAck(SubAck), + Unsubscribe(Unsubscribe), + UnsubAck(UnsubAck), + Disconnect, +} + +/// Key type for TLS authentication +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum Key { + RSA(Vec), + ECC(Vec), +} + +impl From for Request { + fn from(publish: Publish) -> Request { + Request::Publish(publish) + } +} + +impl From for Request { + fn from(subscribe: Subscribe) -> Request { + Request::Subscribe(subscribe) + } +} + +impl From for Request { + fn from(unsubscribe: Unsubscribe) -> Request { + Request::Unsubscribe(unsubscribe) + } +} + +#[derive(Clone)] +pub enum Transport { + Tcp, + #[cfg(feature = "use-rustls")] + Tls(TlsConfiguration), + #[cfg(unix)] + Unix, + #[cfg(feature = "websocket")] + #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + Ws, + #[cfg(all(feature = "use-rustls", feature = "websocket"))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "use-rustls", feature = "websocket"))))] + Wss(TlsConfiguration), +} + +impl Default for Transport { + fn default() -> Self { + Self::tcp() + } +} + +impl Transport { + /// Use regular tcp as transport (default) + pub fn tcp() -> Self { + Self::Tcp + } + + /// Use secure tcp with tls as transport + #[cfg(feature = "use-rustls")] + pub fn tls( + ca: Vec, + client_auth: Option<(Vec, Key)>, + alpn: Option>>, + ) -> Self { + let config = TlsConfiguration::Simple { + ca, + alpn, + client_auth, + }; + + Self::tls_with_config(config) + } + + #[cfg(feature = "use-rustls")] + pub fn tls_with_config(tls_config: TlsConfiguration) -> Self { + Self::Tls(tls_config) + } + + #[cfg(unix)] + pub fn unix() -> Self { + Self::Unix + } + + /// Use websockets as transport + #[cfg(feature = "websocket")] + #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + pub fn ws() -> Self { + Self::Ws + } + + /// Use secure websockets with tls as transport + #[cfg(all(feature = "use-rustls", feature = "websocket"))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "use-rustls", feature = "websocket"))))] + pub fn wss( + ca: Vec, + client_auth: Option<(Vec, Key)>, + alpn: Option>>, + ) -> Self { + let config = TlsConfiguration::Simple { + ca, + client_auth, + alpn, + }; + + Self::wss_with_config(config) + } + + #[cfg(all(feature = "use-rustls", feature = "websocket"))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "use-rustls", feature = "websocket"))))] + pub fn wss_with_config(tls_config: TlsConfiguration) -> Self { + Self::Wss(tls_config) + } +} + +#[derive(Clone)] +#[cfg(feature = "use-rustls")] +pub enum TlsConfiguration { + Simple { + /// connection method + ca: Vec, + /// alpn settings + alpn: Option>>, + /// tls client_authentication + client_auth: Option<(Vec, Key)>, + }, + /// Injected rustls ClientConfig for TLS, to allow more customisation. + Rustls(Arc), +} + +#[cfg(feature = "use-rustls")] +impl From for TlsConfiguration { + fn from(config: ClientConfig) -> Self { + TlsConfiguration::Rustls(Arc::new(config)) + } +} + +// TODO: Should all the options be exposed as public? Drawback +// would be loosing the ability to panic when the user options +// are wrong (e.g empty client id) or aggressive (keep alive time) +/// Options to configure the behaviour of mqtt connection +#[derive(Clone)] +pub struct MqttOptions { + /// broker address that you want to connect to + broker_addr: String, + /// broker port + port: u16, + // What transport protocol to use + transport: Transport, + /// keep alive time to send pingreq to broker when the connection is idle + keep_alive: Duration, + /// clean (or) persistent session + clean_session: bool, + /// client identifier + client_id: String, + /// username and password + credentials: Option<(String, String)>, + /// maximum incoming packet size (verifies remaining length of the packet) + max_incoming_packet_size: usize, + /// Maximum outgoing packet size (only verifies publish payload size) + // TODO Verify this with all packets. This can be packet.write but message left in + // the state might be a footgun as user has to explicitly clean it. Probably state + // has to be moved to network + max_outgoing_packet_size: usize, + /// request (publish, subscribe) channel capacity + request_channel_capacity: usize, + /// Max internal request batching + max_request_batch: usize, + /// Minimum delay time between consecutive outgoing packets + /// while retransmitting pending packets + pending_throttle: Duration, + /// maximum number of outgoing inflight messages + inflight: u16, + /// Last will that will be issued on unexpected disconnect + last_will: Option, + /// Connection timeout + conn_timeout: u64, + /// If set to `true` MQTT acknowledgements are not sent automatically. + /// Every incoming publish packet must be manually acknowledged with `client.ack(...)` method. + manual_acks: bool, +} + +impl MqttOptions { + /// New mqtt options + pub fn new, T: Into>(id: S, host: T, port: u16) -> MqttOptions { + let id = id.into(); + if id.starts_with(' ') || id.is_empty() { + panic!("Invalid client id") + } + + MqttOptions { + broker_addr: host.into(), + port, + transport: Transport::tcp(), + keep_alive: Duration::from_secs(60), + clean_session: true, + client_id: id, + credentials: None, + max_incoming_packet_size: 10 * 1024, + max_outgoing_packet_size: 10 * 1024, + request_channel_capacity: 10, + max_request_batch: 0, + pending_throttle: Duration::from_micros(0), + inflight: 100, + last_will: None, + conn_timeout: 5, + manual_acks: false, + } + } + + /// Broker address + pub fn broker_address(&self) -> (String, u16) { + (self.broker_addr.clone(), self.port) + } + + pub fn set_last_will(&mut self, will: LastWill) -> &mut Self { + self.last_will = Some(will); + self + } + + pub fn last_will(&self) -> Option { + self.last_will.clone() + } + + pub fn set_transport(&mut self, transport: Transport) -> &mut Self { + self.transport = transport; + self + } + + pub fn transport(&self) -> Transport { + self.transport.clone() + } + + /// Set number of seconds after which client should ping the broker + /// if there is no other data exchange + pub fn set_keep_alive(&mut self, duration: Duration) -> &mut Self { + assert!(duration.as_secs() >= 5, "Keep alives should be >= 5 secs"); + + self.keep_alive = duration; + self + } + + /// Keep alive time + pub fn keep_alive(&self) -> Duration { + self.keep_alive + } + + /// Client identifier + pub fn client_id(&self) -> String { + self.client_id.clone() + } + + /// Set packet size limit for outgoing an incoming packets + pub fn set_max_packet_size(&mut self, incoming: usize, outgoing: usize) -> &mut Self { + self.max_incoming_packet_size = incoming; + self.max_outgoing_packet_size = outgoing; + self + } + + /// Maximum packet size + pub fn max_packet_size(&self) -> usize { + self.max_incoming_packet_size + } + + /// `clean_session = true` removes all the state from queues & instructs the broker + /// to clean all the client state when client disconnects. + /// + /// When set `false`, broker will hold the client state and performs pending + /// operations on the client when reconnection with same `client_id` + /// happens. Local queue state is also held to retransmit packets after reconnection. + pub fn set_clean_session(&mut self, clean_session: bool) -> &mut Self { + self.clean_session = clean_session; + self + } + + /// Clean session + pub fn clean_session(&self) -> bool { + self.clean_session + } + + /// Username and password + pub fn set_credentials, P: Into>( + &mut self, + username: U, + password: P, + ) -> &mut Self { + self.credentials = Some((username.into(), password.into())); + self + } + + /// Security options + pub fn credentials(&self) -> Option<(String, String)> { + self.credentials.clone() + } + + /// Set request channel capacity + pub fn set_request_channel_capacity(&mut self, capacity: usize) -> &mut Self { + self.request_channel_capacity = capacity; + self + } + + /// Request channel capacity + pub fn request_channel_capacity(&self) -> usize { + self.request_channel_capacity + } + + /// Enables throttling and sets outoing message rate to the specified 'rate' + pub fn set_pending_throttle(&mut self, duration: Duration) -> &mut Self { + self.pending_throttle = duration; + self + } + + /// Outgoing message rate + pub fn pending_throttle(&self) -> Duration { + self.pending_throttle + } + + /// Set number of concurrent in flight messages + pub fn set_inflight(&mut self, inflight: u16) -> &mut Self { + if inflight == 0 { + panic!("zero in flight is not allowed") + } + + self.inflight = inflight; + self + } + + /// Number of concurrent in flight messages + pub fn inflight(&self) -> u16 { + self.inflight + } + + /// set connection timeout in secs + pub fn set_connection_timeout(&mut self, timeout: u64) -> &mut Self { + self.conn_timeout = timeout; + self + } + + /// get timeout in secs + pub fn connection_timeout(&self) -> u64 { + self.conn_timeout + } + + /// set manual acknowledgements + pub fn set_manual_acks(&mut self, manual_acks: bool) -> &mut Self { + self.manual_acks = manual_acks; + self + } + + /// get manual acknowledgements + pub fn manual_acks(&self) -> bool { + self.manual_acks + } +} + +#[cfg(feature = "url")] +#[derive(Debug, PartialEq, thiserror::Error)] +pub enum OptionError { + #[error("Unsupported URL scheme.")] + Scheme, + + #[error("Missing client ID.")] + ClientId, + + #[error("Invalid keep-alive value.")] + KeepAlive, + + #[error("Invalid clean-session value.")] + CleanSession, + + #[error("Invalid max-incoming-packet-size value.")] + MaxIncomingPacketSize, + + #[error("Invalid max-outgoing-packet-size value.")] + MaxOutgoingPacketSize, + + #[error("Invalid request-channel-capacity value.")] + RequestChannelCapacity, + + #[error("Invalid max-request-batch value.")] + MaxRequestBatch, + + #[error("Invalid pending-throttle value.")] + PendingThrottle, + + #[error("Invalid inflight value.")] + Inflight, + + #[error("Invalid conn-timeout value.")] + ConnTimeout, + + #[error("Unknown option: {0}")] + Unknown(String), +} + +#[cfg(feature = "url")] +impl std::convert::TryFrom for MqttOptions { + type Error = OptionError; + + fn try_from(url: url::Url) -> Result { + use std::collections::HashMap; + + let broker_addr = url.host_str().unwrap_or_default().to_owned(); + + let (transport, default_port) = match url.scheme() { + // Encrypted connections are supported, but require explicit TLS configuration. We fall + // back to the unencrypted transport layer, so that `set_transport` can be used to + // configure the encrypted transport layer with the provided TLS configuration. + "mqtts" | "ssl" => (Transport::Tcp, 8883), + "mqtt" | "tcp" => (Transport::Tcp, 1883), + _ => return Err(OptionError::Scheme), + }; + + let port = url.port().unwrap_or(default_port); + + let mut queries = url.query_pairs().collect::>(); + + let keep_alive = Duration::from_secs( + queries + .remove("keep_alive_secs") + .map(|v| v.parse::().map_err(|_| OptionError::KeepAlive)) + .transpose()? + .unwrap_or(60), + ); + + let client_id = queries + .remove("client_id") + .ok_or(OptionError::ClientId)? + .into_owned(); + + let clean_session = queries + .remove("clean_session") + .map(|v| v.parse::().map_err(|_| OptionError::CleanSession)) + .transpose()? + .unwrap_or(true); + + let credentials = { + match url.username() { + "" => None, + username => Some(( + username.to_owned(), + url.password().unwrap_or_default().to_owned(), + )), + } + }; + + let max_incoming_packet_size = queries + .remove("max_incoming_packet_size_bytes") + .map(|v| { + v.parse::() + .map_err(|_| OptionError::MaxIncomingPacketSize) + }) + .transpose()? + .unwrap_or(10 * 1024); + + let max_outgoing_packet_size = queries + .remove("max_outgoing_packet_size_bytes") + .map(|v| { + v.parse::() + .map_err(|_| OptionError::MaxOutgoingPacketSize) + }) + .transpose()? + .unwrap_or(10 * 1024); + + let request_channel_capacity = queries + .remove("request_channel_capacity_num") + .map(|v| { + v.parse::() + .map_err(|_| OptionError::RequestChannelCapacity) + }) + .transpose()? + .unwrap_or(10); + + let max_request_batch = queries + .remove("max_request_batch_num") + .map(|v| v.parse::().map_err(|_| OptionError::MaxRequestBatch)) + .transpose()? + .unwrap_or(0); + + let pending_throttle = Duration::from_micros( + queries + .remove("pending_throttle_usecs") + .map(|v| v.parse::().map_err(|_| OptionError::PendingThrottle)) + .transpose()? + .unwrap_or(0), + ); + + let inflight = queries + .remove("inflight_num") + .map(|v| v.parse::().map_err(|_| OptionError::Inflight)) + .transpose()? + .unwrap_or(100); + + let conn_timeout = queries + .remove("conn_timeout_secs") + .map(|v| v.parse::().map_err(|_| OptionError::ConnTimeout)) + .transpose()? + .unwrap_or(5); + + if let Some((opt, _)) = queries.into_iter().next() { + return Err(OptionError::Unknown(opt.into_owned())); + } + + Ok(Self { + broker_addr, + port, + transport, + keep_alive, + clean_session, + client_id, + credentials, + max_incoming_packet_size, + max_outgoing_packet_size, + request_channel_capacity, + max_request_batch, + pending_throttle, + inflight, + last_will: None, + conn_timeout, + manual_acks: false, + }) + } +} + +// Implement Debug manually because ClientConfig doesn't implement it, so derive(Debug) doesn't +// work. +impl Debug for MqttOptions { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("MqttOptions") + .field("broker_addr", &self.broker_addr) + .field("port", &self.port) + .field("keep_alive", &self.keep_alive) + .field("clean_session", &self.clean_session) + .field("client_id", &self.client_id) + .field("credentials", &self.credentials) + .field("max_packet_size", &self.max_incoming_packet_size) + .field("request_channel_capacity", &self.request_channel_capacity) + .field("max_request_batch", &self.max_request_batch) + .field("pending_throttle", &self.pending_throttle) + .field("inflight", &self.inflight) + .field("last_will", &self.last_will) + .field("conn_timeout", &self.conn_timeout) + .field("manual_acks", &self.manual_acks) + .finish() + } +} + +pub async fn connect(options: MqttOptions, cap: usize) -> Result<(AsyncClient, Notifier), ()> { + let mut eventloop = EventLoop::new(options, cap); + let outgoing_buf = eventloop.state.outgoing_buf.clone(); + let incoming_buf = eventloop.state.incoming_buf.clone(); + let incoming_buf_cache = VecDeque::with_capacity(cap); + let request_tx = eventloop.handle(); + + let client = AsyncClient { + outgoing_buf, + request_tx, + }; + + tokio::spawn(async move { + loop { + // TODO: maybe do something like retries for some specific errors? or maybe give user + // options to configure these retries? + eventloop.poll().await.unwrap(); + } + }); + + Ok((client, Notifier::new(incoming_buf, incoming_buf_cache))) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + #[should_panic] + fn client_id_startswith_space() { + let _mqtt_opts = MqttOptions::new(" client_a", "127.0.0.1", 1883).set_clean_session(true); + } + + #[test] + #[cfg(all(feature = "use-rustls", feature = "websocket"))] + fn no_scheme() { + let mut _mqtt_opts = MqttOptions::new("client_a", "a3f8czas.iot.eu-west-1.amazonaws.com/mqtt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=MyCreds%2F20201001%2Feu-west-1%2Fiotdevicegateway%2Faws4_request&X-Amz-Date=20201001T130812Z&X-Amz-Expires=7200&X-Amz-Signature=9ae09b49896f44270f2707551581953e6cac71a4ccf34c7c3415555be751b2d1&X-Amz-SignedHeaders=host", 443); + + _mqtt_opts.set_transport(crate::v5::Transport::wss(Vec::from("Test CA"), None, None)); + + if let crate::v5::Transport::Wss(TlsConfiguration::Simple { + ca, + client_auth, + alpn, + }) = _mqtt_opts.transport + { + assert_eq!(ca, Vec::from("Test CA")); + assert_eq!(client_auth, None); + assert_eq!(alpn, None); + } else { + panic!("Unexpected transport!"); + } + + assert_eq!(_mqtt_opts.broker_addr, "a3f8czas.iot.eu-west-1.amazonaws.com/mqtt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=MyCreds%2F20201001%2Feu-west-1%2Fiotdevicegateway%2Faws4_request&X-Amz-Date=20201001T130812Z&X-Amz-Expires=7200&X-Amz-Signature=9ae09b49896f44270f2707551581953e6cac71a4ccf34c7c3415555be751b2d1&X-Amz-SignedHeaders=host"); + } + + #[test] + #[cfg(feature = "url")] + fn from_url() { + use std::convert::TryInto; + use std::str::FromStr; + + fn opt(s: &str) -> Result { + url::Url::from_str(s).expect("valid url").try_into() + } + fn ok(s: &str) -> MqttOptions { + opt(s).expect("valid options") + } + fn err(s: &str) -> OptionError { + opt(s).expect_err("invalid options") + } + + let v = ok("mqtt://host:42?client_id=foo"); + assert_eq!(v.broker_address(), ("host".to_owned(), 42)); + assert_eq!(v.client_id(), "foo".to_owned()); + + let v = ok("mqtt://host:42?client_id=foo&keep_alive_secs=5"); + assert_eq!(v.keep_alive, Duration::from_secs(5)); + + assert_eq!(err("mqtt://host:42"), OptionError::ClientId); + assert_eq!( + err("mqtt://host:42?client_id=foo&foo=bar"), + OptionError::Unknown("foo".to_owned()) + ); + assert_eq!(err("mqt://host:42?client_id=foo"), OptionError::Scheme); + assert_eq!( + err("mqtt://host:42?client_id=foo&keep_alive_secs=foo"), + OptionError::KeepAlive + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&clean_session=foo"), + OptionError::CleanSession + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&max_incoming_packet_size_bytes=foo"), + OptionError::MaxIncomingPacketSize + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&max_outgoing_packet_size_bytes=foo"), + OptionError::MaxOutgoingPacketSize + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&request_channel_capacity_num=foo"), + OptionError::RequestChannelCapacity + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&max_request_batch_num=foo"), + OptionError::MaxRequestBatch + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&pending_throttle_usecs=foo"), + OptionError::PendingThrottle + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&inflight_num=foo"), + OptionError::Inflight + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&conn_timeout_secs=foo"), + OptionError::ConnTimeout + ); + } + + #[test] + #[should_panic] + fn no_client_id() { + let _mqtt_opts = MqttOptions::new("", "127.0.0.1", 1883).set_clean_session(true); + } +} diff --git a/rumqttc/src/v5/notifier.rs b/rumqttc/src/v5/notifier.rs new file mode 100644 index 000000000..3ef34cb85 --- /dev/null +++ b/rumqttc/src/v5/notifier.rs @@ -0,0 +1,60 @@ +use std::{ + collections::VecDeque, + mem, + sync::{Arc, Mutex}, +}; + +use crate::v5::Incoming; + +#[derive(Debug)] +pub struct Notifier { + incoming_buf: Arc>>, + incoming_buf_cache: VecDeque, +} + +impl Notifier { + #[inline] + pub(crate) fn new( + incoming_buf: Arc>>, + incoming_buf_cache: VecDeque, + ) -> Self { + Self { + incoming_buf, + incoming_buf_cache, + } + } + + #[inline] + pub fn iter(&mut self) -> NotifierIter<'_> { + NotifierIter(self) + } +} + +impl Iterator for Notifier { + type Item = Incoming; + + #[inline] + fn next(&mut self) -> Option { + match self.incoming_buf_cache.pop_front() { + None => { + mem::swap( + &mut self.incoming_buf_cache, + &mut *self.incoming_buf.lock().unwrap(), + ); + self.incoming_buf_cache.pop_front() + } + val => val, + } + } +} + +pub struct NotifierIter<'a>(&'a mut Notifier); + +impl<'a> Iterator for NotifierIter<'a> { + type Item = Incoming; + + #[inline] + fn next(&mut self) -> Option { + self.0.next() + } +} diff --git a/rumqttc/src/v5/outgoing_buf.rs b/rumqttc/src/v5/outgoing_buf.rs new file mode 100644 index 000000000..d036dd0c8 --- /dev/null +++ b/rumqttc/src/v5/outgoing_buf.rs @@ -0,0 +1,34 @@ +use std::{ + collections::VecDeque, + sync::{Arc, Mutex}, +}; + +use crate::v5::Request; + +#[derive(Debug)] +pub struct OutgoingBuf { + pub(crate) buf: VecDeque, + pub(crate) pkid_counter: u16, + pub(crate) capacity: usize, +} + +impl OutgoingBuf { + #[inline] + pub fn new(cap: usize) -> Arc> { + Arc::new(Mutex::new(Self { + buf: VecDeque::with_capacity(cap), + pkid_counter: 0, + capacity: cap, + })) + } + + #[inline] + pub fn increment_pkid(&mut self) -> u16 { + self.pkid_counter = if self.pkid_counter == self.capacity as u16 { + 1 + } else { + self.pkid_counter + 1 + }; + self.pkid_counter + } +} diff --git a/rumqttc/src/v5/packet/connack.rs b/rumqttc/src/v5/packet/connack.rs new file mode 100644 index 000000000..f0e0ebaee --- /dev/null +++ b/rumqttc/src/v5/packet/connack.rs @@ -0,0 +1,553 @@ +use super::*; +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +/// Return code in connack +#[derive(Debug, Clone, Copy, PartialEq)] +#[repr(u8)] +pub enum ConnectReturnCode { + Success = 0, + UnspecifiedError = 128, + MalformedPacket = 129, + ProtocolError = 130, + ImplementationSpecificError = 131, + UnsupportedProtocolVersion = 132, + ClientIdentifierNotValid = 133, + BadUserNamePassword = 134, + NotAuthorized = 135, + ServerUnavailable = 136, + ServerBusy = 137, + Banned = 138, + BadAuthenticationMethod = 140, + TopicNameInvalid = 144, + PacketTooLarge = 149, + QuotaExceeded = 151, + PayloadFormatInvalid = 153, + RetainNotSupported = 154, + QoSNotSupported = 155, + UseAnotherServer = 156, + ServerMoved = 157, + ConnectionRateExceeded = 159, +} + +/// Acknowledgement to connect packet +#[derive(Debug, Clone, PartialEq)] +pub struct ConnAck { + pub session_present: bool, + pub code: ConnectReturnCode, + pub properties: Option, +} + +impl ConnAck { + pub fn new(code: ConnectReturnCode, session_present: bool) -> ConnAck { + ConnAck { + code, + session_present, + properties: None, + } + } + + fn len(&self) -> usize { + let mut len = 1 // session present + + 1; // code + + if let Some(properties) = &self.properties { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + let flags = read_u8(&mut bytes)?; + let return_code = read_u8(&mut bytes)?; + + let session_present = (flags & 0x01) == 1; + let code = connect_return(return_code)?; + let connack = ConnAck { + session_present, + code, + properties: ConnAckProperties::extract(&mut bytes)?, + }; + + Ok(connack) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + let len = self.len(); + buffer.put_u8(0x20); + + let count = write_remaining_length(buffer, len)?; + buffer.put_u8(self.session_present as u8); + buffer.put_u8(self.code as u8); + + if let Some(properties) = &self.properties { + properties.write(buffer)?; + } + + Ok(1 + count + len) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct ConnAckProperties { + pub session_expiry_interval: Option, + pub receive_max: Option, + pub max_qos: Option, + pub retain_available: Option, + pub max_packet_size: Option, + pub assigned_client_identifier: Option, + pub topic_alias_max: Option, + pub reason_string: Option, + pub user_properties: Vec<(String, String)>, + pub wildcard_subscription_available: Option, + pub subscription_identifiers_available: Option, + pub shared_subscription_available: Option, + pub server_keep_alive: Option, + pub response_information: Option, + pub server_reference: Option, + pub authentication_method: Option, + pub authentication_data: Option, +} + +impl ConnAckProperties { + pub fn new() -> ConnAckProperties { + ConnAckProperties { + session_expiry_interval: None, + receive_max: None, + max_qos: None, + retain_available: None, + max_packet_size: None, + assigned_client_identifier: None, + topic_alias_max: None, + reason_string: None, + user_properties: Vec::new(), + wildcard_subscription_available: None, + subscription_identifiers_available: None, + shared_subscription_available: None, + server_keep_alive: None, + response_information: None, + server_reference: None, + authentication_method: None, + authentication_data: None, + } + } + + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(_) = &self.session_expiry_interval { + len += 1 + 4; + } + + if let Some(_) = &self.receive_max { + len += 1 + 2; + } + + if let Some(_) = &self.max_qos { + len += 1 + 1; + } + + if let Some(_) = &self.retain_available { + len += 1 + 1; + } + + if let Some(_) = &self.max_packet_size { + len += 1 + 4; + } + + if let Some(id) = &self.assigned_client_identifier { + len += 1 + 2 + id.len(); + } + + if let Some(_) = &self.topic_alias_max { + len += 1 + 2; + } + + if let Some(reason) = &self.reason_string { + len += 1 + 2 + reason.len(); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + if let Some(_) = &self.wildcard_subscription_available { + len += 1 + 1; + } + + if let Some(_) = &self.subscription_identifiers_available { + len += 1 + 1; + } + + if let Some(_) = &self.shared_subscription_available { + len += 1 + 1; + } + + if let Some(_) = &self.server_keep_alive { + len += 1 + 2; + } + + if let Some(info) = &self.response_information { + len += 1 + 2 + info.len(); + } + + if let Some(reference) = &self.server_reference { + len += 1 + 2 + reference.len(); + } + + if let Some(authentication_method) = &self.authentication_method { + len += 1 + 2 + authentication_method.len(); + } + + if let Some(authentication_data) = &self.authentication_data { + len += 1 + 2 + authentication_data.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut session_expiry_interval = None; + let mut receive_max = None; + let mut max_qos = None; + let mut retain_available = None; + let mut max_packet_size = None; + let mut assigned_client_identifier = None; + let mut topic_alias_max = None; + let mut reason_string = None; + let mut user_properties = Vec::new(); + let mut wildcard_subscription_available = None; + let mut subscription_identifiers_available = None; + let mut shared_subscription_available = None; + let mut server_keep_alive = None; + let mut response_information = None; + let mut server_reference = None; + let mut authentication_method = None; + let mut authentication_data = None; + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::SessionExpiryInterval => { + session_expiry_interval = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::ReceiveMaximum => { + receive_max = Some(read_u16(&mut bytes)?); + cursor += 2; + } + PropertyType::MaximumQos => { + max_qos = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::RetainAvailable => { + retain_available = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::AssignedClientIdentifier => { + let id = read_mqtt_string(&mut bytes)?; + cursor += 2 + id.len(); + assigned_client_identifier = Some(id); + } + PropertyType::MaximumPacketSize => { + max_packet_size = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::TopicAliasMaximum => { + topic_alias_max = Some(read_u16(&mut bytes)?); + cursor += 2; + } + PropertyType::ReasonString => { + let reason = read_mqtt_string(&mut bytes)?; + cursor += 2 + reason.len(); + reason_string = Some(reason); + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + PropertyType::WildcardSubscriptionAvailable => { + wildcard_subscription_available = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::SubscriptionIdentifierAvailable => { + subscription_identifiers_available = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::SharedSubscriptionAvailable => { + shared_subscription_available = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::ServerKeepAlive => { + server_keep_alive = Some(read_u16(&mut bytes)?); + cursor += 2; + } + PropertyType::ResponseInformation => { + let info = read_mqtt_string(&mut bytes)?; + cursor += 2 + info.len(); + response_information = Some(info); + } + PropertyType::ServerReference => { + let reference = read_mqtt_string(&mut bytes)?; + cursor += 2 + reference.len(); + server_reference = Some(reference); + } + PropertyType::AuthenticationMethod => { + let method = read_mqtt_string(&mut bytes)?; + cursor += 2 + method.len(); + authentication_method = Some(method); + } + PropertyType::AuthenticationData => { + let data = read_mqtt_bytes(&mut bytes)?; + cursor += 2 + data.len(); + authentication_data = Some(data); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(ConnAckProperties { + session_expiry_interval, + receive_max, + max_qos, + retain_available, + max_packet_size, + assigned_client_identifier, + topic_alias_max, + reason_string, + user_properties, + wildcard_subscription_available, + subscription_identifiers_available, + shared_subscription_available, + server_keep_alive, + response_information, + server_reference, + authentication_method, + authentication_data, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(session_expiry_interval) = self.session_expiry_interval { + buffer.put_u8(PropertyType::SessionExpiryInterval as u8); + buffer.put_u32(session_expiry_interval); + } + + if let Some(receive_maximum) = self.receive_max { + buffer.put_u8(PropertyType::ReceiveMaximum as u8); + buffer.put_u16(receive_maximum); + } + + if let Some(qos) = self.max_qos { + buffer.put_u8(PropertyType::MaximumQos as u8); + buffer.put_u8(qos); + } + + if let Some(retain_available) = self.retain_available { + buffer.put_u8(PropertyType::RetainAvailable as u8); + buffer.put_u8(retain_available); + } + + if let Some(max_packet_size) = self.max_packet_size { + buffer.put_u8(PropertyType::MaximumPacketSize as u8); + buffer.put_u32(max_packet_size); + } + + if let Some(id) = &self.assigned_client_identifier { + buffer.put_u8(PropertyType::AssignedClientIdentifier as u8); + write_mqtt_string(buffer, id); + } + + if let Some(topic_alias_max) = self.topic_alias_max { + buffer.put_u8(PropertyType::TopicAliasMaximum as u8); + buffer.put_u16(topic_alias_max); + } + + if let Some(reason) = &self.reason_string { + buffer.put_u8(PropertyType::ReasonString as u8); + write_mqtt_string(buffer, reason); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + if let Some(w) = self.wildcard_subscription_available { + buffer.put_u8(PropertyType::WildcardSubscriptionAvailable as u8); + buffer.put_u8(w); + } + + if let Some(s) = self.subscription_identifiers_available { + buffer.put_u8(PropertyType::SubscriptionIdentifierAvailable as u8); + buffer.put_u8(s); + } + + if let Some(s) = self.shared_subscription_available { + buffer.put_u8(PropertyType::SharedSubscriptionAvailable as u8); + buffer.put_u8(s); + } + + if let Some(keep_alive) = self.server_keep_alive { + buffer.put_u8(PropertyType::ServerKeepAlive as u8); + buffer.put_u16(keep_alive); + } + + if let Some(info) = &self.response_information { + buffer.put_u8(PropertyType::ResponseInformation as u8); + write_mqtt_string(buffer, info); + } + + if let Some(reference) = &self.server_reference { + buffer.put_u8(PropertyType::ServerReference as u8); + write_mqtt_string(buffer, reference); + } + + if let Some(authentication_method) = &self.authentication_method { + buffer.put_u8(PropertyType::AuthenticationMethod as u8); + write_mqtt_string(buffer, authentication_method); + } + + if let Some(authentication_data) = &self.authentication_data { + buffer.put_u8(PropertyType::AuthenticationData as u8); + write_mqtt_bytes(buffer, authentication_data); + } + + Ok(()) + } +} + +/// Connection return code type +fn connect_return(num: u8) -> Result { + let code = match num { + 0 => ConnectReturnCode::Success, + 128 => ConnectReturnCode::UnspecifiedError, + 129 => ConnectReturnCode::MalformedPacket, + 130 => ConnectReturnCode::ProtocolError, + 131 => ConnectReturnCode::ImplementationSpecificError, + 132 => ConnectReturnCode::UnsupportedProtocolVersion, + 133 => ConnectReturnCode::ClientIdentifierNotValid, + 134 => ConnectReturnCode::BadUserNamePassword, + 135 => ConnectReturnCode::NotAuthorized, + 136 => ConnectReturnCode::ServerUnavailable, + 137 => ConnectReturnCode::ServerBusy, + 138 => ConnectReturnCode::Banned, + 140 => ConnectReturnCode::BadAuthenticationMethod, + 144 => ConnectReturnCode::TopicNameInvalid, + 149 => ConnectReturnCode::PacketTooLarge, + 151 => ConnectReturnCode::QuotaExceeded, + 153 => ConnectReturnCode::PayloadFormatInvalid, + 154 => ConnectReturnCode::RetainNotSupported, + 155 => ConnectReturnCode::QoSNotSupported, + 156 => ConnectReturnCode::UseAnotherServer, + 157 => ConnectReturnCode::ServerMoved, + 159 => ConnectReturnCode::ConnectionRateExceeded, + num => return Err(Error::InvalidConnectReturnCode(num)), + }; + + Ok(code) +} + +#[cfg(test)] +mod test { + use super::*; + use bytes::{Bytes, BytesMut}; + use pretty_assertions::assert_eq; + + fn sample() -> ConnAck { + let properties = ConnAckProperties { + session_expiry_interval: Some(1234), + receive_max: Some(432), + max_qos: Some(2), + retain_available: Some(1), + max_packet_size: Some(100), + assigned_client_identifier: Some("test".to_owned()), + topic_alias_max: Some(456), + reason_string: Some("test".to_owned()), + user_properties: vec![("test".to_owned(), "test".to_owned())], + wildcard_subscription_available: Some(1), + subscription_identifiers_available: Some(1), + shared_subscription_available: Some(0), + server_keep_alive: Some(1234), + response_information: Some("test".to_owned()), + server_reference: Some("test".to_owned()), + authentication_method: Some("test".to_owned()), + authentication_data: Some(Bytes::from(vec![1, 2, 3, 4])), + }; + + ConnAck { + session_present: false, + code: ConnectReturnCode::Success, + properties: Some(properties), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0x20, // Packet type + 0x57, // Remaining length + 0x00, 0x00, // Session, code + 0x54, // Properties length + 0x11, 0x00, 0x00, 0x04, 0xd2, // Session expiry interval + 0x21, 0x01, 0xb0, // Receive maximum + 0x24, 0x02, // Maximum qos + 0x25, 0x01, // Retain available + 0x27, 0x00, 0x00, 0x00, 0x64, // Maximum packet size + 0x12, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // Assigned client identifier + 0x22, 0x01, 0xc8, // Topic alias max + 0x1f, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // Reason string + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // user properties + 0x28, 0x01, // wildcard_subscription_available + 0x29, 0x01, // subscription_identifiers_available + 0x2a, 0x00, // shared_subscription_available + 0x13, 0x04, 0xd2, // server keep_alive + 0x1a, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // response_information + 0x1c, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // server reference + 0x15, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // authentication method + 0x16, 0x00, 0x04, 0x01, 0x02, 0x03, 0x04, // authentication data + ] + } + + #[test] + fn connack_parsing_works() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let connack_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let connack = ConnAck::read(fixed_header, connack_bytes).unwrap(); + + assert_eq!(connack, sample()); + } + + #[test] + fn connack_encoding_works() { + let connack = sample(); + let mut buf = BytesMut::new(); + connack.write(&mut buf).unwrap(); + assert_eq!(&buf[..], sample_bytes()); + } +} diff --git a/rumqttc/src/v5/packet/connect.rs b/rumqttc/src/v5/packet/connect.rs new file mode 100644 index 000000000..1c630ae79 --- /dev/null +++ b/rumqttc/src/v5/packet/connect.rs @@ -0,0 +1,928 @@ +use super::*; +use bytes::{Buf, Bytes}; + +/// Connection packet initiated by the client +#[derive(Debug, Clone, PartialEq)] +pub struct Connect { + /// Mqtt protocol version + pub protocol: Protocol, + /// Mqtt keep alive time + pub keep_alive: u16, + /// Client Id + pub client_id: String, + /// Clean session. Asks the broker to clear previous state + pub clean_session: bool, + /// Will that broker needs to publish when the client disconnects + pub last_will: Option, + /// Login credentials + pub login: Option, + /// Properties + pub properties: Option, +} + +impl Connect { + pub fn new>(id: S) -> Connect { + Connect { + protocol: Protocol::V5, + keep_alive: 10, + properties: None, + client_id: id.into(), + clean_session: true, + last_will: None, + login: None, + } + } + + pub fn set_login, P: Into>(&mut self, u: U, p: P) -> &mut Connect { + let login = Login { + username: u.into(), + password: p.into(), + }; + + self.login = Some(login); + self + } + + pub fn len(&self) -> usize { + let mut len = 2 + "MQTT".len() // protocol name + + 1 // protocol version + + 1 // connect flags + + 2; // keep alive + + match &self.properties { + Some(properties) => { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + None => { + // just 1 byte representing 0 len + len += 1; + } + } + + len += 2 + self.client_id.len(); + + // last will len + if let Some(last_will) = &self.last_will { + len += last_will.len(); + } + + // username and password len + if let Some(login) = &self.login { + len += login.len(); + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + // Variable header + let protocol_name = read_mqtt_string(&mut bytes)?; + let protocol_level = read_u8(&mut bytes)?; + if protocol_name != "MQTT" { + return Err(Error::InvalidProtocol); + } + + let protocol = match protocol_level { + 4 => Protocol::V4, + 5 => Protocol::V5, + num => return Err(Error::InvalidProtocolLevel(num)), + }; + + let connect_flags = read_u8(&mut bytes)?; + let clean_session = (connect_flags & 0b10) != 0; + let keep_alive = read_u16(&mut bytes)?; + + // Properties in variable header + let properties = match protocol { + Protocol::V5 => ConnectProperties::read(&mut bytes)?, + Protocol::V4 => None, + }; + + let client_id = read_mqtt_string(&mut bytes)?; + let last_will = LastWill::read(connect_flags, &mut bytes)?; + let login = Login::read(connect_flags, &mut bytes)?; + + let connect = Connect { + protocol, + keep_alive, + client_id, + clean_session, + last_will, + login, + properties, + }; + + Ok(connect) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + let len = self.len(); + buffer.put_u8(0b0001_0000); + let count = write_remaining_length(buffer, len)?; + write_mqtt_string(buffer, "MQTT"); + + match self.protocol { + Protocol::V4 => buffer.put_u8(0x04), + Protocol::V5 => buffer.put_u8(0x05), + } + + let flags_index = 1 + count + 2 + 4 + 1; + + let mut connect_flags = 0; + if self.clean_session { + connect_flags |= 0x02; + } + + buffer.put_u8(connect_flags); + buffer.put_u16(self.keep_alive); + + match &self.properties { + Some(properties) => properties.write(buffer)?, + None => { + write_remaining_length(buffer, 0)?; + } + }; + + write_mqtt_string(buffer, &self.client_id); + + if let Some(last_will) = &self.last_will { + connect_flags |= last_will.write(buffer)?; + } + + if let Some(login) = &self.login { + connect_flags |= login.write(buffer); + } + + // update connect flags + buffer[flags_index] = connect_flags; + Ok(len) + } +} + +/// LastWill that broker forwards on behalf of the client +#[derive(Debug, Clone, PartialEq)] +pub struct LastWill { + pub topic: String, + pub message: Bytes, + pub qos: QoS, + pub retain: bool, + pub properties: Option, +} + +impl LastWill { + pub fn new( + topic: impl Into, + payload: impl Into>, + qos: QoS, + retain: bool, + ) -> LastWill { + LastWill { + topic: topic.into(), + message: Bytes::from(payload.into()), + qos, + retain, + properties: None, + } + } + + fn len(&self) -> usize { + let mut len = 0; + + match &self.properties { + Some(properties) => { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + None => { + // just 1 byte representing 0 len + len += 1; + } + }; + + len += 2 + self.topic.len() + 2 + self.message.len(); + len + } + + fn read(connect_flags: u8, mut bytes: &mut Bytes) -> Result, Error> { + let last_will = match connect_flags & 0b100 { + 0 if (connect_flags & 0b0011_1000) != 0 => { + return Err(Error::IncorrectPacketFormat); + } + 0 => None, + _ => { + // Properties in variable header + let properties = WillProperties::read(&mut bytes)?; + + let will_topic = read_mqtt_string(&mut bytes)?; + let will_message = read_mqtt_bytes(&mut bytes)?; + let will_qos = qos((connect_flags & 0b11000) >> 3)?; + Some(LastWill { + topic: will_topic, + message: will_message, + qos: will_qos, + retain: (connect_flags & 0b0010_0000) != 0, + properties, + }) + } + }; + + Ok(last_will) + } + + fn write(&self, buffer: &mut BytesMut) -> Result { + let mut connect_flags = 0; + + connect_flags |= 0x04 | (self.qos as u8) << 3; + if self.retain { + connect_flags |= 0x20; + } + + match &self.properties { + Some(properties) => properties.write(buffer)?, + None => { + write_remaining_length(buffer, 0)?; + } + }; + + write_mqtt_string(buffer, &self.topic); + write_mqtt_bytes(buffer, &self.message); + Ok(connect_flags) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct WillProperties { + pub delay_interval: Option, + pub payload_format_indicator: Option, + pub message_expiry_interval: Option, + pub content_type: Option, + pub response_topic: Option, + pub correlation_data: Option, + pub user_properties: Vec<(String, String)>, +} + +impl WillProperties { + fn len(&self) -> usize { + let mut len = 0; + + if self.delay_interval.is_some() { + len += 1 + 4; + } + + if self.payload_format_indicator.is_some() { + len += 1 + 1; + } + + if self.message_expiry_interval.is_some() { + len += 1 + 4; + } + + if let Some(typ) = &self.content_type { + len += 1 + 2 + typ.len() + } + + if let Some(topic) = &self.response_topic { + len += 1 + 2 + topic.len() + } + + if let Some(data) = &self.correlation_data { + len += 1 + 2 + data.len() + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + fn read(mut bytes: &mut Bytes) -> Result, Error> { + let mut delay_interval = None; + let mut payload_format_indicator = None; + let mut message_expiry_interval = None; + let mut content_type = None; + let mut response_topic = None; + let mut correlation_data = None; + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::WillDelayInterval => { + delay_interval = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::PayloadFormatIndicator => { + payload_format_indicator = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::MessageExpiryInterval => { + message_expiry_interval = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::ContentType => { + let typ = read_mqtt_string(&mut bytes)?; + cursor += 2 + typ.len(); + content_type = Some(typ); + } + PropertyType::ResponseTopic => { + let topic = read_mqtt_string(&mut bytes)?; + cursor += 2 + topic.len(); + response_topic = Some(topic); + } + PropertyType::CorrelationData => { + let data = read_mqtt_bytes(&mut bytes)?; + cursor += 2 + data.len(); + correlation_data = Some(data); + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(WillProperties { + delay_interval, + payload_format_indicator, + message_expiry_interval, + content_type, + response_topic, + correlation_data, + user_properties, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(delay_interval) = self.delay_interval { + buffer.put_u8(PropertyType::WillDelayInterval as u8); + buffer.put_u32(delay_interval); + } + + if let Some(payload_format_indicator) = self.payload_format_indicator { + buffer.put_u8(PropertyType::PayloadFormatIndicator as u8); + buffer.put_u8(payload_format_indicator); + } + + if let Some(message_expiry_interval) = self.message_expiry_interval { + buffer.put_u8(PropertyType::MessageExpiryInterval as u8); + buffer.put_u32(message_expiry_interval); + } + + if let Some(typ) = &self.content_type { + buffer.put_u8(PropertyType::ContentType as u8); + write_mqtt_string(buffer, typ); + } + + if let Some(topic) = &self.response_topic { + buffer.put_u8(PropertyType::ResponseTopic as u8); + write_mqtt_string(buffer, topic); + } + + if let Some(data) = &self.correlation_data { + buffer.put_u8(PropertyType::CorrelationData as u8); + write_mqtt_bytes(buffer, data); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct Login { + pub username: String, + pub password: String, +} + +impl Login { + pub fn new, P: Into>(u: U, p: P) -> Login { + Login { + username: u.into(), + password: p.into(), + } + } + + fn read(connect_flags: u8, mut bytes: &mut Bytes) -> Result, Error> { + let username = match connect_flags & 0b1000_0000 { + 0 => String::new(), + _ => read_mqtt_string(&mut bytes)?, + }; + + let password = match connect_flags & 0b0100_0000 { + 0 => String::new(), + _ => read_mqtt_string(&mut bytes)?, + }; + + if username.is_empty() && password.is_empty() { + Ok(None) + } else { + Ok(Some(Login { username, password })) + } + } + + fn len(&self) -> usize { + let mut len = 0; + + if !self.username.is_empty() { + len += 2 + self.username.len(); + } + + if !self.password.is_empty() { + len += 2 + self.password.len(); + } + + len + } + + fn write(&self, buffer: &mut BytesMut) -> u8 { + let mut connect_flags = 0; + if !self.username.is_empty() { + connect_flags |= 0x80; + write_mqtt_string(buffer, &self.username); + } + + if !self.password.is_empty() { + connect_flags |= 0x40; + write_mqtt_string(buffer, &self.password); + } + + connect_flags + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct ConnectProperties { + /// Expiry interval property after loosing connection + pub session_expiry_interval: Option, + /// Maximum simultaneous packets + pub receive_maximum: Option, + /// Maximum packet size + pub max_packet_size: Option, + /// Maximum mapping integer for a topic + pub topic_alias_max: Option, + pub request_response_info: Option, + pub request_problem_info: Option, + /// List of user properties + pub user_properties: Vec<(String, String)>, + /// Method of authentication + pub authentication_method: Option, + /// Authentication data + pub authentication_data: Option, +} + +impl ConnectProperties { + fn _new() -> ConnectProperties { + ConnectProperties { + session_expiry_interval: None, + receive_maximum: None, + max_packet_size: None, + topic_alias_max: None, + request_response_info: None, + request_problem_info: None, + user_properties: Vec::new(), + authentication_method: None, + authentication_data: None, + } + } + + fn read(mut bytes: &mut Bytes) -> Result, Error> { + let mut session_expiry_interval = None; + let mut receive_maximum = None; + let mut max_packet_size = None; + let mut topic_alias_max = None; + let mut request_response_info = None; + let mut request_problem_info = None; + let mut user_properties = Vec::new(); + let mut authentication_method = None; + let mut authentication_data = None; + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + match property(prop)? { + PropertyType::SessionExpiryInterval => { + session_expiry_interval = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::ReceiveMaximum => { + receive_maximum = Some(read_u16(&mut bytes)?); + cursor += 2; + } + PropertyType::MaximumPacketSize => { + max_packet_size = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::TopicAliasMaximum => { + topic_alias_max = Some(read_u16(&mut bytes)?); + cursor += 2; + } + PropertyType::RequestResponseInformation => { + request_response_info = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::RequestProblemInformation => { + request_problem_info = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + PropertyType::AuthenticationMethod => { + let method = read_mqtt_string(&mut bytes)?; + cursor += 2 + method.len(); + authentication_method = Some(method); + } + PropertyType::AuthenticationData => { + let data = read_mqtt_bytes(&mut bytes)?; + cursor += 2 + data.len(); + authentication_data = Some(data); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(ConnectProperties { + session_expiry_interval, + receive_maximum, + max_packet_size, + topic_alias_max, + request_response_info, + request_problem_info, + user_properties, + authentication_method, + authentication_data, + })) + } + + fn len(&self) -> usize { + let mut len = 0; + + if self.session_expiry_interval.is_some() { + len += 1 + 4; + } + + if self.receive_maximum.is_some() { + len += 1 + 2; + } + + if self.max_packet_size.is_some() { + len += 1 + 4; + } + + if self.topic_alias_max.is_some() { + len += 1 + 2; + } + + if self.request_response_info.is_some() { + len += 1 + 1; + } + + if self.request_problem_info.is_some() { + len += 1 + 1; + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + if let Some(authentication_method) = &self.authentication_method { + len += 1 + 2 + authentication_method.len(); + } + + if let Some(authentication_data) = &self.authentication_data { + len += 1 + 2 + authentication_data.len(); + } + + len + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(session_expiry_interval) = self.session_expiry_interval { + buffer.put_u8(PropertyType::SessionExpiryInterval as u8); + buffer.put_u32(session_expiry_interval); + } + + if let Some(receive_maximum) = self.receive_maximum { + buffer.put_u8(PropertyType::ReceiveMaximum as u8); + buffer.put_u16(receive_maximum); + } + + if let Some(max_packet_size) = self.max_packet_size { + buffer.put_u8(PropertyType::MaximumPacketSize as u8); + buffer.put_u32(max_packet_size); + } + + if let Some(topic_alias_max) = self.topic_alias_max { + buffer.put_u8(PropertyType::TopicAliasMaximum as u8); + buffer.put_u16(topic_alias_max); + } + + if let Some(request_response_info) = self.request_response_info { + buffer.put_u8(PropertyType::RequestResponseInformation as u8); + buffer.put_u8(request_response_info); + } + + if let Some(request_problem_info) = self.request_problem_info { + buffer.put_u8(PropertyType::RequestProblemInformation as u8); + buffer.put_u8(request_problem_info); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + if let Some(authentication_method) = &self.authentication_method { + buffer.put_u8(PropertyType::AuthenticationMethod as u8); + write_mqtt_string(buffer, authentication_method); + } + + if let Some(authentication_data) = &self.authentication_data { + buffer.put_u8(PropertyType::AuthenticationData as u8); + write_mqtt_bytes(buffer, authentication_data); + } + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use super::*; + + fn sample() -> Connect { + let connect_properties = ConnectProperties { + session_expiry_interval: Some(1234), + receive_maximum: Some(432), + max_packet_size: Some(100), + topic_alias_max: Some(456), + request_response_info: Some(1), + request_problem_info: Some(1), + user_properties: vec![("test".to_owned(), "test".to_owned())], + authentication_method: Some("test".to_owned()), + authentication_data: Some(Bytes::from(vec![1, 2, 3, 4])), + }; + + let will_properties = WillProperties { + delay_interval: Some(1234), + payload_format_indicator: Some(0), + message_expiry_interval: Some(4321), + content_type: Some("test".to_owned()), + response_topic: Some("topic".to_owned()), + correlation_data: Some(Bytes::from(vec![1, 2, 3, 4])), + user_properties: vec![("test".to_owned(), "test".to_owned())], + }; + + let will = LastWill { + topic: "mydevice/status".to_string(), + message: Bytes::from(vec![b'd', b'e', b'a', b'd']), + qos: QoS::AtMostOnce, + retain: false, + properties: Some(will_properties), + }; + + let login = Login { + username: "matteo".to_string(), + password: "collina".to_string(), + }; + + Connect { + protocol: Protocol::V5, + keep_alive: 0, + properties: Some(connect_properties), + client_id: "my-device".to_string(), + clean_session: true, + last_will: Some(will), + login: Some(login), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0x10, // packet type + 0x9d, // remaining len + 0x01, // remaining len + 0x00, 0x04, // 4 + 0x4d, // M + 0x51, // Q + 0x54, // T + 0x54, // T + 0x05, // Level + 0xc6, // connect flags + 0x00, 0x00, // keep alive + 0x2f, // properties len + 0x11, 0x00, 0x00, 0x04, 0xd2, // session expiry interval + 0x21, 0x01, 0xb0, // receive_maximum + 0x27, 0x00, 0x00, 0x00, 0x64, // max packet size + 0x22, 0x01, 0xc8, // topic_alias_max + 0x19, 0x01, // request_response_info + 0x17, 0x01, // request_problem_info + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // user + 0x15, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // authentication_method + 0x16, 0x00, 0x04, 0x01, 0x02, 0x03, 0x04, // authentication_data + 0x00, 0x09, 0x6d, 0x79, 0x2d, 0x64, 0x65, 0x76, 0x69, 0x63, 0x65, // client id + 0x2f, // will properties len + 0x18, 0x00, 0x00, 0x04, 0xd2, // will delay interval + 0x01, 0x00, // payload format indicator + 0x02, 0x00, 0x00, 0x10, 0xe1, // message expiry interval + 0x03, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // content type + 0x08, 0x00, 0x05, 0x74, 0x6f, 0x70, 0x69, 0x63, // response topic + 0x09, 0x00, 0x04, 0x01, 0x02, 0x03, 0x04, // correlation_data + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // will user properties + 0x00, 0x0f, 0x6d, 0x79, 0x64, 0x65, 0x76, 0x69, 0x63, 0x65, 0x2f, 0x73, 0x74, 0x61, + 0x74, 0x75, 0x73, // will topic + 0x00, 0x04, 0x64, 0x65, 0x61, 0x64, // will payload + 0x00, 0x06, 0x6d, 0x61, 0x74, 0x74, 0x65, 0x6f, // username + 0x00, 0x07, 0x63, 0x6f, 0x6c, 0x6c, 0x69, 0x6e, 0x61, // password + ] + } + + #[test] + fn connect1_parsing_works() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let connect_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let connect = Connect::read(fixed_header, connect_bytes).unwrap(); + assert_eq!(connect, sample()); + } + + #[test] + fn connect1_encoding_works() { + let connect = sample(); + let mut buf = BytesMut::new(); + connect.write(&mut buf).unwrap(); + assert_eq!(&buf[..], sample_bytes()); + } + + fn sample2() -> Connect { + Connect { + protocol: Protocol::V5, + keep_alive: 10, + properties: None, + client_id: "hackathonmqtt5test".to_owned(), + clean_session: true, + last_will: None, + login: None, + } + } + + fn sample2_bytes() -> Vec { + vec![ + 0x10, // packet type + 0x1f, 0x00, // remaining len + 0x04, // 4 + 0x4d, 0x51, 0x54, 0x54, // MQTT + 0x05, // level + 0x02, // connect flags + 0x00, 0x0a, // keep alive + 0x00, 0x00, 0x12, 0x68, 0x61, 0x63, 0x6b, 0x61, 0x74, 0x68, 0x6f, 0x6e, 0x6d, 0x71, + 0x74, 0x74, 0x35, 0x74, 0x65, 0x73, 0x74, // payload + 0x10, 0x11, 0x12, // extra bytes in the stream + ] + } + + #[test] + fn connect2_parsing_works() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample2_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let connect_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let connect = Connect::read(fixed_header, connect_bytes).unwrap(); + assert_eq!(connect, sample2()); + } + + #[test] + fn connect2_encoding_works() { + let connect = sample2(); + let mut buf = BytesMut::new(); + connect.write(&mut buf).unwrap(); + + let expected = sample2_bytes(); + assert_eq!(&buf[..], &expected[0..(expected.len() - 3)]); + } + + fn sample3() -> Connect { + let connect_properties = ConnectProperties { + session_expiry_interval: Some(1234), + receive_maximum: Some(432), + max_packet_size: Some(100), + topic_alias_max: Some(456), + request_response_info: Some(1), + request_problem_info: Some(1), + user_properties: vec![("test".to_owned(), "test".to_owned())], + authentication_method: Some("test".to_owned()), + authentication_data: Some(Bytes::from(vec![1, 2, 3, 4])), + }; + + let will = LastWill { + topic: "mydevice/status".to_string(), + message: Bytes::from(vec![b'd', b'e', b'a', b'd']), + qos: QoS::AtMostOnce, + retain: false, + properties: None, + }; + + let login = Login { + username: "matteo".to_string(), + password: "collina".to_string(), + }; + + Connect { + protocol: Protocol::V5, + keep_alive: 0, + properties: Some(connect_properties), + client_id: "my-device".to_string(), + clean_session: true, + last_will: Some(will), + login: Some(login), + } + } + + fn sample3_bytes() -> Vec { + vec![ + 0x10, 0x6e, 0x00, 0x04, 0x4d, 0x51, 0x54, 0x54, 0x05, 0xc6, 0x00, 0x00, 0x2f, 0x11, + 0x00, 0x00, 0x04, 0xd2, 0x21, 0x01, 0xb0, 0x27, 0x00, 0x00, 0x00, 0x64, 0x22, 0x01, + 0xc8, 0x19, 0x01, 0x17, 0x01, 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, + 0x74, 0x65, 0x73, 0x74, 0x15, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x16, 0x00, 0x04, + 0x01, 0x02, 0x03, 0x04, 0x00, 0x09, 0x6d, 0x79, 0x2d, 0x64, 0x65, 0x76, 0x69, 0x63, + 0x65, 0x00, 0x00, 0x0f, 0x6d, 0x79, 0x64, 0x65, 0x76, 0x69, 0x63, 0x65, 0x2f, 0x73, + 0x74, 0x61, 0x74, 0x75, 0x73, 0x00, 0x04, 0x64, 0x65, 0x61, 0x64, 0x00, 0x06, 0x6d, + 0x61, 0x74, 0x74, 0x65, 0x6f, 0x00, 0x07, 0x63, 0x6f, 0x6c, 0x6c, 0x69, 0x6e, 0x61, + ] + } + + #[test] + fn connect3_parsing_works() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample3_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let connect_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let connect = Connect::read(fixed_header, connect_bytes).unwrap(); + assert_eq!(connect, sample3()); + } + + #[test] + fn connect3_encoding_works() { + let connect = sample3(); + let mut buf = BytesMut::new(); + connect.write(&mut buf).unwrap(); + + let expected = sample3_bytes(); + assert_eq!(&buf[..], &expected[0..(expected.len())]); + } + + #[test] + fn missing_properties_are_encoded() {} +} diff --git a/rumqttc/src/v5/packet/disconnect.rs b/rumqttc/src/v5/packet/disconnect.rs new file mode 100644 index 000000000..3662087f1 --- /dev/null +++ b/rumqttc/src/v5/packet/disconnect.rs @@ -0,0 +1,434 @@ +use std::convert::{TryFrom, TryInto}; + +use bytes::{BufMut, Bytes, BytesMut}; + +use super::*; + +use super::{property, PropertyType}; + +#[derive(Debug, Clone, Copy, PartialEq)] +#[repr(u8)] +pub enum DisconnectReasonCode { + /// Close the connection normally. Do not send the Will Message. + NormalDisconnection = 0x00, + /// The Client wishes to disconnect but requires that the Server also publishes its Will Message. + DisconnectWithWillMessage = 0x04, + /// The Connection is closed but the sender either does not wish to reveal the reason, or none of the other Reason Codes apply. + UnspecifiedError = 0x80, + /// The received packet does not conform to this specification. + MalformedPacket = 0x81, + /// An unexpected or out of order packet was received. + ProtocolError = 0x82, + /// The packet received is valid but cannot be processed by this implementation. + ImplementationSpecificError = 0x83, + /// The request is not authorized. + NotAuthorized = 0x87, + /// The Server is busy and cannot continue processing requests from this Client. + ServerBusy = 0x89, + /// The Server is shutting down. + ServerShuttingDown = 0x8B, + /// The Connection is closed because no packet has been received for 1.5 times the Keepalive time. + KeepAliveTimeout = 0x8D, + /// Another Connection using the same ClientID has connected causing this Connection to be closed. + SessionTakenOver = 0x8E, + /// The Topic Filter is correctly formed, but is not accepted by this Sever. + TopicFilterInvalid = 0x8F, + /// The Topic Name is correctly formed, but is not accepted by this Client or Server. + TopicNameInvalid = 0x90, + /// The Client or Server has received more than Receive Maximum publication for which it has not sent PUBACK or PUBCOMP. + ReceiveMaximumExceeded = 0x93, + /// The Client or Server has received a PUBLISH packet containing a Topic Alias which is greater than the Maximum Topic Alias it sent in the CONNECT or CONNACK packet. + TopicAliasInvalid = 0x94, + /// The packet size is greater than Maximum Packet Size for this Client or Server. + PacketTooLarge = 0x95, + /// The received data rate is too high. + MessageRateTooHigh = 0x96, + /// An implementation or administrative imposed limit has been exceeded. + QuotaExceeded = 0x97, + /// The Connection is closed due to an administrative action. + AdministrativeAction = 0x98, + /// The payload format does not match the one specified by the Payload Format Indicator. + PayloadFormatInvalid = 0x99, + /// The Server has does not support retained messages. + RetainNotSupported = 0x9A, + /// The Client specified a QoS greater than the QoS specified in a Maximum QoS in the CONNACK. + QoSNotSupported = 0x9B, + /// The Client should temporarily change its Server. + UseAnotherServer = 0x9C, + /// The Server is moved and the Client should permanently change its server location. + ServerMoved = 0x9D, + /// The Server does not support Shared Subscriptions. + SharedSubscriptionNotSupported = 0x9E, + /// This connection is closed because the connection rate is too high. + ConnectionRateExceeded = 0x9F, + /// The maximum connection time authorized for this connection has been exceeded. + MaximumConnectTime = 0xA0, + /// The Server does not support Subscription Identifiers; the subscription is not accepted. + SubscriptionIdentifiersNotSupported = 0xA1, + /// The Server does not support Wildcard subscription; the subscription is not accepted. + WildcardSubscriptionsNotSupported = 0xA2, +} + +impl TryFrom for DisconnectReasonCode { + type Error = Error; + + fn try_from(value: u8) -> Result { + let rc = match value { + 0x00 => Self::NormalDisconnection, + 0x04 => Self::DisconnectWithWillMessage, + 0x80 => Self::UnspecifiedError, + 0x81 => Self::MalformedPacket, + 0x82 => Self::ProtocolError, + 0x83 => Self::ImplementationSpecificError, + 0x87 => Self::NotAuthorized, + 0x89 => Self::ServerBusy, + 0x8B => Self::ServerShuttingDown, + 0x8D => Self::KeepAliveTimeout, + 0x8E => Self::SessionTakenOver, + 0x8F => Self::TopicFilterInvalid, + 0x90 => Self::TopicNameInvalid, + 0x93 => Self::ReceiveMaximumExceeded, + 0x94 => Self::TopicAliasInvalid, + 0x95 => Self::PacketTooLarge, + 0x96 => Self::MessageRateTooHigh, + 0x97 => Self::QuotaExceeded, + 0x98 => Self::AdministrativeAction, + 0x99 => Self::PayloadFormatInvalid, + 0x9A => Self::RetainNotSupported, + 0x9B => Self::QoSNotSupported, + 0x9C => Self::UseAnotherServer, + 0x9D => Self::ServerMoved, + 0x9E => Self::SharedSubscriptionNotSupported, + 0x9F => Self::ConnectionRateExceeded, + 0xA0 => Self::MaximumConnectTime, + 0xA1 => Self::SubscriptionIdentifiersNotSupported, + 0xA2 => Self::WildcardSubscriptionsNotSupported, + other => return Err(Error::InvalidConnectReturnCode(other)), + }; + + Ok(rc) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct DisconnectProperties { + /// Session Expiry Interval in seconds + pub session_expiry_interval: Option, + + /// Human readable reason for the disconnect + pub reason_string: Option, + + /// List of user properties + pub user_properties: Vec<(String, String)>, + + /// String which can be used by the Client to identify another Server to use. + pub server_reference: Option, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct Disconnect { + /// Disconnect Reason Code + pub reason_code: DisconnectReasonCode, + + /// Disconnect Properties + pub properties: Option, +} + +impl DisconnectProperties { + pub fn new() -> Self { + Self { + session_expiry_interval: None, + reason_string: None, + user_properties: Vec::new(), + server_reference: None, + } + } + + fn len(&self) -> usize { + let mut length = 0; + + if self.session_expiry_interval.is_some() { + length += 1 + 4; + } + + if let Some(reason) = &self.reason_string { + length += 1 + 2 + reason.len(); + } + + for (key, value) in self.user_properties.iter() { + length += 1 + 2 + key.len() + 2 + value.len(); + } + + if let Some(server_reference) = &self.server_reference { + length += 1 + 2 + server_reference.len(); + } + + length + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let (properties_len_len, properties_len) = length(bytes.iter())?; + + bytes.advance(properties_len_len); + + if properties_len == 0 { + return Ok(None); + } + + let mut session_expiry_interval = None; + let mut reason_string = None; + let mut user_properties = Vec::new(); + let mut server_reference = None; + + let mut cursor = 0; + + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::SessionExpiryInterval => { + session_expiry_interval = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::ReasonString => { + let reason = read_mqtt_string(&mut bytes)?; + cursor += 2 + reason.len(); + reason_string = Some(reason); + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + PropertyType::ServerReference => { + let reference = read_mqtt_string(&mut bytes)?; + cursor += 2 + reference.len(); + server_reference = Some(reference); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + let properties = Self { + session_expiry_interval, + reason_string, + user_properties, + server_reference, + }; + + Ok(Some(properties)) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let length = self.len(); + write_remaining_length(buffer, length)?; + + if let Some(session_expiry_interval) = self.session_expiry_interval { + buffer.put_u8(PropertyType::SessionExpiryInterval as u8); + buffer.put_u32(session_expiry_interval); + } + + if let Some(reason) = &self.reason_string { + buffer.put_u8(PropertyType::ReasonString as u8); + write_mqtt_string(buffer, reason); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + if let Some(reference) = &self.server_reference { + buffer.put_u8(PropertyType::ServerReference as u8); + write_mqtt_string(buffer, reference); + } + + Ok(()) + } +} + +impl Disconnect { + pub fn new() -> Self { + Self { + reason_code: DisconnectReasonCode::NormalDisconnection, + properties: None, + } + } + + fn len(&self) -> usize { + if self.reason_code == DisconnectReasonCode::NormalDisconnection + && self.properties.is_none() + { + return 2; // Packet type + 0x00 + } + + let mut length = 0; + + match &self.properties { + Some(properties) => { + length += 1; // Disconnect Reason Code + + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + length += properties_len_len + properties_len; + } + None if self.reason_code == DisconnectReasonCode::NormalDisconnection => {} + None => { + length += 1; // Disconnect Reason Code + } + }; + + length + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let packet_type = fixed_header.byte1 >> 4; + let flags = fixed_header.byte1 & 0b0000_1111; + + bytes.advance(fixed_header.fixed_header_len); + + if packet_type != PacketType::Disconnect as u8 { + return Err(Error::InvalidPacketType(packet_type)); + }; + + if flags != 0x00 { + return Err(Error::MalformedPacket); + }; + + if fixed_header.remaining_len == 0 { + return Ok(Self::new()); + } + + let reason_code = read_u8(&mut bytes)?; + + let disconnect = Self { + reason_code: reason_code.try_into()?, + properties: DisconnectProperties::extract(&mut bytes)?, + }; + + Ok(disconnect) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + buffer.put_u8(0xE0); + + let length = self.len(); + + if length == 2 { + buffer.put_u8(0x00); + + return Ok(length); + } + + let len_len = write_remaining_length(buffer, length)?; + + buffer.put_u8(self.reason_code as u8); + + if let Some(properties) = &self.properties { + properties.write(buffer)?; + } + + Ok(1 + len_len + length) + } +} + +#[cfg(test)] +mod test { + use bytes::BytesMut; + + use super::parse_fixed_header; + + use super::{Disconnect, DisconnectProperties, DisconnectReasonCode}; + + #[test] + fn disconnect1_parsing_works() { + let mut buffer = bytes::BytesMut::new(); + let packet_bytes = [ + 0xE0, // Packet type + 0x00, // Remaining length + ]; + let expected = Disconnect::new(); + + buffer.extend_from_slice(&packet_bytes[..]); + + let fixed_header = parse_fixed_header(buffer.iter()).unwrap(); + let disconnect_bytes = buffer.split_to(fixed_header.frame_length()).freeze(); + let disconnect = Disconnect::read(fixed_header, disconnect_bytes).unwrap(); + + assert_eq!(disconnect, expected); + } + + #[test] + fn disconnect1_encoding_works() { + let mut buffer = BytesMut::new(); + let disconnect = Disconnect::new(); + let expected = [ + 0xE0, // Packet type + 0x00, // Remaining length + ]; + + disconnect.write(&mut buffer).unwrap(); + + assert_eq!(&buffer[..], &expected); + } + + fn sample2() -> Disconnect { + let properties = DisconnectProperties { + // TODO: change to 2137 xD + session_expiry_interval: Some(1234), + reason_string: Some("test".to_owned()), + user_properties: vec![("test".to_owned(), "test".to_owned())], + server_reference: Some("test".to_owned()), + }; + + Disconnect { + reason_code: DisconnectReasonCode::UnspecifiedError, + properties: Some(properties), + } + } + + fn sample_bytes2() -> Vec { + vec![ + 0xE0, // Packet type + 0x22, // Remaining length + 0x80, // Disconnect Reason Code + 0x20, // Properties length + 0x11, 0x00, 0x00, 0x04, 0xd2, // Session expiry interval + 0x1F, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // Reason string + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // User properties + 0x1C, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // server reference + ] + } + + #[test] + fn disconnect2_parsing_works() { + let mut buffer = bytes::BytesMut::new(); + let packet_bytes = sample_bytes2(); + let expected = sample2(); + + buffer.extend_from_slice(&packet_bytes[..]); + + let fixed_header = parse_fixed_header(buffer.iter()).unwrap(); + let disconnect_bytes = buffer.split_to(fixed_header.frame_length()).freeze(); + let disconnect = Disconnect::read(fixed_header, disconnect_bytes).unwrap(); + + assert_eq!(disconnect, expected); + } + + #[test] + fn disconnect2_encoding_works() { + let mut buffer = BytesMut::new(); + + let disconnect = sample2(); + let expected = sample_bytes2(); + + disconnect.write(&mut buffer).unwrap(); + + assert_eq!(&buffer[..], &expected); + } +} diff --git a/rumqttc/src/v5/packet/mod.rs b/rumqttc/src/v5/packet/mod.rs new file mode 100644 index 000000000..8f954c1cf --- /dev/null +++ b/rumqttc/src/v5/packet/mod.rs @@ -0,0 +1,489 @@ +use std::{ + fmt::{self, Display, Formatter}, + slice::Iter, +}; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +mod connack; +mod connect; +mod disconnect; +mod ping; +mod puback; +mod pubcomp; +mod publish; +mod pubrec; +mod pubrel; +mod suback; +mod subscribe; +mod unsuback; +mod unsubscribe; + +pub use connack::*; +pub use connect::*; +pub use disconnect::*; +pub use ping::*; +pub use puback::*; +pub use pubcomp::*; +pub use publish::*; +pub use pubrec::*; +pub use pubrel::*; +pub use suback::*; +pub use subscribe::*; +pub use unsuback::*; +pub use unsubscribe::*; + +/// Encapsulates all MQTT packet types +#[derive(Debug, Clone, PartialEq)] +pub enum Packet { + Connect(Connect), + ConnAck(ConnAck), + Publish(Publish), + PubAck(PubAck), + PubRec(PubRec), + PubRel(PubRel), + PubComp(PubComp), + Subscribe(Subscribe), + SubAck(SubAck), + Unsubscribe(Unsubscribe), + UnsubAck(UnsubAck), + PingReq, + PingResp, + Disconnect(Disconnect), +} + +/// MQTT packet type +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PacketType { + Connect = 1, + ConnAck, + Publish, + PubAck, + PubRec, + PubRel, + PubComp, + Subscribe, + SubAck, + Unsubscribe, + UnsubAck, + PingReq, + PingResp, + Disconnect, +} + +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum PropertyType { + PayloadFormatIndicator = 1, + MessageExpiryInterval = 2, + ContentType = 3, + ResponseTopic = 8, + CorrelationData = 9, + SubscriptionIdentifier = 11, + SessionExpiryInterval = 17, + AssignedClientIdentifier = 18, + ServerKeepAlive = 19, + AuthenticationMethod = 21, + AuthenticationData = 22, + RequestProblemInformation = 23, + WillDelayInterval = 24, + RequestResponseInformation = 25, + ResponseInformation = 26, + ServerReference = 28, + ReasonString = 31, + ReceiveMaximum = 33, + TopicAliasMaximum = 34, + TopicAlias = 35, + MaximumQos = 36, + RetainAvailable = 37, + UserProperty = 38, + MaximumPacketSize = 39, + WildcardSubscriptionAvailable = 40, + SubscriptionIdentifierAvailable = 41, + SharedSubscriptionAvailable = 42, +} + +/// Error during serialization and deserialization +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Error { + NotConnect(PacketType), + UnexpectedConnect, + InvalidConnectReturnCode(u8), + InvalidReason(u8), + InvalidProtocol, + InvalidProtocolLevel(u8), + IncorrectPacketFormat, + InvalidPacketType(u8), + InvalidPropertyType(u8), + InvalidRetainForwardRule(u8), + InvalidQoS(u8), + InvalidSubscribeReasonCode(u8), + PacketIdZero, + SubscriptionIdZero, + PayloadSizeIncorrect, + PayloadTooLong, + PayloadSizeLimitExceeded(usize), + PayloadRequired, + TopicNotUtf8, + BoundaryCrossed(usize), + MalformedPacket, + MalformedRemainingLength, + /// More bytes required to frame packet. Argument + /// implies minimum additional bytes required to + /// proceed further + InsufficientBytes(usize), +} + +/// Protocol type +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Protocol { + V4, + V5, +} + +/// Quality of service +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)] +pub enum QoS { + AtMostOnce = 0, + AtLeastOnce = 1, + ExactlyOnce = 2, +} + +/// Packet type from a byte +/// +/// ```ignore +/// 7 3 0 +/// +--------------------------+--------------------------+ +/// byte 1 | MQTT Control Packet Type | Flags for each type | +/// +--------------------------+--------------------------+ +/// | Remaining Bytes Len (1/2/3/4 bytes) | +/// +-----------------------------------------------------+ +/// +/// http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Figure_2.2_- +/// ``` +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)] +pub struct FixedHeader { + /// First byte of the stream. Used to identify packet types and + /// several flags + byte1: u8, + /// Length of fixed header. Byte 1 + (1..4) bytes. So fixed header + /// len can vary from 2 bytes to 5 bytes + /// 1..4 bytes are variable length encoded to represent remaining length + fixed_header_len: usize, + /// Remaining length of the packet. Doesn't include fixed header bytes + /// Represents variable header + payload size + remaining_len: usize, +} + +impl FixedHeader { + pub fn new(byte1: u8, remaining_len_len: usize, remaining_len: usize) -> FixedHeader { + FixedHeader { + byte1, + fixed_header_len: remaining_len_len + 1, + remaining_len, + } + } + + pub fn packet_type(&self) -> Result { + let num = self.byte1 >> 4; + match num { + 1 => Ok(PacketType::Connect), + 2 => Ok(PacketType::ConnAck), + 3 => Ok(PacketType::Publish), + 4 => Ok(PacketType::PubAck), + 5 => Ok(PacketType::PubRec), + 6 => Ok(PacketType::PubRel), + 7 => Ok(PacketType::PubComp), + 8 => Ok(PacketType::Subscribe), + 9 => Ok(PacketType::SubAck), + 10 => Ok(PacketType::Unsubscribe), + 11 => Ok(PacketType::UnsubAck), + 12 => Ok(PacketType::PingReq), + 13 => Ok(PacketType::PingResp), + 14 => Ok(PacketType::Disconnect), + _ => Err(Error::InvalidPacketType(num)), + } + } + + /// Returns the size of full packet (fixed header + variable header + payload) + /// Fixed header is enough to get the size of a frame in the stream + pub fn frame_length(&self) -> usize { + self.fixed_header_len + self.remaining_len + } +} + +fn property(num: u8) -> Result { + let property = match num { + 1 => PropertyType::PayloadFormatIndicator, + 2 => PropertyType::MessageExpiryInterval, + 3 => PropertyType::ContentType, + 8 => PropertyType::ResponseTopic, + 9 => PropertyType::CorrelationData, + 11 => PropertyType::SubscriptionIdentifier, + 17 => PropertyType::SessionExpiryInterval, + 18 => PropertyType::AssignedClientIdentifier, + 19 => PropertyType::ServerKeepAlive, + 21 => PropertyType::AuthenticationMethod, + 22 => PropertyType::AuthenticationData, + 23 => PropertyType::RequestProblemInformation, + 24 => PropertyType::WillDelayInterval, + 25 => PropertyType::RequestResponseInformation, + 26 => PropertyType::ResponseInformation, + 28 => PropertyType::ServerReference, + 31 => PropertyType::ReasonString, + 33 => PropertyType::ReceiveMaximum, + 34 => PropertyType::TopicAliasMaximum, + 35 => PropertyType::TopicAlias, + 36 => PropertyType::MaximumQos, + 37 => PropertyType::RetainAvailable, + 38 => PropertyType::UserProperty, + 39 => PropertyType::MaximumPacketSize, + 40 => PropertyType::WildcardSubscriptionAvailable, + 41 => PropertyType::SubscriptionIdentifierAvailable, + 42 => PropertyType::SharedSubscriptionAvailable, + num => return Err(Error::InvalidPropertyType(num)), + }; + + Ok(property) +} + +/// Checks if the stream has enough bytes to frame a packet and returns fixed header +/// only if a packet can be framed with existing bytes in the `stream`. +/// The passed stream doesn't modify parent stream's cursor. If this function +/// returned an error, next `check` on the same parent stream is forced start +/// with cursor at 0 again (Iter is owned. Only Iter's cursor is changed internally) +pub fn check(stream: Iter, max_packet_size: usize) -> Result { + // Create fixed header if there are enough bytes in the stream + // to frame full packet + let stream_len = stream.len(); + let fixed_header = parse_fixed_header(stream)?; + + // Don't let rogue connections attack with huge payloads. + // Disconnect them before reading all that data + if fixed_header.remaining_len > max_packet_size { + return Err(Error::PayloadSizeLimitExceeded(fixed_header.remaining_len)); + } + + // If the current call fails due to insufficient bytes in the stream, + // after calculating remaining length, we extend the stream + let frame_length = fixed_header.frame_length(); + if stream_len < frame_length { + return Err(Error::InsufficientBytes(frame_length - stream_len)); + } + + Ok(fixed_header) +} + +/// Parses fixed header +fn parse_fixed_header(mut stream: Iter) -> Result { + // At least 2 bytes are necessary to frame a packet + let stream_len = stream.len(); + if stream_len < 2 { + return Err(Error::InsufficientBytes(2 - stream_len)); + } + + let byte1 = stream.next().unwrap(); + let (len_len, len) = length(stream)?; + + Ok(FixedHeader::new(*byte1, len_len, len)) +} + +/// Parses variable byte integer in the stream and returns the length +/// and number of bytes that make it. Used for remaining length calculation +/// as well as for calculating property lengths +fn length(stream: Iter) -> Result<(usize, usize), Error> { + let mut len: usize = 0; + let mut len_len = 0; + let mut done = false; + let mut shift = 0; + + // Use continuation bit at position 7 to continue reading next + // byte to frame 'length'. + // Stream 0b1xxx_xxxx 0b1yyy_yyyy 0b1zzz_zzzz 0b0www_wwww will + // be framed as number 0bwww_wwww_zzz_zzzz_yyy_yyyy_xxx_xxxx + for byte in stream { + len_len += 1; + let byte = *byte as usize; + len += (byte & 0x7F) << shift; + + // stop when continue bit is 0 + done = (byte & 0x80) == 0; + if done { + break; + } + + shift += 7; + + // Only a max of 4 bytes allowed for remaining length + // more than 4 shifts (0, 7, 14, 21) implies bad length + if shift > 21 { + return Err(Error::MalformedRemainingLength); + } + } + + // Not enough bytes to frame remaining length. wait for + // one more byte + if !done { + return Err(Error::InsufficientBytes(1)); + } + + Ok((len_len, len)) +} + +/// Reads a stream of bytes and extracts next MQTT packet out of it +pub fn read(stream: &mut BytesMut, max_size: usize) -> Result { + let fixed_header = check(stream.iter(), max_size)?; + + // Test with a stream with exactly the size to check border panics + let packet = stream.split_to(fixed_header.frame_length()); + let packet_type = fixed_header.packet_type()?; + + if fixed_header.remaining_len == 0 { + // no payload packets + return match packet_type { + PacketType::PingReq => Ok(Packet::PingReq), + PacketType::PingResp => Ok(Packet::PingResp), + _ => Err(Error::PayloadRequired), + }; + } + + let packet = packet.freeze(); + let packet = match packet_type { + PacketType::Connect => Packet::Connect(Connect::read(fixed_header, packet)?), + PacketType::ConnAck => Packet::ConnAck(ConnAck::read(fixed_header, packet)?), + PacketType::Publish => Packet::Publish(Publish::read(fixed_header, packet)?), + PacketType::PubAck => Packet::PubAck(PubAck::read(fixed_header, packet)?), + PacketType::PubRec => Packet::PubRec(PubRec::read(fixed_header, packet)?), + PacketType::PubRel => Packet::PubRel(PubRel::read(fixed_header, packet)?), + PacketType::PubComp => Packet::PubComp(PubComp::read(fixed_header, packet)?), + PacketType::Subscribe => Packet::Subscribe(Subscribe::read(fixed_header, packet)?), + PacketType::SubAck => Packet::SubAck(SubAck::read(fixed_header, packet)?), + PacketType::Unsubscribe => Packet::Unsubscribe(Unsubscribe::read(fixed_header, packet)?), + PacketType::UnsubAck => Packet::UnsubAck(UnsubAck::read(fixed_header, packet)?), + PacketType::PingReq => Packet::PingReq, + PacketType::PingResp => Packet::PingResp, + PacketType::Disconnect => Packet::Disconnect(Disconnect::read(fixed_header, packet)?), + }; + + Ok(packet) +} + +/// Reads a series of bytes with a length from a byte stream +fn read_mqtt_bytes(stream: &mut Bytes) -> Result { + let len = read_u16(stream)? as usize; + + // Prevent attacks with wrong remaining length. This method is used in + // `packet.assembly()` with (enough) bytes to frame packet. Ensures that + // reading variable len string or bytes doesn't cross promised boundary + // with `read_fixed_header()` + if len > stream.len() { + return Err(Error::BoundaryCrossed(len)); + } + + Ok(stream.split_to(len)) +} + +/// Reads a string from bytes stream +fn read_mqtt_string(stream: &mut Bytes) -> Result { + let s = read_mqtt_bytes(stream)?; + match String::from_utf8(s.to_vec()) { + Ok(v) => Ok(v), + Err(_e) => Err(Error::TopicNotUtf8), + } +} + +/// Serializes bytes to stream (including length) +fn write_mqtt_bytes(stream: &mut BytesMut, bytes: &[u8]) { + stream.put_u16(bytes.len() as u16); + stream.extend_from_slice(bytes); +} + +/// Serializes a string to stream +fn write_mqtt_string(stream: &mut BytesMut, string: &str) { + write_mqtt_bytes(stream, string.as_bytes()); +} + +/// Writes remaining length to stream and returns number of bytes for remaining length +fn write_remaining_length(stream: &mut BytesMut, len: usize) -> Result { + if len > 268_435_455 { + return Err(Error::PayloadTooLong); + } + + let mut done = false; + let mut x = len; + let mut count = 0; + + while !done { + let mut byte = (x % 128) as u8; + x /= 128; + if x > 0 { + byte |= 128; + } + + stream.put_u8(byte); + count += 1; + done = x == 0; + } + + Ok(count) +} + +/// Return number of remaining length bytes required for encoding length +fn len_len(len: usize) -> usize { + if len >= 2_097_152 { + 4 + } else if len >= 16_384 { + 3 + } else if len >= 128 { + 2 + } else { + 1 + } +} + +/// Maps a number to QoS +pub fn qos(num: u8) -> Result { + match num { + 0 => Ok(QoS::AtMostOnce), + 1 => Ok(QoS::AtLeastOnce), + 2 => Ok(QoS::ExactlyOnce), + qos => Err(Error::InvalidQoS(qos)), + } +} + +/// After collecting enough bytes to frame a packet (packet's frame()) +/// , It's possible that content itself in the stream is wrong. Like expected +/// packet id or qos not being present. In cases where `read_mqtt_string` or +/// `read_mqtt_bytes` exhausted remaining length but packet framing expects to +/// parse qos next, these pre checks will prevent `bytes` crashes +fn read_u16(stream: &mut Bytes) -> Result { + if stream.len() < 2 { + return Err(Error::MalformedPacket); + } + + Ok(stream.get_u16()) +} + +fn read_u8(stream: &mut Bytes) -> Result { + if stream.is_empty() { + return Err(Error::MalformedPacket); + } + + Ok(stream.get_u8()) +} + +fn read_u32(stream: &mut Bytes) -> Result { + if stream.len() < 4 { + return Err(Error::MalformedPacket); + } + + Ok(stream.get_u32()) +} + +impl Display for Error { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "Error = {:?}", self) + } +} diff --git a/rumqttc/src/v5/packet/ping.rs b/rumqttc/src/v5/packet/ping.rs new file mode 100644 index 000000000..1072029bd --- /dev/null +++ b/rumqttc/src/v5/packet/ping.rs @@ -0,0 +1,20 @@ +use super::*; +use bytes::{BufMut, BytesMut}; + +pub struct PingReq; + +impl PingReq { + pub fn write(&self, payload: &mut BytesMut) -> Result { + payload.put_slice(&[0xC0, 0x00]); + Ok(2) + } +} + +pub struct PingResp; + +impl PingResp { + pub fn write(&self, payload: &mut BytesMut) -> Result { + payload.put_slice(&[0xD0, 0x00]); + Ok(2) + } +} diff --git a/rumqttc/src/v5/packet/puback.rs b/rumqttc/src/v5/packet/puback.rs new file mode 100644 index 000000000..51131949e --- /dev/null +++ b/rumqttc/src/v5/packet/puback.rs @@ -0,0 +1,324 @@ +use super::*; +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +/// Return code in connack +#[derive(Debug, Clone, Copy, PartialEq)] +#[repr(u8)] +pub enum PubAckReason { + Success = 0, + NoMatchingSubscribers = 16, + UnspecifiedError = 128, + ImplementationSpecificError = 131, + NotAuthorized = 135, + TopicNameInvalid = 144, + PacketIdentifierInUse = 145, + QuotaExceeded = 151, + PayloadFormatInvalid = 153, +} + +/// Acknowledgement to QoS1 publish +#[derive(Debug, Clone, PartialEq)] +pub struct PubAck { + pub pkid: u16, + pub reason: PubAckReason, + pub properties: Option, +} + +impl PubAck { + pub fn new(pkid: u16) -> PubAck { + PubAck { + pkid, + reason: PubAckReason::Success, + properties: None, + } + } + + fn len(&self) -> usize { + let mut len = 2 + 1; // pkid + reason + + // If there are no properties, sending reason code is optional + if self.reason == PubAckReason::Success && self.properties.is_none() { + return 2; + } + + if let Some(properties) = &self.properties { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + + // Unlike other packets, property length can be ignored if there are + // no properties in acks + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + let pkid = read_u16(&mut bytes)?; + + // No reason code or properties if remaining length == 2 + if fixed_header.remaining_len == 2 { + return Ok(PubAck { + pkid, + reason: PubAckReason::Success, + properties: None, + }); + } + + // No properties len or properties if remaining len > 2 but < 4 + let ack_reason = read_u8(&mut bytes)?; + if fixed_header.remaining_len < 4 { + return Ok(PubAck { + pkid, + reason: reason(ack_reason)?, + properties: None, + }); + } + + let puback = PubAck { + pkid, + reason: reason(ack_reason)?, + properties: PubAckProperties::extract(&mut bytes)?, + }; + + Ok(puback) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + let len = self.len(); + buffer.put_u8(0x40); + + let count = write_remaining_length(buffer, len)?; + buffer.put_u16(self.pkid); + + // Reason code is optional with success if there are no properties + if self.reason == PubAckReason::Success && self.properties.is_none() { + return Ok(4); + } + + buffer.put_u8(self.reason as u8); + if let Some(properties) = &self.properties { + properties.write(buffer)?; + } + + Ok(1 + count + len) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct PubAckProperties { + pub reason_string: Option, + pub user_properties: Vec<(String, String)>, +} + +impl PubAckProperties { + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(reason) = &self.reason_string { + len += 1 + 2 + reason.len(); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut reason_string = None; + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::ReasonString => { + let reason = read_mqtt_string(&mut bytes)?; + cursor += 2 + reason.len(); + reason_string = Some(reason); + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(PubAckProperties { + reason_string, + user_properties, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(reason) = &self.reason_string { + buffer.put_u8(PropertyType::ReasonString as u8); + write_mqtt_string(buffer, reason); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } +} +/// Connection return code type +fn reason(num: u8) -> Result { + let code = match num { + 0 => PubAckReason::Success, + 16 => PubAckReason::NoMatchingSubscribers, + 128 => PubAckReason::UnspecifiedError, + 131 => PubAckReason::ImplementationSpecificError, + 135 => PubAckReason::NotAuthorized, + 144 => PubAckReason::TopicNameInvalid, + 145 => PubAckReason::PacketIdentifierInUse, + 151 => PubAckReason::QuotaExceeded, + 153 => PubAckReason::PayloadFormatInvalid, + num => return Err(Error::InvalidConnectReturnCode(num)), + }; + + Ok(code) +} + +#[cfg(v5)] +#[cfg(test)] +mod test { + use super::*; + use alloc::vec; + use bytes::BytesMut; + use pretty_assertions::assert_eq; + + fn sample() -> PubAck { + let properties = PubAckProperties { + reason_string: Some("test".to_owned()), + user_properties: vec![("test".to_owned(), "test".to_owned())], + }; + + PubAck { + pkid: 42, + reason: PubAckReason::NoMatchingSubscribers, + properties: Some(properties), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0x40, // payload type + 0x18, // remaining length + 0x00, 0x2a, // packet id + 0x10, // reason + 0x14, // properties len + 0x1f, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // reason_string + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // user properties + ] + } + + #[test] + fn puback_parsing_works() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let puback_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let puback = PubAck::read(fixed_header, puback_bytes).unwrap(); + assert_eq!(puback, sample()); + } + + #[test] + fn puback_encoding_works() { + let puback = sample(); + let mut buf = BytesMut::new(); + puback.write(&mut buf).unwrap(); + assert_eq!(&buf[..], sample_bytes()); + } + + fn sample2() -> PubAck { + PubAck { + pkid: 42, + reason: PubAckReason::NoMatchingSubscribers, + properties: None, + } + } + + fn sample2_bytes() -> Vec { + vec![0x40, 0x03, 0x00, 0x2a, 0x10] + } + + #[test] + fn puback2_parsing_works() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample2_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let puback_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let puback = PubAck::read(fixed_header, puback_bytes).unwrap(); + assert_eq!(puback, sample2()); + } + + #[test] + fn puback2_encoding_works() { + let puback = sample2(); + let mut buf = BytesMut::new(); + + puback.write(&mut buf).unwrap(); + assert_eq!(&buf[..], sample2_bytes()); + } + + fn sample3() -> PubAck { + PubAck { + pkid: 42, + reason: PubAckReason::Success, + properties: None, + } + } + + fn sample3_bytes() -> Vec { + vec![0x40, 0x02, 0x00, 0x2a] + } + + #[test] + fn puback3_parsing_works() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample3_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let puback_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let puback = PubAck::read(fixed_header, puback_bytes).unwrap(); + assert_eq!(puback, sample3()); + } + + #[test] + fn puback3_encoding_works() { + let puback = sample3(); + let mut buf = BytesMut::new(); + + puback.write(&mut buf).unwrap(); + assert_eq!(&buf[..], sample3_bytes()); + } +} diff --git a/rumqttc/src/v5/packet/pubcomp.rs b/rumqttc/src/v5/packet/pubcomp.rs new file mode 100644 index 000000000..badb97867 --- /dev/null +++ b/rumqttc/src/v5/packet/pubcomp.rs @@ -0,0 +1,237 @@ +use super::*; +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +/// Return code in connack +#[derive(Debug, Clone, Copy, PartialEq)] +#[repr(u8)] +pub enum PubCompReason { + Success = 0, + PacketIdentifierNotFound = 146, +} + +/// Acknowledgement to QoS1 publish +#[derive(Debug, Clone, PartialEq)] +pub struct PubComp { + pub pkid: u16, + pub reason: PubCompReason, + pub properties: Option, +} + +impl PubComp { + pub fn new(pkid: u16) -> PubComp { + PubComp { + pkid, + reason: PubCompReason::Success, + properties: None, + } + } + + fn len(&self) -> usize { + let mut len = 2 + 1; // pkid + reason + + // If there are no properties during success, sending reason code is optional + if self.reason == PubCompReason::Success && self.properties.is_none() { + return 2; + } + + if let Some(properties) = &self.properties { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + let pkid = read_u16(&mut bytes)?; + + if fixed_header.remaining_len == 2 { + return Ok(PubComp { + pkid, + reason: PubCompReason::Success, + properties: None, + }); + } + + let ack_reason = read_u8(&mut bytes)?; + if fixed_header.remaining_len < 4 { + return Ok(PubComp { + pkid, + reason: reason(ack_reason)?, + properties: None, + }); + } + + let puback = PubComp { + pkid, + reason: reason(ack_reason)?, + properties: PubCompProperties::extract(&mut bytes)?, + }; + + Ok(puback) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + let len = self.len(); + buffer.put_u8(0x70); + let count = write_remaining_length(buffer, len)?; + buffer.put_u16(self.pkid); + + // If there are no properties during success, sending reason code is optional + if self.reason == PubCompReason::Success && self.properties.is_none() { + return Ok(4); + } + + buffer.put_u8(self.reason as u8); + + if let Some(properties) = &self.properties { + properties.write(buffer)?; + } + + Ok(1 + count + len) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct PubCompProperties { + pub reason_string: Option, + pub user_properties: Vec<(String, String)>, +} + +impl PubCompProperties { + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(reason) = &self.reason_string { + len += 1 + 2 + reason.len(); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut reason_string = None; + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::ReasonString => { + let reason = read_mqtt_string(&mut bytes)?; + cursor += 2 + reason.len(); + reason_string = Some(reason); + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(PubCompProperties { + reason_string, + user_properties, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(reason) = &self.reason_string { + buffer.put_u8(PropertyType::ReasonString as u8); + write_mqtt_string(buffer, reason); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } +} +/// Connection return code type +fn reason(num: u8) -> Result { + let code = match num { + 0 => PubCompReason::Success, + 146 => PubCompReason::PacketIdentifierNotFound, + num => return Err(Error::InvalidConnectReturnCode(num)), + }; + + Ok(code) +} + +#[cfg(test)] +mod test { + use super::*; + use bytes::BytesMut; + use pretty_assertions::assert_eq; + + fn sample() -> PubComp { + let properties = PubCompProperties { + reason_string: Some("test".to_owned()), + user_properties: vec![("test".to_owned(), "test".to_owned())], + }; + + PubComp { + pkid: 42, + reason: PubCompReason::PacketIdentifierNotFound, + properties: Some(properties), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0x70, // payload type + 0x18, // remaining length + 0x00, 0x2a, // packet id + 0x92, // reason + 0x14, // properties len + 0x1f, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // reason_string + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // user properties + ] + } + + #[test] + fn pubcomp_parsing_works_correctly() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let pubcomp_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let pubcomp = PubComp::read(fixed_header, pubcomp_bytes).unwrap(); + assert_eq!(pubcomp, sample()); + } + + #[test] + fn pubcomp_encoding_works_correctly() { + let pubcomp = sample(); + let mut buf = BytesMut::new(); + pubcomp.write(&mut buf).unwrap(); + assert_eq!(&buf[..], sample_bytes()); + } +} diff --git a/rumqttc/src/v5/packet/publish.rs b/rumqttc/src/v5/packet/publish.rs new file mode 100644 index 000000000..9f0e228ea --- /dev/null +++ b/rumqttc/src/v5/packet/publish.rs @@ -0,0 +1,394 @@ +use super::*; +use bytes::{Buf, Bytes}; +use core::fmt; + +/// Publish packet +#[derive(Clone, PartialEq)] +pub struct Publish { + pub dup: bool, + pub qos: QoS, + pub retain: bool, + pub topic: String, + pub pkid: u16, + pub properties: Option, + pub payload: Bytes, +} + +impl Publish { + pub fn new, P: Into>>(topic: S, qos: QoS, payload: P) -> Publish { + Publish { + dup: false, + qos, + retain: false, + pkid: 0, + topic: topic.into(), + properties: None, + payload: Bytes::from(payload.into()), + } + } + + pub fn from_bytes>(topic: S, qos: QoS, payload: Bytes) -> Publish { + Publish { + dup: false, + qos, + retain: false, + pkid: 0, + topic: topic.into(), + properties: None, + payload, + } + } + + pub fn len(&self) -> usize { + let mut len = 2 + self.topic.len(); + if self.qos != QoS::AtMostOnce && self.pkid != 0 { + len += 2; + } + + match &self.properties { + Some(properties) => { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + None => { + // just 1 byte representing 0 len + len += 1; + } + } + + len += self.payload.len(); + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let qos = qos((fixed_header.byte1 & 0b0110) >> 1)?; + let dup = (fixed_header.byte1 & 0b1000) != 0; + let retain = (fixed_header.byte1 & 0b0001) != 0; + + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + let topic = read_mqtt_string(&mut bytes)?; + + // Packet identifier exists where QoS > 0 + let pkid = match qos { + QoS::AtMostOnce => 0, + QoS::AtLeastOnce | QoS::ExactlyOnce => read_u16(&mut bytes)?, + }; + + if qos != QoS::AtMostOnce && pkid == 0 { + return Err(Error::PacketIdZero); + } + + let publish = Publish { + dup, + retain, + qos, + pkid, + topic, + properties: PublishProperties::extract(&mut bytes)?, + payload: bytes, + }; + + Ok(publish) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + let len = self.len(); + + let dup = self.dup as u8; + let qos = self.qos as u8; + let retain = self.retain as u8; + buffer.put_u8(0b0011_0000 | retain | qos << 1 | dup << 3); + + let count = write_remaining_length(buffer, len)?; + write_mqtt_string(buffer, self.topic.as_str()); + + if self.qos != QoS::AtMostOnce { + let pkid = self.pkid; + if pkid == 0 { + return Err(Error::PacketIdZero); + } + + buffer.put_u16(pkid); + } + + match &self.properties { + Some(properties) => properties.write(buffer)?, + None => { + write_remaining_length(buffer, 0)?; + } + }; + + buffer.extend_from_slice(&self.payload); + + // TODO: Returned length is wrong in other packets. Fix it + Ok(1 + count + len) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct PublishProperties { + pub payload_format_indicator: Option, + pub message_expiry_interval: Option, + pub topic_alias: Option, + pub response_topic: Option, + pub correlation_data: Option, + pub user_properties: Vec<(String, String)>, + pub subscription_identifiers: Vec, + pub content_type: Option, +} + +impl PublishProperties { + fn len(&self) -> usize { + let mut len = 0; + + if self.payload_format_indicator.is_some() { + len += 1 + 1; + } + + if self.message_expiry_interval.is_some() { + len += 1 + 4; + } + + if self.topic_alias.is_some() { + len += 1 + 2; + } + + if let Some(topic) = &self.response_topic { + len += 1 + 2 + topic.len() + } + + if let Some(data) = &self.correlation_data { + len += 1 + 2 + data.len() + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + for id in self.subscription_identifiers.iter() { + len += 1 + len_len(*id); + } + + if let Some(typ) = &self.content_type { + len += 1 + 2 + typ.len() + } + + len + } + + fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut payload_format_indicator = None; + let mut message_expiry_interval = None; + let mut topic_alias = None; + let mut response_topic = None; + let mut correlation_data = None; + let mut user_properties = Vec::new(); + let mut subscription_identifiers = Vec::new(); + let mut content_type = None; + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::PayloadFormatIndicator => { + payload_format_indicator = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::MessageExpiryInterval => { + message_expiry_interval = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::TopicAlias => { + topic_alias = Some(read_u16(&mut bytes)?); + cursor += 2; + } + PropertyType::ResponseTopic => { + let topic = read_mqtt_string(&mut bytes)?; + cursor += 2 + topic.len(); + response_topic = Some(topic); + } + PropertyType::CorrelationData => { + let data = read_mqtt_bytes(&mut bytes)?; + cursor += 2 + data.len(); + correlation_data = Some(data); + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + PropertyType::SubscriptionIdentifier => { + let (id_len, id) = length(bytes.iter())?; + cursor += 1 + id_len; + bytes.advance(id_len); + subscription_identifiers.push(id); + } + PropertyType::ContentType => { + let typ = read_mqtt_string(&mut bytes)?; + cursor += 2 + typ.len(); + content_type = Some(typ); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(PublishProperties { + payload_format_indicator, + message_expiry_interval, + topic_alias, + response_topic, + correlation_data, + user_properties, + subscription_identifiers, + content_type, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(payload_format_indicator) = self.payload_format_indicator { + buffer.put_u8(PropertyType::PayloadFormatIndicator as u8); + buffer.put_u8(payload_format_indicator); + } + + if let Some(message_expiry_interval) = self.message_expiry_interval { + buffer.put_u8(PropertyType::MessageExpiryInterval as u8); + buffer.put_u32(message_expiry_interval); + } + + if let Some(topic_alias) = self.topic_alias { + buffer.put_u8(PropertyType::TopicAlias as u8); + buffer.put_u16(topic_alias); + } + + if let Some(topic) = &self.response_topic { + buffer.put_u8(PropertyType::ResponseTopic as u8); + write_mqtt_string(buffer, topic); + } + + if let Some(data) = &self.correlation_data { + buffer.put_u8(PropertyType::CorrelationData as u8); + write_mqtt_bytes(buffer, data); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + for id in self.subscription_identifiers.iter() { + buffer.put_u8(PropertyType::SubscriptionIdentifier as u8); + write_remaining_length(buffer, *id)?; + } + + if let Some(typ) = &self.content_type { + buffer.put_u8(PropertyType::ContentType as u8); + write_mqtt_string(buffer, typ); + } + + Ok(()) + } +} + +impl fmt::Debug for Publish { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "Topic = {}, Qos = {:?}, Retain = {}, Pkid = {:?}, Payload Size = {}", + self.topic, + self.qos, + self.retain, + self.pkid, + self.payload.len() + ) + } +} + +#[cfg(test)] +mod test { + use super::*; + use bytes::{Bytes, BytesMut}; + use pretty_assertions::assert_eq; + + fn sample_v5() -> Publish { + let publish_properties = PublishProperties { + payload_format_indicator: Some(1), + message_expiry_interval: Some(4321), + topic_alias: Some(100), + response_topic: Some("topic".to_owned()), + correlation_data: Some(Bytes::from(vec![1, 2, 3, 4])), + user_properties: vec![("test".to_owned(), "test".to_owned())], + subscription_identifiers: vec![120, 121], + content_type: Some("test".to_owned()), + }; + + Publish { + dup: false, + qos: QoS::ExactlyOnce, + retain: false, + topic: "test".to_string(), + pkid: 42, + properties: Some(publish_properties), + payload: Bytes::from(vec![b't', b'e', b's', b't']), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0x34, // payload type + 0x3e, // remaining len + 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // topic name + 0x00, 0x2a, // pkid + 0x31, // properties len + 0x01, 0x01, // payload format indicator + 0x02, 0x00, 0x00, 0x10, 0xe1, // message_expiry_interval + 0x23, 0x00, 0x64, // topic alias + 0x08, 0x00, 0x05, 0x74, 0x6f, 0x70, 0x69, 0x63, // response topic + 0x09, 0x00, 0x04, 0x01, 0x02, 0x03, 0x04, // correlation_data + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // user properties + 0x0b, 0x78, // subscription_identifier + 0x0b, 0x79, // subscription_identifier + 0x03, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // content_type + 0x74, 0x65, 0x73, 0x74, // payload + ] + } + + #[test] + fn publish_parsing_works() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let publish_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let publish = Publish::read(fixed_header, publish_bytes).unwrap(); + assert_eq!(publish, sample_v5()); + } + + #[test] + fn publish_encoding_works() { + let publish = sample_v5(); + let mut buf = BytesMut::new(); + publish.write(&mut buf).unwrap(); + assert_eq!(&buf[..], sample_bytes()); + } + + #[test] + fn missing_properties_are_encoded() {} +} diff --git a/rumqttc/src/v5/packet/pubrec.rs b/rumqttc/src/v5/packet/pubrec.rs new file mode 100644 index 000000000..5e8de572e --- /dev/null +++ b/rumqttc/src/v5/packet/pubrec.rs @@ -0,0 +1,252 @@ +use super::*; +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +/// Return code in connack +#[derive(Debug, Clone, Copy, PartialEq)] +#[repr(u8)] +pub enum PubRecReason { + Success = 0, + NoMatchingSubscribers = 16, + UnspecifiedError = 128, + ImplementationSpecificError = 131, + NotAuthorized = 135, + TopicNameInvalid = 144, + PacketIdentifierInUse = 145, + QuotaExceeded = 151, + PayloadFormatInvalid = 153, +} + +/// Acknowledgement to QoS1 publish +#[derive(Debug, Clone, PartialEq)] +pub struct PubRec { + pub pkid: u16, + pub reason: PubRecReason, + pub properties: Option, +} + +impl PubRec { + pub fn new(pkid: u16) -> PubRec { + PubRec { + pkid, + reason: PubRecReason::Success, + properties: None, + } + } + + fn len(&self) -> usize { + let mut len = 2 + 1; // pkid + reason + + // If there are no properties during success, sending reason code is optional + if self.reason == PubRecReason::Success && self.properties.is_none() { + return 2; + } + + if let Some(properties) = &self.properties { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + + // Unlike other packets, property length can be ignored if there are + // no properties in acks + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + let pkid = read_u16(&mut bytes)?; + if fixed_header.remaining_len == 2 { + return Ok(PubRec { + pkid, + reason: PubRecReason::Success, + properties: None, + }); + } + + let ack_reason = read_u8(&mut bytes)?; + if fixed_header.remaining_len < 4 { + return Ok(PubRec { + pkid, + reason: reason(ack_reason)?, + properties: None, + }); + } + + let puback = PubRec { + pkid, + reason: reason(ack_reason)?, + properties: PubRecProperties::extract(&mut bytes)?, + }; + + Ok(puback) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + let len = self.len(); + buffer.put_u8(0x50); + let count = write_remaining_length(buffer, len)?; + buffer.put_u16(self.pkid); + + // If there are no properties during success, sending reason code is optional + if self.reason == PubRecReason::Success && self.properties.is_none() { + return Ok(4); + } + + buffer.put_u8(self.reason as u8); + + if let Some(properties) = &self.properties { + properties.write(buffer)?; + } + + Ok(1 + count + len) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct PubRecProperties { + pub reason_string: Option, + pub user_properties: Vec<(String, String)>, +} + +impl PubRecProperties { + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(reason) = &self.reason_string { + len += 1 + 2 + reason.len(); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut reason_string = None; + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::ReasonString => { + let reason = read_mqtt_string(&mut bytes)?; + cursor += 2 + reason.len(); + reason_string = Some(reason); + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(PubRecProperties { + reason_string, + user_properties, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(reason) = &self.reason_string { + buffer.put_u8(PropertyType::ReasonString as u8); + write_mqtt_string(buffer, reason); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } +} +/// Connection return code type +fn reason(num: u8) -> Result { + let code = match num { + 0 => PubRecReason::Success, + 16 => PubRecReason::NoMatchingSubscribers, + 128 => PubRecReason::UnspecifiedError, + 131 => PubRecReason::ImplementationSpecificError, + 135 => PubRecReason::NotAuthorized, + 144 => PubRecReason::TopicNameInvalid, + 145 => PubRecReason::PacketIdentifierInUse, + 151 => PubRecReason::QuotaExceeded, + 153 => PubRecReason::PayloadFormatInvalid, + num => return Err(Error::InvalidConnectReturnCode(num)), + }; + + Ok(code) +} + +#[cfg(test)] +mod test { + use super::*; + use bytes::BytesMut; + use pretty_assertions::assert_eq; + + fn sample() -> PubRec { + let properties = PubRecProperties { + reason_string: Some("test".to_owned()), + user_properties: vec![("test".to_owned(), "test".to_owned())], + }; + + PubRec { + pkid: 42, + reason: PubRecReason::NoMatchingSubscribers, + properties: Some(properties), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0x50, // payload type + 0x18, // remaining length + 0x00, 0x2a, // packet id + 0x10, // reason + 0x14, // properties len + 0x1f, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // reason_string + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // user properties + ] + } + + #[test] + fn pubrec_parsing_works() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let pubrec_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let pubrec = PubRec::read(fixed_header, pubrec_bytes).unwrap(); + assert_eq!(pubrec, sample()); + } + + #[test] + fn pubrec_encoding_works() { + let pubrec = sample(); + let mut buf = BytesMut::new(); + pubrec.write(&mut buf).unwrap(); + assert_eq!(&buf[..], sample_bytes()); + } +} diff --git a/rumqttc/src/v5/packet/pubrel.rs b/rumqttc/src/v5/packet/pubrel.rs new file mode 100644 index 000000000..1a1a62e4d --- /dev/null +++ b/rumqttc/src/v5/packet/pubrel.rs @@ -0,0 +1,236 @@ +use super::*; +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +/// Return code in connack +#[derive(Debug, Clone, Copy, PartialEq)] +#[repr(u8)] +pub enum PubRelReason { + Success = 0, + PacketIdentifierNotFound = 146, +} + +/// Acknowledgement to QoS1 publish +#[derive(Debug, Clone, PartialEq)] +pub struct PubRel { + pub pkid: u16, + pub reason: PubRelReason, + pub properties: Option, +} + +impl PubRel { + pub fn new(pkid: u16) -> PubRel { + PubRel { + pkid, + reason: PubRelReason::Success, + properties: None, + } + } + + fn len(&self) -> usize { + let mut len = 2 + 1; // pkid + reason + + // If there are no properties during success, sending reason code is optional + if self.reason == PubRelReason::Success && self.properties.is_none() { + return 2; + } + + if let Some(properties) = &self.properties { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + let pkid = read_u16(&mut bytes)?; + if fixed_header.remaining_len == 2 { + return Ok(PubRel { + pkid, + reason: PubRelReason::Success, + properties: None, + }); + } + + let ack_reason = read_u8(&mut bytes)?; + if fixed_header.remaining_len < 4 { + return Ok(PubRel { + pkid, + reason: reason(ack_reason)?, + properties: None, + }); + } + + let puback = PubRel { + pkid, + reason: reason(ack_reason)?, + properties: PubRelProperties::extract(&mut bytes)?, + }; + + Ok(puback) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + let len = self.len(); + buffer.put_u8(0x62); + let count = write_remaining_length(buffer, len)?; + buffer.put_u16(self.pkid); + + // If there are no properties during success, sending reason code is optional + if self.reason == PubRelReason::Success && self.properties.is_none() { + return Ok(4); + } + + buffer.put_u8(self.reason as u8); + + if let Some(properties) = &self.properties { + properties.write(buffer)?; + } + + Ok(1 + count + len) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct PubRelProperties { + pub reason_string: Option, + pub user_properties: Vec<(String, String)>, +} + +impl PubRelProperties { + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(reason) = &self.reason_string { + len += 1 + 2 + reason.len(); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut reason_string = None; + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::ReasonString => { + let reason = read_mqtt_string(&mut bytes)?; + cursor += 2 + reason.len(); + reason_string = Some(reason); + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(PubRelProperties { + reason_string, + user_properties, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(reason) = &self.reason_string { + buffer.put_u8(PropertyType::ReasonString as u8); + write_mqtt_string(buffer, reason); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } +} +/// Connection return code type +fn reason(num: u8) -> Result { + let code = match num { + 0 => PubRelReason::Success, + 146 => PubRelReason::PacketIdentifierNotFound, + num => return Err(Error::InvalidConnectReturnCode(num)), + }; + + Ok(code) +} + +#[cfg(test)] +mod test { + use super::*; + use bytes::BytesMut; + use pretty_assertions::assert_eq; + + fn sample() -> PubRel { + let properties = PubRelProperties { + reason_string: Some("test".to_owned()), + user_properties: vec![("test".to_owned(), "test".to_owned())], + }; + + PubRel { + pkid: 42, + reason: PubRelReason::PacketIdentifierNotFound, + properties: Some(properties), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0x62, // payload type + 0x18, // remaining length + 0x00, 0x2a, // packet id + 0x92, // reason + 0x14, // properties len + 0x1f, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // reason_string + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // user properties + ] + } + + #[test] + fn pubrel_parsing_works() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let pubrel_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let pubrel = PubRel::read(fixed_header, pubrel_bytes).unwrap(); + assert_eq!(pubrel, sample()); + } + + #[test] + fn pubrel_encoding_works() { + let pubrel = sample(); + let mut buf = BytesMut::new(); + pubrel.write(&mut buf).unwrap(); + assert_eq!(&buf[..], sample_bytes()); + } +} diff --git a/rumqttc/src/v5/packet/suback.rs b/rumqttc/src/v5/packet/suback.rs new file mode 100644 index 000000000..0ec0ead05 --- /dev/null +++ b/rumqttc/src/v5/packet/suback.rs @@ -0,0 +1,263 @@ +use super::*; +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use std::convert::{TryFrom, TryInto}; + +/// Acknowledgement to subscribe +#[derive(Debug, Clone, PartialEq)] +pub struct SubAck { + pub pkid: u16, + pub return_codes: Vec, + pub properties: Option, +} + +impl SubAck { + pub fn new(pkid: u16, return_codes: Vec) -> SubAck { + SubAck { + pkid, + return_codes, + properties: None, + } + } + + pub fn len(&self) -> usize { + let mut len = 2 + self.return_codes.len(); + + match &self.properties { + Some(properties) => { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + None => { + // just 1 byte representing 0 len + len += 1; + } + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + let pkid = read_u16(&mut bytes)?; + let properties = SubAckProperties::extract(&mut bytes)?; + + if !bytes.has_remaining() { + return Err(Error::MalformedPacket); + } + + let mut return_codes = Vec::new(); + while bytes.has_remaining() { + let return_code = read_u8(&mut bytes)?; + return_codes.push(return_code.try_into()?); + } + + let suback = SubAck { + pkid, + return_codes, + properties, + }; + + Ok(suback) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + buffer.put_u8(0x90); + let remaining_len = self.len(); + let remaining_len_bytes = write_remaining_length(buffer, remaining_len)?; + + buffer.put_u16(self.pkid); + + match &self.properties { + Some(properties) => properties.write(buffer)?, + None => { + write_remaining_length(buffer, 0)?; + } + }; + + let p: Vec = self.return_codes.iter().map(|code| *code as u8).collect(); + buffer.extend_from_slice(&p); + Ok(1 + remaining_len_bytes + remaining_len) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct SubAckProperties { + pub reason_string: Option, + pub user_properties: Vec<(String, String)>, +} + +impl SubAckProperties { + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(reason) = &self.reason_string { + len += 1 + 2 + reason.len(); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut reason_string = None; + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::ReasonString => { + let reason = read_mqtt_string(&mut bytes)?; + cursor += 2 + reason.len(); + reason_string = Some(reason); + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(SubAckProperties { + reason_string, + user_properties, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(reason) = &self.reason_string { + buffer.put_u8(PropertyType::ReasonString as u8); + write_mqtt_string(buffer, reason); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } +} + +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SubscribeReasonCode { + QoS0 = 0, + QoS1 = 1, + QoS2 = 2, + Unspecified = 128, + ImplementationSpecific = 131, + NotAuthorized = 135, + TopicFilterInvalid = 143, + PkidInUse = 145, + QuotaExceeded = 151, + SharedSubscriptionsNotSupported = 158, + SubscriptionIdNotSupported = 161, + WildcardSubscriptionsNotSupported = 162, +} + +impl TryFrom for SubscribeReasonCode { + type Error = super::Error; + + fn try_from(value: u8) -> Result { + let v = match value { + 0 => SubscribeReasonCode::QoS0, + 1 => SubscribeReasonCode::QoS1, + 2 => SubscribeReasonCode::QoS2, + 128 => SubscribeReasonCode::Unspecified, + 131 => SubscribeReasonCode::ImplementationSpecific, + 135 => SubscribeReasonCode::NotAuthorized, + 143 => SubscribeReasonCode::TopicFilterInvalid, + 145 => SubscribeReasonCode::PkidInUse, + 151 => SubscribeReasonCode::QuotaExceeded, + 158 => SubscribeReasonCode::SharedSubscriptionsNotSupported, + 161 => SubscribeReasonCode::SubscriptionIdNotSupported, + 162 => SubscribeReasonCode::WildcardSubscriptionsNotSupported, + v => return Err(super::Error::InvalidSubscribeReasonCode(v)), + }; + + Ok(v) + } +} + +#[cfg(test)] +mod test { + use super::*; + use bytes::BytesMut; + use pretty_assertions::assert_eq; + + fn sample() -> SubAck { + let properties = SubAckProperties { + reason_string: Some("test".to_owned()), + user_properties: vec![("test".to_owned(), "test".to_owned())], + }; + + SubAck { + pkid: 42, + return_codes: vec![ + SubscribeReasonCode::QoS0, + SubscribeReasonCode::QoS1, + SubscribeReasonCode::QoS2, + SubscribeReasonCode::Unspecified, + ], + properties: Some(properties), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0x90, // packet type + 0x1b, // remaining len + 0x00, 0x2a, // pkid + 0x14, // properties len + 0x1f, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, + 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // user properties + 0x00, 0x01, 0x02, 0x80, // return codes + ] + } + + #[test] + fn suback_parsing_works() { + let mut stream = BytesMut::new(); + let packetstream = &sample_bytes(); + + stream.extend_from_slice(&packetstream[..]); + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let suback_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let suback = SubAck::read(fixed_header, suback_bytes).unwrap(); + assert_eq!(suback, sample()); + } + + #[test] + fn suback_encoding_works() { + let publish = sample(); + let mut buf = BytesMut::new(); + publish.write(&mut buf).unwrap(); + + // println!("{:X?}", buf); + // println!("{:#04X?}", &buf[..]); + assert_eq!(&buf[..], sample_bytes()); + } +} diff --git a/rumqttc/src/v5/packet/subscribe.rs b/rumqttc/src/v5/packet/subscribe.rs new file mode 100644 index 000000000..9a5c43d1e --- /dev/null +++ b/rumqttc/src/v5/packet/subscribe.rs @@ -0,0 +1,425 @@ +use super::*; +use bytes::{Buf, Bytes}; +use core::fmt; + +/// Subscription packet +#[derive(Clone, PartialEq)] +pub struct Subscribe { + pub pkid: u16, + pub filters: Vec, + pub properties: Option, +} + +impl Subscribe { + pub fn new>(path: S, qos: QoS) -> Subscribe { + let filter = SubscribeFilter { + path: path.into(), + qos, + nolocal: false, + preserve_retain: false, + retain_forward_rule: RetainForwardRule::OnEverySubscribe, + }; + + Subscribe { + pkid: 0, + filters: vec![filter], + properties: None, + } + } + + pub fn new_many(topics: T) -> Subscribe + where + T: IntoIterator, + { + Subscribe { + pkid: 0, + filters: topics.into_iter().collect(), + properties: None, + } + } + + pub fn empty_subscribe() -> Subscribe { + Subscribe { + pkid: 0, + filters: Vec::new(), + properties: None, + } + } + + pub fn add(&mut self, path: String, qos: QoS) -> &mut Self { + let filter = SubscribeFilter { + path, + qos, + nolocal: false, + preserve_retain: false, + retain_forward_rule: RetainForwardRule::OnEverySubscribe, + }; + + self.filters.push(filter); + self + } + + pub fn len(&self) -> usize { + let mut len = 2 + self.filters.iter().fold(0, |s, t| s + t.len()); + + if let Some(properties) = &self.properties { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } else { + // just 1 byte representing 0 len + len += 1; + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + let pkid = read_u16(&mut bytes)?; + let properties = SubscribeProperties::extract(&mut bytes)?; + + // variable header size = 2 (packet identifier) + let mut filters = Vec::new(); + + while bytes.has_remaining() { + let path = read_mqtt_string(&mut bytes)?; + let options = read_u8(&mut bytes)?; + let requested_qos = options & 0b0000_0011; + + let nolocal = options >> 2 & 0b0000_0001; + let nolocal = nolocal != 0; + + let preserve_retain = options >> 3 & 0b0000_0001; + let preserve_retain = preserve_retain != 0; + + let retain_forward_rule = (options >> 4) & 0b0000_0011; + let retain_forward_rule = match retain_forward_rule { + 0 => RetainForwardRule::OnEverySubscribe, + 1 => RetainForwardRule::OnNewSubscribe, + 2 => RetainForwardRule::Never, + r => return Err(Error::InvalidRetainForwardRule(r)), + }; + + filters.push(SubscribeFilter { + path, + qos: qos(requested_qos)?, + nolocal, + preserve_retain, + retain_forward_rule, + }); + } + + let subscribe = Subscribe { + pkid, + filters, + properties, + }; + + Ok(subscribe) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + // write packet type + buffer.put_u8(0x82); + + // write remaining length + let remaining_len = self.len(); + let remaining_len_bytes = write_remaining_length(buffer, remaining_len)?; + + // write packet id + buffer.put_u16(self.pkid); + + match &self.properties { + Some(properties) => properties.write(buffer)?, + None => { + write_remaining_length(buffer, 0)?; + } + }; + + // write filters + for filter in self.filters.iter() { + filter.write(buffer); + } + + Ok(1 + remaining_len_bytes + remaining_len) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct SubscribeProperties { + pub id: Option, + pub user_properties: Vec<(String, String)>, +} + +impl SubscribeProperties { + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(id) = &self.id { + len += 1 + len_len(*id); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut id = None; + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::SubscriptionIdentifier => { + let (id_len, sub_id) = length(bytes.iter())?; + // TODO: Validate 1 +. Tests are working either way + cursor += 1 + id_len; + bytes.advance(id_len); + id = Some(sub_id) + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(SubscribeProperties { + id, + user_properties, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(id) = &self.id { + buffer.put_u8(PropertyType::SubscriptionIdentifier as u8); + write_remaining_length(buffer, *id)?; + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } +} + +/// Subscription filter +#[derive(Clone, PartialEq)] +pub struct SubscribeFilter { + pub path: String, + pub qos: QoS, + pub nolocal: bool, + pub preserve_retain: bool, + pub retain_forward_rule: RetainForwardRule, +} + +impl SubscribeFilter { + pub fn new(path: String, qos: QoS) -> SubscribeFilter { + SubscribeFilter { + path, + qos, + nolocal: false, + preserve_retain: false, + retain_forward_rule: RetainForwardRule::OnEverySubscribe, + } + } + + pub fn set_nolocal(&mut self, flag: bool) -> &mut Self { + self.nolocal = flag; + self + } + + pub fn set_preserve_retain(&mut self, flag: bool) -> &mut Self { + self.preserve_retain = flag; + self + } + + pub fn set_retain_forward_rule(&mut self, rule: RetainForwardRule) -> &mut Self { + self.retain_forward_rule = rule; + self + } + + pub fn len(&self) -> usize { + // filter len + filter + options + 2 + self.path.len() + 1 + } + + fn write(&self, buffer: &mut BytesMut) { + let mut options = 0; + options |= self.qos as u8; + + if self.nolocal { + options |= 1 << 2; + } + + if self.preserve_retain { + options |= 1 << 3; + } + + match self.retain_forward_rule { + RetainForwardRule::OnEverySubscribe => options |= 0 << 4, + RetainForwardRule::OnNewSubscribe => options |= 1 << 4, + RetainForwardRule::Never => options |= 2 << 4, + } + + write_mqtt_string(buffer, self.path.as_str()); + buffer.put_u8(options); + } +} + +#[derive(Debug, Clone, PartialEq)] +pub enum RetainForwardRule { + OnEverySubscribe, + OnNewSubscribe, + Never, +} + +impl fmt::Debug for Subscribe { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "Filters = {:?}, Packet id = {:?}", + self.filters, self.pkid + ) + } +} + +impl fmt::Debug for SubscribeFilter { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "Filter = {}, Qos = {:?}, Nolocal = {}, Preserve retain = {}, Forward rule = {:?}", + self.path, self.qos, self.nolocal, self.preserve_retain, self.retain_forward_rule + ) + } +} + +#[cfg(test)] +mod test { + use super::*; + use bytes::BytesMut; + use pretty_assertions::assert_eq; + + fn sample() -> Subscribe { + let subscribe_properties = SubscribeProperties { + id: Some(100), + user_properties: vec![("test".to_owned(), "test".to_owned())], + }; + + let mut filter = SubscribeFilter::new("hello".to_owned(), QoS::AtLeastOnce); + filter + .set_nolocal(true) + .set_preserve_retain(true) + .set_retain_forward_rule(RetainForwardRule::Never); + + Subscribe { + pkid: 42, + filters: vec![filter], + properties: Some(subscribe_properties), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0x82, // packet type + 0x1a, // remaining length + 0x00, 0x2a, // pkid + 0x0f, // properties len + 0x0b, 0x64, // subscription identifier + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // user properties + 0x00, 0x05, 0x68, 0x65, 0x6c, 0x6c, 0x6f, // filter + 0x2d, // options + ] + } + + #[test] + fn subscribe_parsing_works() { + let mut stream = BytesMut::new(); + let packetstream = &sample_bytes(); + + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let subscribe_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let subscribe = Subscribe::read(fixed_header, subscribe_bytes).unwrap(); + assert_eq!(subscribe, sample()); + } + + #[test] + fn subscribe_encoding_works() { + let publish = sample(); + let mut buf = BytesMut::new(); + publish.write(&mut buf).unwrap(); + + // println!("{:X?}", buf); + // println!("{:#04X?}", &buf[..]); + assert_eq!(&buf[..], sample_bytes()); + } + + fn sample2() -> Subscribe { + let filter = SubscribeFilter::new("hello/world".to_owned(), QoS::AtLeastOnce); + Subscribe { + pkid: 42, + filters: vec![filter], + properties: None, + } + } + + fn sample2_bytes() -> Vec { + vec![ + 0x82, 0x11, 0x00, 0x2a, 0x00, 0x00, 0x0b, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x2f, 0x77, + 0x6f, 0x72, 0x6c, 0x64, 0x01, + ] + } + + #[test] + fn subscribe2_parsing_works() { + let mut stream = BytesMut::new(); + let packetstream = &sample2_bytes(); + + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let subscribe_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let subscribe = Subscribe::read(fixed_header, subscribe_bytes).unwrap(); + assert_eq!(subscribe, sample2()); + } + + #[test] + fn subscribe2_encoding_works() { + let publish = sample2(); + let mut buf = BytesMut::new(); + publish.write(&mut buf).unwrap(); + + // println!("{:X?}", buf); + // println!("{:#04X?}", &buf[..]); + assert_eq!(&buf[..], sample2_bytes()); + } +} diff --git a/rumqttc/src/v5/packet/unsuback.rs b/rumqttc/src/v5/packet/unsuback.rs new file mode 100644 index 000000000..ce01bb4cd --- /dev/null +++ b/rumqttc/src/v5/packet/unsuback.rs @@ -0,0 +1,249 @@ +use super::*; +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +//// Return code in connack +#[derive(Debug, Clone, Copy, PartialEq)] +#[repr(u8)] +pub enum UnsubAckReason { + Success = 0x00, + NoSubscriptionExisted = 0x11, + UnspecifiedError = 0x80, + ImplementationSpecificError = 0x83, + NotAuthorized = 0x87, + TopicFilterInvalid = 0x8F, + PacketIdentifierInUse = 0x91, +} + +/// Acknowledgement to unsubscribe +#[derive(Debug, Clone, PartialEq)] +pub struct UnsubAck { + pub pkid: u16, + pub reasons: Vec, + pub properties: Option, +} + +impl UnsubAck { + pub fn new(pkid: u16) -> UnsubAck { + UnsubAck { + pkid, + reasons: Vec::new(), + properties: None, + } + } + + pub fn len(&self) -> usize { + let mut len = 2 + self.reasons.len(); + + match &self.properties { + Some(properties) => { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + None => { + // just 1 byte representing 0 len + len += 1; + } + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + let pkid = read_u16(&mut bytes)?; + let properties = UnsubAckProperties::extract(&mut bytes)?; + + if !bytes.has_remaining() { + return Err(Error::MalformedPacket); + } + + let mut reasons = Vec::new(); + while bytes.has_remaining() { + let r = read_u8(&mut bytes)?; + reasons.push(reason(r)?); + } + + let unsuback = UnsubAck { + pkid, + reasons, + properties, + }; + + Ok(unsuback) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + buffer.put_u8(0xB0); + let remaining_len = self.len(); + let remaining_len_bytes = write_remaining_length(buffer, remaining_len)?; + + buffer.put_u16(self.pkid); + + match &self.properties { + Some(properties) => properties.write(buffer)?, + None => { + write_remaining_length(buffer, 0)?; + } + }; + + let p: Vec = self.reasons.iter().map(|code| *code as u8).collect(); + buffer.extend_from_slice(&p); + Ok(1 + remaining_len_bytes + remaining_len) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct UnsubAckProperties { + pub reason_string: Option, + pub user_properties: Vec<(String, String)>, +} + +impl UnsubAckProperties { + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(reason) = &self.reason_string { + len += 1 + 2 + reason.len(); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut reason_string = None; + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::ReasonString => { + let reason = read_mqtt_string(&mut bytes)?; + cursor += 2 + reason.len(); + reason_string = Some(reason); + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(UnsubAckProperties { + reason_string, + user_properties, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(reason) = &self.reason_string { + buffer.put_u8(PropertyType::ReasonString as u8); + write_mqtt_string(buffer, reason); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } +} + +/// Connection return code type +fn reason(num: u8) -> Result { + let code = match num { + 0x00 => UnsubAckReason::Success, + 0x11 => UnsubAckReason::NoSubscriptionExisted, + 0x80 => UnsubAckReason::UnspecifiedError, + 0x83 => UnsubAckReason::ImplementationSpecificError, + 0x87 => UnsubAckReason::NotAuthorized, + 0x8F => UnsubAckReason::TopicFilterInvalid, + 0x91 => UnsubAckReason::PacketIdentifierInUse, + num => return Err(Error::InvalidSubscribeReasonCode(num)), + }; + + Ok(code) +} + +#[cfg(test)] +mod test { + use super::*; + use bytes::BytesMut; + use pretty_assertions::assert_eq; + + fn sample() -> UnsubAck { + let properties = UnsubAckProperties { + reason_string: Some("test".to_owned()), + user_properties: vec![("test".to_owned(), "test".to_owned())], + }; + + UnsubAck { + pkid: 10, + reasons: vec![ + UnsubAckReason::NotAuthorized, + UnsubAckReason::TopicFilterInvalid, + ], + properties: Some(properties), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0xb0, // packet type + 0x19, // remaining len + 0x00, 0x0a, // pkid + 0x14, // properties len + 0x1f, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // reason string + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // user properties + 0x87, 0x8f, // reasons + ] + } + + #[test] + fn unsuback_parsing_works() { + let mut stream = BytesMut::new(); + let packetstream = &sample_bytes(); + + stream.extend_from_slice(&packetstream[..]); + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let unsuback_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let unsuback = UnsubAck::read(fixed_header, unsuback_bytes).unwrap(); + assert_eq!(unsuback, sample()); + } + + #[test] + fn unsuback_encoding_works() { + let publish = sample(); + let mut buf = BytesMut::new(); + publish.write(&mut buf).unwrap(); + + // println!("{:X?}", buf); + // println!("{:#04X?}", &buf[..]); + assert_eq!(&buf[..], sample_bytes()); + } +} diff --git a/rumqttc/src/v5/packet/unsubscribe.rs b/rumqttc/src/v5/packet/unsubscribe.rs new file mode 100644 index 000000000..8700c5132 --- /dev/null +++ b/rumqttc/src/v5/packet/unsubscribe.rs @@ -0,0 +1,238 @@ +use super::*; +use bytes::{Buf, Bytes}; + +/// Unsubscribe packet +#[derive(Debug, Clone, PartialEq)] +pub struct Unsubscribe { + pub pkid: u16, + pub filters: Vec, + pub properties: Option, +} + +impl Unsubscribe { + pub fn new>(topic: S) -> Unsubscribe { + Unsubscribe { + pkid: 0, + filters: vec![topic.into()], + properties: None, + } + } + + pub fn len(&self) -> usize { + // Packet id + length of filters (unlike subscribe, this just a string. + // Hence 2 is prefixed for len per filter) + let mut len = 2 + self.filters.iter().fold(0, |s, t| 2 + s + t.len()); + + if let Some(properties) = &self.properties { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } else { + // just 1 byte representing 0 len + len += 1; + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + let pkid = read_u16(&mut bytes)?; + dbg!(pkid); + let properties = UnsubscribeProperties::extract(&mut bytes)?; + + let mut filters = Vec::with_capacity(1); + while bytes.has_remaining() { + let filter = read_mqtt_string(&mut bytes)?; + filters.push(filter); + } + + let unsubscribe = Unsubscribe { + pkid, + filters, + properties, + }; + Ok(unsubscribe) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + buffer.put_u8(0xA2); + + // write remaining length + let remaining_len = self.len(); + let remaining_len_bytes = write_remaining_length(buffer, remaining_len)?; + + // write packet id + buffer.put_u16(self.pkid); + + match &self.properties { + Some(properties) => properties.write(buffer)?, + None => { + write_remaining_length(buffer, 0)?; + } + }; + + // write filters + for filter in self.filters.iter() { + write_mqtt_string(buffer, filter); + } + + Ok(1 + remaining_len_bytes + remaining_len) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct UnsubscribeProperties { + pub user_properties: Vec<(String, String)>, +} + +impl UnsubscribeProperties { + fn len(&self) -> usize { + let mut len = 0; + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(UnsubscribeProperties { user_properties })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use super::*; + use bytes::BytesMut; + use pretty_assertions::assert_eq; + + fn sample() -> Unsubscribe { + let properties = UnsubscribeProperties { + user_properties: vec![("test".to_owned(), "test".to_owned())], + }; + + Unsubscribe { + pkid: 10, + filters: vec!["hello".to_owned(), "world".to_owned()], + properties: Some(properties), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0xa2, // packet type + 0x1e, // remaining len + 0x00, 0x0a, // pkid + 0x0d, // properties len + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // user properties + 0x00, 0x05, 0x68, 0x65, 0x6c, 0x6c, 0x6f, // filter 1 + 0x00, 0x05, 0x77, 0x6f, 0x72, 0x6c, 0x64, // filter 2 + ] + } + + #[test] + fn unsubscribe_parsing_works() { + let mut stream = BytesMut::new(); + let packetstream = &sample_bytes(); + + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let unsubscribe_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let unsubscribe = Unsubscribe::read(fixed_header, unsubscribe_bytes).unwrap(); + assert_eq!(unsubscribe, sample()); + } + + #[test] + fn subscribe_encoding_works() { + let publish = sample(); + let mut buf = BytesMut::new(); + publish.write(&mut buf).unwrap(); + + println!("{:X?}", buf); + println!("{:#04X?}", &buf[..]); + assert_eq!(&buf[..], sample_bytes()); + } + + fn sample2() -> Unsubscribe { + Unsubscribe { + pkid: 10, + filters: vec!["hello".to_owned()], + properties: None, + } + } + + fn sample2_bytes() -> Vec { + vec![ + 0xa2, 0x0a, 0x00, 0x0a, 0x00, 0x00, 0x05, 0x68, 0x65, 0x6c, 0x6c, 0x6f, + ] + } + + #[test] + fn subscribe2_parsing_works() { + let mut stream = BytesMut::new(); + let packetstream = &sample2_bytes(); + + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let subscribe_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let subscribe = Unsubscribe::read(fixed_header, subscribe_bytes).unwrap(); + assert_eq!(subscribe, sample2()); + } + + #[test] + fn subscribe2_encoding_works() { + let publish = sample2(); + let mut buf = BytesMut::new(); + publish.write(&mut buf).unwrap(); + + // println!("{:X?}", buf); + // println!("{:#04X?}", &buf[..]); + assert_eq!(&buf[..], sample2_bytes()); + } +} diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs new file mode 100644 index 000000000..b6a8ebfd9 --- /dev/null +++ b/rumqttc/src/v5/state.rs @@ -0,0 +1,759 @@ +use std::{ + collections::VecDeque, + io, mem, + sync::{Arc, Mutex}, + time::Instant, +}; + +use bytes::BytesMut; + +use crate::v5::{outgoing_buf::OutgoingBuf, packet::*, Incoming, Request}; + +/// Errors during state handling +#[derive(Debug, thiserror::Error)] +pub enum StateError { + /// Io Error while state is passed to network + #[error("Io error {0:?}")] + Io(#[from] io::Error), + /// Broker's error reply to client's connect packet + #[error("Connect return code `{0:?}`")] + Connect(ConnectReturnCode), + /// Invalid state for a given operation + #[error("Invalid state for a given operation")] + InvalidState, + /// Received a packet (ack) which isn't asked for + #[error("Received unsolicited ack pkid {0}")] + Unsolicited(u16), + /// Last pingreq isn't acked + #[error("Last pingreq isn't acked")] + AwaitPingResp, + /// Received a wrong packet while waiting for another packet + #[error("Received a wrong packet while waiting for another packet")] + WrongPacket, + #[error("Timeout while waiting to resolve collision")] + CollisionTimeout, + #[error("Mqtt serialization/deserialization error")] + Deserialization(Error), +} + +impl From for StateError { + fn from(e: Error) -> StateError { + StateError::Deserialization(e) + } +} + +/// State of the mqtt connection. +// Design: Methods will just modify the state of the object without doing any network operations +// Design: All inflight queues are maintained in a pre initialized vec with index as packet id. +// This is done for 2 reasons +// Bad acks or out of order acks aren't O(n) causing cpu spikes +// Any missing acks from the broker are detected during the next recycled use of packet ids +#[derive(Debug, Clone)] +pub struct MqttState { + /// Status of last ping + pub await_pingresp: bool, + /// Collision ping count. Collisions stop user requests + /// which inturn trigger pings. Multiple pings without + /// resolving collisions will result in error + pub collision_ping_count: usize, + /// Last incoming packet time + last_incoming: Instant, + /// Last outgoing packet time + last_outgoing: Instant, + /// Number of outgoing inflight publishes + pub(crate) inflight: u16, + /// Outgoing QoS 1, 2 publishes which aren't acked yet + pub(crate) outgoing_pub: Vec>, + /// Packet ids of released QoS 2 publishes + pub(crate) outgoing_rel: Vec>, + /// Packet ids on incoming QoS 2 publishes + pub(crate) incoming_pub: Vec>, + /// Last collision due to broker not acking in order + pub collision: Option, + /// Write buffer + pub write: BytesMut, + /// Indicates if acknowledgements should be send immediately + pub manual_acks: bool, + pub(crate) incoming_buf: Arc>>, + pub(crate) outgoing_buf: Arc>, +} + +impl MqttState { + /// Creates new mqtt state. Same state should be used during a + /// connection for persistent sessions while new state should + /// instantiated for clean sessions + pub fn new(max_inflight: u16, manual_acks: bool, cap: usize) -> Self { + MqttState { + await_pingresp: false, + collision_ping_count: 0, + last_incoming: Instant::now(), + last_outgoing: Instant::now(), + inflight: 0, + // index 0 is wasted as 0 is not a valid packet id + outgoing_pub: vec![None; max_inflight as usize + 1], + outgoing_rel: vec![None; max_inflight as usize + 1], + incoming_pub: vec![None; std::u16::MAX as usize + 1], + collision: None, + // TODO: Optimize these sizes later + write: BytesMut::with_capacity(10 * 1024), + manual_acks, + incoming_buf: Arc::new(Mutex::new(VecDeque::with_capacity(cap))), + outgoing_buf: OutgoingBuf::new(max_inflight as usize), + } + } + + /// Returns inflight outgoing packets and clears internal queues + pub fn clean(&mut self) -> Vec { + let mut pending = Vec::with_capacity(100); + // remove and collect pending publishes + for publish in self.outgoing_pub.iter_mut() { + if let Some(publish) = publish.take() { + let request = Request::Publish(publish); + pending.push(request); + } + } + + // remove and collect pending releases + for rel in self.outgoing_rel.iter_mut() { + if let Some(pkid) = rel.take() { + let request = Request::PubRel(PubRel::new(pkid)); + pending.push(request); + } + } + + // remove packed ids of incoming qos2 publishes + for id in self.incoming_pub.iter_mut() { + id.take(); + } + + self.await_pingresp = false; + self.collision_ping_count = 0; + self.inflight = 0; + pending + } + + #[inline] + pub fn inflight(&self) -> u16 { + self.inflight + } + + #[inline] + pub fn cur_pkid(&self) -> u16 { + self.outgoing_buf.lock().unwrap().pkid_counter + } + + /// Consolidates handling of all outgoing mqtt packet logic. Returns a packet which should + /// be put on to the network by the eventloop + pub fn handle_outgoing_packet(&mut self, request: Request) -> Result<(), StateError> { + match request { + Request::Publish(publish) => self.outgoing_publish(publish)?, + Request::PubRel(pubrel) => self.outgoing_pubrel(pubrel)?, + Request::Subscribe(subscribe) => self.outgoing_subscribe(subscribe)?, + Request::Unsubscribe(unsubscribe) => self.outgoing_unsubscribe(unsubscribe)?, + Request::PingReq => self.outgoing_ping()?, + Request::Disconnect => self.outgoing_disconnect()?, + Request::PubAck(puback) => self.outgoing_puback(puback)?, + Request::PubRec(pubrec) => self.outgoing_pubrec(pubrec)?, + _ => unimplemented!(), + }; + + self.last_outgoing = Instant::now(); + Ok(()) + } + + /// Consolidates handling of all incoming mqtt packets. Returns a `Notification` which for the + /// user to consume and `Packet` which for the eventloop to put on the network + /// E.g For incoming QoS1 publish packet, this method returns (Publish, Puback). Publish packet will + /// be forwarded to user and Pubck packet will be written to network + pub fn handle_incoming_packet(&mut self, packet: Incoming) -> Result<(), StateError> { + let out = match &packet { + Incoming::PingResp => self.handle_incoming_pingresp(), + Incoming::Publish(publish) => self.handle_incoming_publish(publish), + Incoming::SubAck(_suback) => self.handle_incoming_suback(), + Incoming::UnsubAck(_unsuback) => self.handle_incoming_unsuback(), + Incoming::PubAck(puback) => self.handle_incoming_puback(puback), + Incoming::PubRec(pubrec) => self.handle_incoming_pubrec(pubrec), + Incoming::PubRel(pubrel) => self.handle_incoming_pubrel(pubrel), + Incoming::PubComp(pubcomp) => self.handle_incoming_pubcomp(pubcomp), + _ => { + error!("Invalid incoming packet = {:?}", packet); + return Err(StateError::WrongPacket); + } + }; + + out?; + self.incoming_buf.lock().unwrap().push_back(packet); + self.last_incoming = Instant::now(); + Ok(()) + } + + #[inline] + fn handle_incoming_suback(&mut self) -> Result<(), StateError> { + Ok(()) + } + + #[inline] + fn handle_incoming_unsuback(&mut self) -> Result<(), StateError> { + Ok(()) + } + + /// Results in a publish notification in all the QoS cases. Replys with an ack + /// in case of QoS1 and Replys rec in case of QoS while also storing the message + fn handle_incoming_publish(&mut self, publish: &Publish) -> Result<(), StateError> { + let qos = publish.qos; + + match qos { + QoS::AtMostOnce => {} + QoS::AtLeastOnce => { + if !self.manual_acks { + let puback = PubAck::new(publish.pkid); + self.outgoing_puback(puback)? + } + } + QoS::ExactlyOnce => { + let pkid = publish.pkid; + self.incoming_pub[pkid as usize] = Some(pkid); + if !self.manual_acks { + let pubrec = PubRec::new(pkid); + self.outgoing_pubrec(pubrec)?; + } + } + } + + Ok(()) + } + + fn handle_incoming_puback(&mut self, puback: &PubAck) -> Result<(), StateError> { + let v = match mem::replace(&mut self.outgoing_pub[puback.pkid as usize], None) { + Some(_) => { + self.inflight -= 1; + Ok(()) + } + None => { + error!("Unsolicited puback packet: {:?}", puback.pkid); + Err(StateError::Unsolicited(puback.pkid)) + } + }; + + if let Some(publish) = self.check_collision(puback.pkid) { + self.outgoing_pub[publish.pkid as usize] = Some(publish.clone()); + self.inflight += 1; + + publish.write(&mut self.write)?; + self.collision_ping_count = 0; + } + + v + } + + fn handle_incoming_pubrec(&mut self, pubrec: &PubRec) -> Result<(), StateError> { + match mem::replace(&mut self.outgoing_pub[pubrec.pkid as usize], None) { + Some(_) => { + // NOTE: Inflight - 1 for qos2 in comp + self.outgoing_rel[pubrec.pkid as usize] = Some(pubrec.pkid); + PubRel::new(pubrec.pkid).write(&mut self.write)?; + Ok(()) + } + None => { + error!("Unsolicited pubrec packet: {:?}", pubrec.pkid); + Err(StateError::Unsolicited(pubrec.pkid)) + } + } + } + + fn handle_incoming_pubrel(&mut self, pubrel: &PubRel) -> Result<(), StateError> { + match mem::replace(&mut self.incoming_pub[pubrel.pkid as usize], None) { + Some(_) => { + PubComp::new(pubrel.pkid).write(&mut self.write)?; + Ok(()) + } + None => { + error!("Unsolicited pubrel packet: {:?}", pubrel.pkid); + Err(StateError::Unsolicited(pubrel.pkid)) + } + } + } + + fn handle_incoming_pubcomp(&mut self, pubcomp: &PubComp) -> Result<(), StateError> { + if let Some(publish) = self.check_collision(pubcomp.pkid) { + publish.write(&mut self.write)?; + self.collision_ping_count = 0; + } + + match mem::replace(&mut self.outgoing_rel[pubcomp.pkid as usize], None) { + Some(_) => { + self.inflight -= 1; + Ok(()) + } + None => { + error!("Unsolicited pubcomp packet: {:?}", pubcomp.pkid); + Err(StateError::Unsolicited(pubcomp.pkid)) + } + } + } + + #[inline] + fn handle_incoming_pingresp(&mut self) -> Result<(), StateError> { + self.await_pingresp = false; + Ok(()) + } + + /// Adds next packet identifier to QoS 1 and 2 publish packets and returns + /// it buy wrapping publish in packet + fn outgoing_publish(&mut self, publish: Publish) -> Result<(), StateError> { + if publish.qos != QoS::AtMostOnce { + // client should set proper pkid + let pkid = publish.pkid; + if self + .outgoing_pub + .get(publish.pkid as usize) + .unwrap() + .is_some() + { + info!("Collision on packet id = {:?}", publish.pkid); + self.collision = Some(publish); + return Ok(()); + } + + // if there is an existing publish at this pkid, this implies that broker hasn't acked this + // packet yet. This error is possible only when broker isn't acking sequentially + self.outgoing_pub[pkid as usize] = Some(publish.clone()); + self.inflight += 1; + }; + + debug!( + "Publish. Topic = {}, Pkid = {:?}, Payload Size = {:?}", + publish.topic, + publish.pkid, + publish.payload.len() + ); + + publish.write(&mut self.write)?; + Ok(()) + } + + #[inline] + fn outgoing_pubrel(&mut self, pubrel: PubRel) -> Result<(), StateError> { + let pubrel = self.save_pubrel(pubrel)?; + + debug!("Pubrel. Pkid = {}", pubrel.pkid); + PubRel::new(pubrel.pkid).write(&mut self.write)?; + Ok(()) + } + + #[inline] + fn outgoing_puback(&mut self, puback: PubAck) -> Result<(), StateError> { + puback.write(&mut self.write)?; + Ok(()) + } + + #[inline] + fn outgoing_pubrec(&mut self, pubrec: PubRec) -> Result<(), StateError> { + pubrec.write(&mut self.write)?; + Ok(()) + } + + /// check when the last control packet/pingreq packet is received and return + /// the status which tells if keep alive time has exceeded + /// NOTE: status will be checked for zero keepalive times also + fn outgoing_ping(&mut self) -> Result<(), StateError> { + let elapsed_in = self.last_incoming.elapsed(); + let elapsed_out = self.last_outgoing.elapsed(); + + if self.collision.is_some() { + self.collision_ping_count += 1; + if self.collision_ping_count >= 2 { + return Err(StateError::CollisionTimeout); + } + } + + // raise error if last ping didn't receive ack + if self.await_pingresp { + return Err(StateError::AwaitPingResp); + } + + self.await_pingresp = true; + + debug!( + "Pingreq, + last incoming packet before {} millisecs, + last outgoing request before {} millisecs", + elapsed_in.as_millis(), + elapsed_out.as_millis() + ); + + PingReq.write(&mut self.write)?; + Ok(()) + } + + #[inline] + fn outgoing_subscribe(&mut self, subscription: Subscribe) -> Result<(), StateError> { + // client should set correct pkid + debug!( + "Subscribe. Topics = {:?}, Pkid = {:?}", + subscription.filters, subscription.pkid + ); + + subscription.write(&mut self.write)?; + Ok(()) + } + + #[inline] + fn outgoing_unsubscribe(&mut self, unsub: Unsubscribe) -> Result<(), StateError> { + debug!( + "Unsubscribe. Topics = {:?}, Pkid = {:?}", + unsub.filters, unsub.pkid + ); + + unsub.write(&mut self.write)?; + Ok(()) + } + + #[inline] + fn outgoing_disconnect(&mut self) -> Result<(), StateError> { + debug!("Disconnect"); + + Disconnect::new().write(&mut self.write)?; + Ok(()) + } + + #[inline] + fn check_collision(&mut self, pkid: u16) -> Option { + if let Some(publish) = &self.collision { + if publish.pkid == pkid { + return self.collision.take(); + } + } + + None + } + + #[inline] + fn save_pubrel(&mut self, pubrel: PubRel) -> Result { + // pubrel's pkid should already be set correct + self.outgoing_rel[pubrel.pkid as usize] = Some(pubrel.pkid); + Ok(pubrel) + } + + #[inline] + pub fn increment_pkid(&self) -> u16 { + self.outgoing_buf.lock().unwrap().increment_pkid() + } + + ///// http://stackoverflow.com/questions/11115364/mqtt-messageid-practical-implementation + ///// Packet ids are incremented till maximum set inflight messages and reset to 1 after that. + ///// + //fn next_pkid(&mut self) -> u16 { + // let next_pkid = self.last_pkid + 1; + + // // When next packet id is at the edge of inflight queue, + // // set await flag. This instructs eventloop to stop + // // processing requests until all the inflight publishes + // // are acked + // if next_pkid == self.max_inflight { + // self.last_pkid = 0; + // return next_pkid; + // } + + // self.last_pkid = next_pkid; + // next_pkid + //} +} + +#[cfg(test)] +mod test { + use super::{MqttState, StateError}; + use crate::v5::{packet::*, Incoming, MqttOptions, Request}; + + fn build_outgoing_publish(qos: QoS) -> Publish { + let topic = "hello/world".to_owned(); + let payload = vec![1, 2, 3]; + + let mut publish = Publish::new(topic, QoS::AtLeastOnce, payload); + publish.qos = qos; + publish + } + + fn build_incoming_publish(qos: QoS, pkid: u16) -> Publish { + let topic = "hello/world".to_owned(); + let payload = vec![1, 2, 3]; + + let mut publish = Publish::new(topic, QoS::AtLeastOnce, payload); + publish.pkid = pkid; + publish.qos = qos; + publish + } + + fn build_mqttstate() -> MqttState { + MqttState::new(100, false, 100) + } + + #[test] + fn next_pkid_increments_as_expected() { + let mqtt = build_mqttstate(); + + for i in 1..=100 { + let pkid = mqtt.increment_pkid(); + + // loops between 0-99. % 100 == 0 implies border + let expected = i % 100; + if expected == 0 { + break; + } + + assert_eq!(expected, pkid); + } + } + + #[test] + fn outgoing_publish_should_set_pkid_and_add_publish_to_queue() { + let mut mqtt = build_mqttstate(); + + // QoS0 Publish + let mut publish = build_outgoing_publish(QoS::AtMostOnce); + publish.pkid = 1; + + // QoS 0 publish shouldn't be saved in queue + mqtt.outgoing_publish(publish).unwrap(); + assert_eq!(mqtt.cur_pkid(), 0); + assert_eq!(mqtt.inflight, 0); + + // QoS1 Publish + let mut publish = build_outgoing_publish(QoS::AtLeastOnce); + publish.pkid = 2; + + // Packet id should be set and publish should be saved in queue + mqtt.outgoing_publish(publish.clone()).unwrap(); + // cur_pkid == 0 as there is no client to update it + assert_eq!(mqtt.cur_pkid(), 0); + assert_eq!(mqtt.inflight, 1); + + // Packet id should be incremented and publish should be saved in queue + publish.pkid = 3; + mqtt.outgoing_publish(publish).unwrap(); + // cur_pkid == 0 as there is no client to update it + assert_eq!(mqtt.cur_pkid(), 0); + assert_eq!(mqtt.inflight, 2); + + // QoS1 Publish + let mut publish = build_outgoing_publish(QoS::ExactlyOnce); + publish.pkid = 4; + + // Packet id should be set and publish should be saved in queue + mqtt.outgoing_publish(publish.clone()).unwrap(); + // cur_pkid == 0 as there is no client to update it + assert_eq!(mqtt.cur_pkid(), 0); + assert_eq!(mqtt.inflight, 3); + + publish.pkid = 5; + // Packet id should be incremented and publish should be saved in queue + mqtt.outgoing_publish(publish).unwrap(); + // cur_pkid == 0 as there is no client to update it + assert_eq!(mqtt.cur_pkid(), 0); + assert_eq!(mqtt.inflight, 4); + } + + #[test] + fn incoming_publish_should_be_added_to_queue_correctly() { + let mut mqtt = build_mqttstate(); + + // QoS0, 1, 2 Publishes + let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); + let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); + let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); + + mqtt.handle_incoming_publish(&publish1).unwrap(); + mqtt.handle_incoming_publish(&publish2).unwrap(); + mqtt.handle_incoming_publish(&publish3).unwrap(); + + let pkid = mqtt.incoming_pub[3].unwrap(); + + // only qos2 publish should be add to queue + assert_eq!(pkid, 3); + } + + #[test] + fn incoming_publish_should_be_acked() { + let mut mqtt = build_mqttstate(); + + // QoS0, 1, 2 Publishes + let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); + let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); + let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); + + mqtt.handle_incoming_publish(&publish1).unwrap(); + mqtt.handle_incoming_publish(&publish2).unwrap(); + mqtt.handle_incoming_publish(&publish3).unwrap(); + } + + #[test] + fn incoming_publish_should_not_be_acked_with_manual_acks() { + let mut mqtt = build_mqttstate(); + mqtt.manual_acks = true; + + // QoS0, 1, 2 Publishes + let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); + let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); + let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); + + mqtt.handle_incoming_publish(&publish1).unwrap(); + mqtt.handle_incoming_publish(&publish2).unwrap(); + mqtt.handle_incoming_publish(&publish3).unwrap(); + + let pkid = mqtt.incoming_pub[3].unwrap(); + assert_eq!(pkid, 3); + + assert!(mqtt.incoming_buf.lock().unwrap().is_empty()); + } + + #[test] + fn incoming_qos2_publish_should_send_rec_to_network_and_publish_to_user() { + let mut mqtt = build_mqttstate(); + let publish = build_incoming_publish(QoS::ExactlyOnce, 1); + + mqtt.handle_incoming_publish(&publish).unwrap(); + let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + match packet { + Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), + _ => panic!("Invalid network request: {:?}", packet), + } + } + + #[test] + fn incoming_puback_should_remove_correct_publish_from_queue() { + let mut mqtt = build_mqttstate(); + + let mut publish1 = build_outgoing_publish(QoS::AtLeastOnce); + let mut publish2 = build_outgoing_publish(QoS::ExactlyOnce); + publish1.pkid = 1; + publish2.pkid = 2; + + mqtt.outgoing_publish(publish1).unwrap(); + mqtt.outgoing_publish(publish2).unwrap(); + assert_eq!(mqtt.inflight, 2); + + mqtt.handle_incoming_puback(&PubAck::new(1)).unwrap(); + assert_eq!(mqtt.inflight, 1); + + mqtt.handle_incoming_puback(&PubAck::new(2)).unwrap(); + assert_eq!(mqtt.inflight, 0); + + assert!(mqtt.outgoing_pub[1].is_none()); + assert!(mqtt.outgoing_pub[2].is_none()); + } + + #[test] + fn incoming_pubrec_should_release_publish_from_queue_and_add_relid_to_rel_queue() { + let mut mqtt = build_mqttstate(); + + let mut publish1 = build_outgoing_publish(QoS::AtLeastOnce); + let mut publish2 = build_outgoing_publish(QoS::ExactlyOnce); + publish1.pkid = 1; + publish2.pkid = 2; + + let _publish_out = mqtt.outgoing_publish(publish1); + let _publish_out = mqtt.outgoing_publish(publish2); + + mqtt.handle_incoming_pubrec(&PubRec::new(2)).unwrap(); + assert_eq!(mqtt.inflight, 2); + + // check if the remaining element's pkid is 1 + let backup = mqtt.outgoing_pub[1].clone(); + assert_eq!(backup.unwrap().pkid, 1); + + // check if the qos2 element's release pkid is 2 + assert_eq!(mqtt.outgoing_rel[2].unwrap(), 2); + } + + #[test] + fn incoming_pubrec_should_send_release_to_network_and_nothing_to_user() { + let mut mqtt = build_mqttstate(); + + let mut publish = build_outgoing_publish(QoS::ExactlyOnce); + publish.pkid = 1; + mqtt.outgoing_publish(publish).unwrap(); + let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + match packet { + Packet::Publish(publish) => assert_eq!(publish.pkid, 1), + packet => panic!("Invalid network request: {:?}", packet), + } + + mqtt.handle_incoming_pubrec(&PubRec::new(1)).unwrap(); + let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + match packet { + Packet::PubRel(pubrel) => assert_eq!(pubrel.pkid, 1), + packet => panic!("Invalid network request: {:?}", packet), + } + } + + #[test] + fn incoming_pubrel_should_send_comp_to_network_and_nothing_to_user() { + let mut mqtt = build_mqttstate(); + let publish = build_incoming_publish(QoS::ExactlyOnce, 1); + + mqtt.handle_incoming_publish(&publish).unwrap(); + let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + match packet { + Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), + packet => panic!("Invalid network request: {:?}", packet), + } + + mqtt.handle_incoming_pubrel(&PubRel::new(1)).unwrap(); + let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + match packet { + Packet::PubComp(pubcomp) => assert_eq!(pubcomp.pkid, 1), + packet => panic!("Invalid network request: {:?}", packet), + } + } + + #[test] + fn incoming_pubcomp_should_release_correct_pkid_from_release_queue() { + let mut mqtt = build_mqttstate(); + let mut publish = build_outgoing_publish(QoS::ExactlyOnce); + publish.pkid = 1; + + mqtt.outgoing_publish(publish).unwrap(); + mqtt.handle_incoming_pubrec(&PubRec::new(1)).unwrap(); + + mqtt.handle_incoming_pubcomp(&PubComp::new(1)).unwrap(); + assert_eq!(mqtt.inflight, 0); + } + + #[test] + fn outgoing_ping_handle_should_throw_errors_for_no_pingresp() { + let mut mqtt = build_mqttstate(); + let mut opts = MqttOptions::new("test", "localhost", 1883); + opts.set_keep_alive(std::time::Duration::from_secs(10)); + mqtt.outgoing_ping().unwrap(); + + // network activity other than pingresp + let mut publish = build_outgoing_publish(QoS::AtLeastOnce); + publish.pkid = 1; + mqtt.handle_outgoing_packet(Request::Publish(publish)) + .unwrap(); + mqtt.handle_incoming_packet(Incoming::PubAck(PubAck::new(1))) + .unwrap(); + + // should throw error because we didn't get pingresp for previous ping + match mqtt.outgoing_ping() { + Ok(_) => panic!("Should throw pingresp await error"), + Err(StateError::AwaitPingResp) => (), + Err(e) => panic!("Should throw pingresp await error. Error = {:?}", e), + } + } + + #[test] + fn outgoing_ping_handle_should_succeed_if_pingresp_is_received() { + let mut mqtt = build_mqttstate(); + + let mut opts = MqttOptions::new("test", "localhost", 1883); + opts.set_keep_alive(std::time::Duration::from_secs(10)); + + // should ping + mqtt.outgoing_ping().unwrap(); + mqtt.handle_incoming_packet(Incoming::PingResp).unwrap(); + + // should ping + mqtt.outgoing_ping().unwrap(); + } +} diff --git a/rumqttc/src/v5/tls.rs b/rumqttc/src/v5/tls.rs new file mode 100644 index 000000000..3936b2ca8 --- /dev/null +++ b/rumqttc/src/v5/tls.rs @@ -0,0 +1,130 @@ +use tokio::net::TcpStream; +use tokio_rustls::rustls; +use tokio_rustls::rustls::client::InvalidDnsNameError; +use tokio_rustls::rustls::{ + Certificate, ClientConfig, OwnedTrustAnchor, PrivateKey, RootCertStore, ServerName, +}; +use tokio_rustls::webpki; +use tokio_rustls::{client::TlsStream, TlsConnector}; + +use crate::v5::{Key, MqttOptions, TlsConfiguration}; + +use std::convert::TryFrom; +use std::io; +use std::io::{BufReader, Cursor}; +use std::net::AddrParseError; +use std::sync::Arc; + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("Addr")] + Addr(#[from] AddrParseError), + #[error("I/O")] + Io(#[from] io::Error), + #[error("Web Pki")] + WebPki(#[from] webpki::Error), + #[error("DNS name")] + DNSName(#[from] InvalidDnsNameError), + #[error("TLS error")] + TLS(#[from] rustls::Error), + #[error("No valid cert in chain")] + NoValidCertInChain, +} + +// The cert handling functions return unit right now, this is a shortcut +impl From<()> for Error { + fn from(_: ()) -> Self { + Error::NoValidCertInChain + } +} + +pub async fn tls_connector(tls_config: &TlsConfiguration) -> Result { + let config = match tls_config { + TlsConfiguration::Simple { + ca, + alpn, + client_auth, + } => { + // Add ca to root store if the connection is TLS + let mut root_cert_store = RootCertStore::empty(); + let certs = rustls_pemfile::certs(&mut BufReader::new(Cursor::new(ca)))?; + + let trust_anchors = certs.iter().map_while(|cert| { + if let Ok(ta) = webpki::TrustAnchor::try_from_cert_der(&cert[..]) { + Some(OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + )) + } else { + None + } + }); + + root_cert_store.add_server_trust_anchors(trust_anchors); + + if root_cert_store.is_empty() { + return Err(Error::NoValidCertInChain); + } + + let config = ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_cert_store); + + // Add der encoded client cert and key + let mut config = if let Some(client) = client_auth.as_ref() { + let certs = + rustls_pemfile::certs(&mut BufReader::new(Cursor::new(client.0.clone())))?; + // load appropriate Key as per the user request. The underlying signature algorithm + // of key generation determines the Signature Algorithm during the TLS Handskahe. + let read_keys = match &client.1 { + Key::RSA(k) => rustls_pemfile::rsa_private_keys(&mut BufReader::new( + Cursor::new(k.clone()), + )), + Key::ECC(k) => rustls_pemfile::pkcs8_private_keys(&mut BufReader::new( + Cursor::new(k.clone()), + )), + }; + let keys = match read_keys { + Ok(v) => v, + Err(_e) => return Err(Error::NoValidCertInChain), + }; + + // Get the first key. Error if it's not valid + let key = match keys.first() { + Some(k) => k.clone(), + None => return Err(Error::NoValidCertInChain), + }; + + let certs = certs.into_iter().map(Certificate).collect(); + + config.with_single_cert(certs, PrivateKey(key))? + } else { + config.with_no_client_auth() + }; + + // Set ALPN + if let Some(alpn) = alpn.as_ref() { + config.alpn_protocols.extend_from_slice(alpn); + } + + Arc::new(config) + } + TlsConfiguration::Rustls(tls_client_config) => tls_client_config.clone(), + }; + + Ok(TlsConnector::from(config)) +} + +pub async fn tls_connect( + options: &MqttOptions, + tls_config: &TlsConfiguration, +) -> Result, Error> { + let addr = options.broker_addr.as_str(); + let port = options.port; + let connector = tls_connector(tls_config).await?; + let domain = ServerName::try_from(addr)?; + let tcp = TcpStream::connect((addr, port)).await?; + let tls = connector.connect(domain, tcp).await?; + Ok(tls) +}