diff --git a/.github/stale.yml b/.github/stale.yml index 37158a003..09f7b78a9 100644 --- a/.github/stale.yml +++ b/.github/stale.yml @@ -1,11 +1,8 @@ daysUntilStale: 20 - +staleLabel: "stale" +daysUntilClose: false +markComment: false exemptLabels: - "in-pipeline" - "help wanted" - "bug" - -markComment: > - This issue has been automatically marked as stale because it has not had - recent activity. It will be closed if no further activity occurs. Thank you - for your contributions. diff --git a/Cargo.lock b/Cargo.lock index bdefe09ff..1ef43fa23 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4,9 +4,9 @@ version = 3 [[package]] name = "addr2line" -version = "0.15.2" +version = "0.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7a2e47a1fbe209ee101dd6d61285226744c6c8d3c21c8dc878ba6cb9f467f3a" +checksum = "b9ecd88a8c8378ca913a680cd98f0f13ac67383d35993f86c90a70e3f137816b" dependencies = [ "gimli", ] @@ -19,11 +19,11 @@ checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" [[package]] name = "ahash" -version = "0.6.3" +version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "796540673305a66d127804eef19ad696f1f204b8c1025aaca4958c17eab32877" +checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47" dependencies = [ - "getrandom 0.2.3", + "getrandom", "once_cell", "version_check", ] @@ -46,17 +46,26 @@ dependencies = [ "winapi", ] +[[package]] +name = "ansi_term" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d52a9bb7ec0cf484c551830a7ce27bd20d67eac647e1befb56b0be4ee39a55d2" +dependencies = [ + "winapi", +] + [[package]] name = "anyhow" -version = "1.0.42" +version = "1.0.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "595d3cfa7a60d4555cb5067b99f07142a08ea778de5cf993f7b75c7d8fabc486" +checksum = "4361135be9122e0870de935d7c439aef945b9f9ddd4199a553b5270b49c82a27" [[package]] name = "argh" -version = "0.1.5" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e7317a549bc17c5278d9e72bb6e62c6aa801ac2567048e39ebc1c194249323e" +checksum = "dbb41d85d92dfab96cb95ab023c265c5e4261bb956c0fb49ca06d90c570f1958" dependencies = [ "argh_derive", "argh_shared", @@ -64,9 +73,9 @@ dependencies = [ [[package]] name = "argh_derive" -version = "0.1.5" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60949c42375351e9442e354434b0cba2ac402c1237edf673cac3a4bf983b8d3c" +checksum = "be69f70ef5497dd6ab331a50bd95c6ac6b8f7f17a7967838332743fbd58dc3b5" dependencies = [ "argh_shared", "heck", @@ -77,9 +86,9 @@ dependencies = [ [[package]] name = "argh_shared" -version = "0.1.5" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a61eb019cb8f415d162cb9f12130ee6bbe9168b7d953c17f4ad049e4051ca00" +checksum = "e6f8c380fa28aa1b36107cd97f0196474bb7241bb95a453c5c01a15ac74b2eac" [[package]] name = "arrayvec" @@ -119,13 +128,13 @@ dependencies = [ [[package]] name = "async_io_stream" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "541b3487bf601cf3a63dfba621d6d0252611f120aaf27b198f018c0e1714f0df" +checksum = "b6d7b9decdf35d8908a7e3ef02f64c5e9b1695e230154c0e8de3969142d9b94c" dependencies = [ - "futures 0.3.15", + "futures 0.3.21", "pharos", - "rustc_version 0.3.3", + "rustc_version", "tokio", ] @@ -142,15 +151,15 @@ dependencies = [ [[package]] name = "autocfg" -version = "1.0.1" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "backtrace" -version = "0.3.60" +version = "0.3.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b7815ea54e4d821e791162e078acbebfd6d8c8939cd559c9335dceb1c8ca7282" +checksum = "5e121dee8023ce33ab248d9ce1493df03c3b38a659b240096fcbd7048ff9c31f" dependencies = [ "addr2line", "cc", @@ -179,11 +188,11 @@ version = "0.4.0" dependencies = [ "argh", "async-channel", - "bytes 1.0.1", - "futures 0.3.15", - "itoa", + "bytes 1.1.0", + "futures 0.3.21", + "itoa 0.4.8", "jemallocator", - "pprof 0.3.22", + "pprof 0.3.23", "pretty_env_logger", "prost 0.6.1", "rumqttc", @@ -195,9 +204,9 @@ dependencies = [ [[package]] name = "bitflags" -version = "1.2.1" +version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf1de2fe8c75bc145a2f577add951f8134889b4795d47466a54a5c846d691693" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "block-buffer" @@ -208,6 +217,15 @@ dependencies = [ "generic-array", ] +[[package]] +name = "block-buffer" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bf7fe51849ea569fd452f37822f606a5cabb684dc918707a0193fd4664ff324" +dependencies = [ + "generic-array", +] + [[package]] name = "buf_redux" version = "0.8.4" @@ -220,15 +238,15 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.7.0" +version = "3.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c59e7af012c713f529e7a3ee57ce9b31ddd858d4b512923602f74608b009631" +checksum = "a4a45a46ab1f2412e53d3a0ade76ffad2025804294569aae387231a0cd6e0899" [[package]] name = "bytemuck" -version = "1.7.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9966d2ab714d0f785dbac0a0396251a35280aeb42413281617d0209ab4898435" +checksum = "0e851ca7c24871e7336801608a4797d7376545b6928a10d32d75685687141ead" [[package]] name = "byteorder" @@ -244,21 +262,21 @@ checksum = "0e4cec68f03f32e44924783795810fa50a7035d8c8ebe78580ad7e6c703fba38" [[package]] name = "bytes" -version = "1.0.1" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b700ce4376041dcd0a327fd0097c41095743c4c8af8887265942faf1100bd040" +checksum = "c4872d67bab6358e59559027aa3b9157c53d9358c51423c17554809a8858e0f8" [[package]] name = "cache-padded" -version = "1.1.1" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "631ae5198c9be5e753e5cc215e1bd73c2b466a3565173db433f52bb9d3e66dba" +checksum = "c1db59621ec70f09c5e9b597b220c7a2b43611f4710dc03ceb8748637775692c" [[package]] name = "cc" -version = "1.0.69" +version = "1.0.73" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e70cc2f62c6ce1868963827bd677764c62d07c3d9a3e1fb1177ee1a9ab199eb2" +checksum = "2fff2a6927b3bb87f9595d67196a70493f627687a71d87a0d692242c33f58c11" [[package]] name = "cfg-if" @@ -318,9 +336,9 @@ dependencies = [ [[package]] name = "core-foundation" -version = "0.9.1" +version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a89e2ae426ea83155dccf10c0fa6b1463ef6d5fcb44cee0b224a408fa640a62" +checksum = "194a7a9e6de53fa55116934067c844d9d749312f75c6f6d0980e8c252f8c2146" dependencies = [ "core-foundation-sys", "libc", @@ -328,33 +346,33 @@ dependencies = [ [[package]] name = "core-foundation-sys" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea221b5284a47e40033bf9b66f35f984ec0ea2931eb03505246cd27a963f981b" +checksum = "5827cebf4670468b8772dd191856768aedcb1b0278a04f989f7766351917b9dc" [[package]] name = "cpp_demangle" -version = "0.3.3" +version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ea47428dc9d2237f3c6bc134472edfd63ebba0af932e783506dcfd66f10d18a" +checksum = "eeaa953eaad386a53111e47172c2fedba671e5684c8dd601a5f474f4f118710f" dependencies = [ "cfg-if 1.0.0", ] [[package]] name = "cpufeatures" -version = "0.1.5" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "66c99696f6c9dd7f35d486b9d04d7e6e202aa3e8c40d553f2fdf5e7e0c6a71ef" +checksum = "59a6001667ab124aebae2a495118e11d30984c3a653e99d86d58971708cf5e4b" dependencies = [ "libc", ] [[package]] name = "crossbeam-channel" -version = "0.5.1" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06ed27e177f16d65f0f0c22a213e17c696ace5dd64b14258b52f9417ccb52db4" +checksum = "5aaa7bd5fb665c6864b5f963dd9097905c54125909c7aa94c9e18507cdbe6c53" dependencies = [ "cfg-if 1.0.0", "crossbeam-utils", @@ -362,19 +380,29 @@ dependencies = [ [[package]] name = "crossbeam-utils" -version = "0.8.5" +version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d82cfc11ce7f2c3faef78d8a684447b40d503d9681acebed6cb728d45940c4db" +checksum = "0bf124c720b7686e3c2663cf54062ab0f68a88af2fb6a030e87e30bf721fcb38" dependencies = [ "cfg-if 1.0.0", "lazy_static", ] +[[package]] +name = "crypto-common" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57952ca27b5e3606ff4dd79b0020231aaf9d6aa76dc05fd30137538c50bd3ce8" +dependencies = [ + "generic-array", + "typenum", +] + [[package]] name = "ctor" -version = "0.1.21" +version = "0.1.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ccc0a48a9b826acdf4028595adc9db92caea352f7af011a3034acd172a52a0aa" +checksum = "f877be4f7c9f246b183111634f75baa039715e3f46ce860677d3b19a69fb229c" dependencies = [ "quote", "syn", @@ -382,13 +410,19 @@ dependencies = [ [[package]] name = "debugid" -version = "0.7.2" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f91cf5a8c2f2097e2a32627123508635d47ce10563d999ec1a95addf08b502ba" +checksum = "d6ee87af31d84ef885378aebca32be3d682b0e0dc119d5b4860a2c5bb5046730" dependencies = [ "uuid", ] +[[package]] +name = "diff" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e25ea47919b1560c4e3b7fe0aaab9becf5b84a10325ddf7db0f0ba5e1026499" + [[package]] name = "difference" version = "2.0.0" @@ -404,6 +438,16 @@ dependencies = [ "generic-array", ] +[[package]] +name = "digest" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2fb860ca6fafa5552fb6d0e816a69c8e49f0908bf524e30a90d97c85892d506" +dependencies = [ + "block-buffer 0.10.2", + "crypto-common", +] + [[package]] name = "directories" version = "2.0.2" @@ -416,9 +460,9 @@ dependencies = [ [[package]] name = "dirs-sys" -version = "0.3.6" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03d86534ed367a67548dc68113a0f5db55432fdfbb6e6f9d77704397d95d5780" +checksum = "1b1d1d91c932ef41c0f2663aa8b0ca0342d444d842c06914aa0a7e352d0bada6" dependencies = [ "libc", "redox_users", @@ -455,9 +499,18 @@ dependencies = [ [[package]] name = "event-listener" -version = "2.5.1" +version = "2.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7531096570974c3a9dcf9e4b8e1cede1ec26cf5046219fb3b9d897503b9be59" +checksum = "77f3309417938f28bf8228fcff79a4a37103981e3e186d2ccd19c74b38f4eb71" + +[[package]] +name = "fastrand" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3fcf0cee53519c866c09b5de1f6c56ff9d647101f81c1964fa632e148896cdf" +dependencies = [ + "instant", +] [[package]] name = "fixedbitset" @@ -465,6 +518,19 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37ab347416e802de484e4d03c7316c48f1ecb56574dfd4a46a80f173ce1de04d" +[[package]] +name = "flume" +version = "0.10.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "843c03199d0c0ca54bc1ea90ac0d507274c28abcc4f691ae8b4eaa375087c76a" +dependencies = [ + "futures-core", + "futures-sink", + "nanorand", + "pin-project", + "spin 0.9.2", +] + [[package]] name = "fnv" version = "1.0.7" @@ -510,9 +576,9 @@ checksum = "3a471a38ef8ed83cd6e40aa59c1ffe17db6855c18e3604d9c4ed8c08ebc28678" [[package]] name = "futures" -version = "0.3.15" +version = "0.3.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e7e43a803dae2fa37c1f6a8fe121e1f7bf9548b4dfc0522a42f34145dadfc27" +checksum = "f73fe65f54d1e12b726f517d3e2135ca3125a437b6d998caf1962961f7172d9e" dependencies = [ "futures-channel", "futures-core", @@ -525,9 +591,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.15" +version = "0.3.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e682a68b29a882df0545c143dc3646daefe80ba479bcdede94d5a703de2871e2" +checksum = "c3083ce4b914124575708913bca19bfe887522d6e2e6d0952943f5eac4a74010" dependencies = [ "futures-core", "futures-sink", @@ -535,15 +601,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.15" +version = "0.3.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0402f765d8a89a26043b889b26ce3c4679d268fa6bb22cd7c6aad98340e179d1" +checksum = "0c09fd04b7e4073ac7156a9539b57a484a8ea920f79c7c675d05d289ab6110d3" [[package]] name = "futures-executor" -version = "0.3.15" +version = "0.3.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "badaa6a909fac9e7236d0620a2f57f7664640c56575b71a7552fbd68deafab79" +checksum = "9420b90cfa29e327d0429f19be13e7ddb68fa1cccb09d65e5706b8c7a749b8a6" dependencies = [ "futures-core", "futures-task", @@ -552,18 +618,16 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.15" +version = "0.3.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "acc499defb3b348f8d8f3f66415835a9131856ff7714bf10dadfc4ec4bdb29a1" +checksum = "fc4045962a5a5e935ee2fdedaa4e08284547402885ab326734432bed5d12966b" [[package]] name = "futures-macro" -version = "0.3.15" +version = "0.3.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4c40298486cdf52cc00cd6d6987892ba502c7656a16a4192a9992b1ccedd121" +checksum = "33c1e13800337f4d4d7a316bf45a567dbcb6ffe087f16424852d97e97a91f512" dependencies = [ - "autocfg", - "proc-macro-hack", "proc-macro2", "quote", "syn", @@ -571,23 +635,22 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.15" +version = "0.3.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a57bead0ceff0d6dde8f465ecd96c9338121bb7717d3e7b108059531870c4282" +checksum = "21163e139fa306126e6eedaf49ecdb4588f939600f0b1e770f4205ee4b7fa868" [[package]] name = "futures-task" -version = "0.3.15" +version = "0.3.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a16bef9fc1a4dddb5bee51c989e3fbba26569cbb0e31f5b303c184e3dd33dae" +checksum = "57c66a976bf5909d801bbef33416c41372779507e7a6b3a5e25e4749c58f776a" [[package]] name = "futures-util" -version = "0.3.15" +version = "0.3.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "feb5c238d27e2bf94ffdfd27b2c29e3df4a68c4193bb6427384259e2bf191967" +checksum = "d8b7abd5d659d9b90c8cba917f6ec750a74e2dc23902ef9cd4cc8c8b22e6036a" dependencies = [ - "autocfg", "futures 0.1.31", "futures-channel", "futures-core", @@ -598,16 +661,14 @@ dependencies = [ "memchr", "pin-project-lite", "pin-utils", - "proc-macro-hack", - "proc-macro-nested", "slab", ] [[package]] name = "generic-array" -version = "0.14.4" +version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "501466ecc8a30d1d3b7fc9229b122b2ce8ed6e9d9223f1138d4babb253e51817" +checksum = "fd48d33ec7f05fbfa152300fdad764757cbded343c1aa1cff2fbaf4134851803" dependencies = [ "typenum", "version_check", @@ -615,39 +676,30 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.1.16" +version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce" -dependencies = [ - "cfg-if 1.0.0", - "libc", - "wasi 0.9.0+wasi-snapshot-preview1", -] - -[[package]] -name = "getrandom" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fcd999463524c52659517fe2cea98493cfe485d10565e7b0fb07dbba7ad2753" +checksum = "d39cd93900197114fa1fcb7ae84ca742095eed9442088988ae74fa744e930e77" dependencies = [ "cfg-if 1.0.0", + "js-sys", "libc", "wasi 0.10.2+wasi-snapshot-preview1", + "wasm-bindgen", ] [[package]] name = "gimli" -version = "0.24.0" +version = "0.26.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e4075386626662786ddb0ec9081e7c7eeb1ba31951f447ca780ef9f5d568189" +checksum = "78cc372d058dcf6d5ecd98510e7fbc9e5aec4d21de70f65fea8fecebcd881bd4" [[package]] name = "h2" -version = "0.3.3" +version = "0.3.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "825343c4eef0b63f541f8903f395dc5beb362a979b5799a84062527ef1e37726" +checksum = "62eeb471aa3e3c9197aa4bfeabfe02982f6dc96f750486c0bb0009ac58b26d2b" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "fnv", "futures-core", "futures-sink", @@ -668,18 +720,18 @@ checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e" [[package]] name = "headers" -version = "0.3.4" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0b7591fb62902706ae8e7aaff416b1b0fa2c0fd0878b46dc13baa3712d8a855" +checksum = "4cff78e5788be1e0ab65b04d306b2ed5092c815ec97ec70f4ebd5aee158aa55d" dependencies = [ "base64 0.13.0", "bitflags", - "bytes 1.0.1", + "bytes 1.1.0", "headers-core", "http", + "httpdate", "mime", - "sha-1", - "time", + "sha-1 0.10.0", ] [[package]] @@ -711,37 +763,37 @@ dependencies = [ [[package]] name = "http" -version = "0.2.4" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "527e8c9ac747e28542699a951517aa9a6945af506cd1f2e1b53a576c17b6cc11" +checksum = "31f4c6746584866f0feabcc69893c5b51beef3831656a968ed7ae254cdc4fd03" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "fnv", - "itoa", + "itoa 1.0.1", ] [[package]] name = "http-body" -version = "0.4.2" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60daa14be0e0786db0f03a9e57cb404c9d756eed2b6c62b9ea98ec5743ec75a9" +checksum = "1ff4f84919677303da5f147645dbea6b1881f368d03ac84e1dc09031ebd7b2c6" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "http", "pin-project-lite", ] [[package]] name = "httparse" -version = "1.4.1" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3a87b616e37e93c22fb19bcd386f02f3af5ea98a25670ad0fce773de23c5e68" +checksum = "9100414882e15fb7feccb4897e5f0ff0ff1ca7d1a86a23208ada4d7a18e6c6c4" [[package]] name = "httpdate" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6456b8a6c8f33fee7d958fcd1b60d55b11940a79e63ae87013e6d22e26034440" +checksum = "c4a1e36c821dbe04574f602848a19f742f4fb3c98d40449f11bcad18d6b17421" [[package]] name = "humantime" @@ -754,11 +806,11 @@ dependencies = [ [[package]] name = "hyper" -version = "0.14.10" +version = "0.14.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7728a72c4c7d72665fde02204bcbd93b247721025b222ef78606f14513e0fd03" +checksum = "043f0e083e9901b6cc658a77d1eb86f4fc650bbb977a4337dd63192826aa85dd" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "futures-channel", "futures-core", "futures-util", @@ -767,7 +819,7 @@ dependencies = [ "http-body", "httparse", "httpdate", - "itoa", + "itoa 1.0.1", "pin-project-lite", "socket2", "tokio", @@ -789,9 +841,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "1.7.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc633605454125dec4b66843673f01c7df2b89479b32e0ed634e43a91cff62a5" +checksum = "282a6247722caba404c065016bbfa522806e51714c34f5dfc3e4a3a46fcb4223" dependencies = [ "autocfg", "hashbrown", @@ -799,14 +851,14 @@ dependencies = [ [[package]] name = "inferno" -version = "0.10.6" +version = "0.10.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c3cbcc228d2ad2e99328c2b19f38d80ec387ca6a29f778e40e32ca9f25448c3" +checksum = "de3886428c6400486522cf44b8626e7b94ad794c14390290f2a274dcf728a58f" dependencies = [ "ahash", "atty", "indexmap", - "itoa", + "itoa 1.0.1", "lazy_static", "log", "num-format", @@ -815,20 +867,11 @@ dependencies = [ "str_stack", ] -[[package]] -name = "input_buffer" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f97967975f448f1a7ddb12b0bc41069d09ed6a1c161a92687e057325db35d413" -dependencies = [ - "bytes 1.0.1", -] - [[package]] name = "instant" -version = "0.1.10" +version = "0.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bee0328b1209d157ef001c94dd85b4f8f64139adb0eac2659f4b08382b2f474d" +checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c" dependencies = [ "cfg-if 1.0.0", ] @@ -853,9 +896,15 @@ dependencies = [ [[package]] name = "itoa" -version = "0.4.7" +version = "0.4.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b71991ff56294aa922b450139ee08b3bfc70982c6b2c7562771375cf73542dd4" + +[[package]] +name = "itoa" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd25036021b0de88a0aff6b850051563c6516d0bf53f8638938edbb9de732736" +checksum = "1aab8fc367588b89dcee83ab0fd66b72b50b72fa1904d7095045ace2b0c81c35" [[package]] name = "jackiechan" @@ -891,9 +940,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.51" +version = "0.3.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83bdfbace3a0e81a4253f73b49e960b053e396a11012cbd49b9b74d6a2b67062" +checksum = "a38fc24e30fd564ce974c02bf1d337caddff65be6cc4735a1f7eab22a7440f04" dependencies = [ "wasm-bindgen", ] @@ -920,15 +969,15 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.98" +version = "0.2.121" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "320cfe77175da3a483efed4bc0adc1968ca050b098ce4f2f1c13a56626128790" +checksum = "efaa7b300f3b5fe8eb6bf21ce3895e1751d9665086af2d64b42f19701015ff4f" [[package]] name = "lock_api" -version = "0.4.4" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0382880606dff6d15c9476c416d18690b72742aa7b605bb6dd6ec9030fbf07eb" +checksum = "88943dd7ef4a2e5a4bfa2753aaab3013e34ce2533d1996fb18ef591e315e2b3b" dependencies = [ "scopeguard", ] @@ -944,15 +993,15 @@ dependencies = [ [[package]] name = "matches" -version = "0.1.8" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ffc5c5338469d4d3ea17d269fa8ea3512ad247247c30bd2df69e68309ed0a08" +checksum = "a3e378b66a060d48947b590737b30a1be76706c8dd7b8ba0f2fe3989c68a853f" [[package]] name = "memchr" -version = "2.4.0" +version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b16bd47d9e329435e309c58469fe0791c2d0d1ba96ec0954152a5ae2b04387dc" +checksum = "308cc39be01b73d0d18f82a0e7b2a3df85245f84af96fdddc5d202d27e47b86a" [[package]] name = "memmap" @@ -964,6 +1013,15 @@ dependencies = [ "winapi", ] +[[package]] +name = "memmap2" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "057a3db23999c867821a7a59feb06a578fcb03685e983dff90daf9e7d24ac08f" +dependencies = [ + "libc", +] + [[package]] name = "mime" version = "0.3.16" @@ -972,9 +1030,9 @@ checksum = "2a60c7ce501c71e03a9c9c0d35b861413ae925bd979cc7a4e30d060069aaac8d" [[package]] name = "mime_guess" -version = "2.0.3" +version = "2.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2684d4c2e97d99848d30b324b00c8fcc7e5c897b7cbb5819b09e7c90e8baf212" +checksum = "4192263c238a5f0d0c6bfd21f336a313a4ce1c450542449ca191bb657b4642ef" dependencies = [ "mime", "unicase", @@ -992,14 +1050,15 @@ dependencies = [ [[package]] name = "mio" -version = "0.7.13" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c2bdb6314ec10835cd3293dd268473a835c02b7b352e788be788b3c6ca6bb16" +checksum = "52da4364ffb0e4fe33a9841a98a3f3014fb964045ce4f7a45a398243c8d6b0c9" dependencies = [ "libc", "log", "miow", "ntapi", + "wasi 0.11.0+wasi-snapshot-preview1", "winapi", ] @@ -1020,9 +1079,9 @@ checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a" [[package]] name = "multipart" -version = "0.17.1" +version = "0.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d050aeedc89243f5347c3e237e3e13dc76fbe4ae3742a57b94dc14f69acf76d4" +checksum = "00dec633863867f29cb39df64a397cdf4a6354708ddd7759f70c7fb51c5f9182" dependencies = [ "buf_redux", "httparse", @@ -1030,17 +1089,26 @@ dependencies = [ "mime", "mime_guess", "quick-error", - "rand 0.7.3", + "rand", "safemem", "tempfile", "twoway", ] +[[package]] +name = "nanorand" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a51313c5820b0b02bd422f4b44776fbf47961755c74ce64afc73bfad10226c3" +dependencies = [ + "getrandom", +] + [[package]] name = "native-tls" -version = "0.2.7" +version = "0.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8d96b2e1c8da3957d58100b09f102c6d9cfdfced01b7ec5a8974044bb09dbd4" +checksum = "48ba9f7719b5a0f42f338907614285fb5fd70e53858141f69898a1fb7203b24d" dependencies = [ "lazy_static", "libc", @@ -1086,9 +1154,9 @@ checksum = "72ef4a56884ca558e5ddb05a1d1e7e1bfd9a68d9ed024c21704cc98872dae1bb" [[package]] name = "ntapi" -version = "0.3.6" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f6bb902e437b6d86e03cce10a7e2af662292c5dfef23b65899ea3ac9354ad44" +checksum = "c28774a7fd2fbb4f0babd8237ce554b73af68021b5f695a3cebd6c59bac0980f" dependencies = [ "winapi", ] @@ -1111,7 +1179,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bafe4179722c2894288ee77a9f044f02811c86af699344c498b0840c698a2465" dependencies = [ "arrayvec", - "itoa", + "itoa 0.4.8", ] [[package]] @@ -1135,9 +1203,9 @@ dependencies = [ [[package]] name = "num_cpus" -version = "1.13.0" +version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05499f3756671c15885fee9034446956fff3f243d6077b91e5767df161f766b3" +checksum = "19e64526ebdee182341572e50e9ad03965aa510cd94427a4549448f285e957a1" dependencies = [ "hermit-abi", "libc", @@ -1145,18 +1213,18 @@ dependencies = [ [[package]] name = "object" -version = "0.25.3" +version = "0.27.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a38f2be3697a57b4060074ff41b44c16870d916ad7877c17696e063257482bc7" +checksum = "67ac1d3f9a1d3616fd9a60c8d74296f22406a238b6a72f5cc1e6f314df4ffbf9" dependencies = [ "memchr", ] [[package]] name = "once_cell" -version = "1.8.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "692fcb63b64b1758029e0a96ee63e049ce8c5948587f2f7208df04625e5f6b56" +checksum = "87f3e037eac156d1775da914196f0f37741a274155e34a0b7e427c35d2a2ecb9" [[package]] name = "opaque-debug" @@ -1166,9 +1234,9 @@ checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" [[package]] name = "openssl" -version = "0.10.35" +version = "0.10.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "549430950c79ae24e6d02e0b7404534ecf311d94cc9f861e9e4020187d13d885" +checksum = "0c7ae222234c30df141154f159066c5093ff73b63204dcda7121eb082fc56a95" dependencies = [ "bitflags", "cfg-if 1.0.0", @@ -1180,15 +1248,15 @@ dependencies = [ [[package]] name = "openssl-probe" -version = "0.1.4" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28988d872ab76095a6e6ac88d99b54fd267702734fd7ffe610ca27f533ddb95a" +checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-sys" -version = "0.9.65" +version = "0.9.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a7907e3bfa08bb85105209cdfcb6c63d109f8f6c1ed6ca318fff5c1853fbc1d" +checksum = "7e46109c383602735fa0a2e48dd2b7c892b048e1bf69e5c3b1d804b7d9c203cb" dependencies = [ "autocfg", "cc", @@ -1208,20 +1276,30 @@ dependencies = [ [[package]] name = "parking_lot" -version = "0.11.1" +version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d7744ac029df22dca6284efe4e898991d28e3085c706c972bcd7da4a27a15eb" +checksum = "7d17b78036a60663b797adeaee46f5c9dfebb86948d1255007a1d6be0271ff99" dependencies = [ "instant", "lock_api", - "parking_lot_core", + "parking_lot_core 0.8.5", +] + +[[package]] +name = "parking_lot" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87f5ec2493a61ac0506c0f4199f99070cbe83857b0337006a30f3e6719b8ef58" +dependencies = [ + "lock_api", + "parking_lot_core 0.9.1", ] [[package]] name = "parking_lot_core" -version = "0.8.3" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa7a782938e745763fe6907fc6ba86946d72f49fe7e21de074e08128a99fb018" +checksum = "d76e8e1493bcac0d2766c42737f34458f1c8c50c0d23bcb24ea953affb273216" dependencies = [ "cfg-if 1.0.0", "instant", @@ -1231,6 +1309,19 @@ dependencies = [ "winapi", ] +[[package]] +name = "parking_lot_core" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28141e0cc4143da2443301914478dc976a61ffdb3f043058310c70df2fed8954" +dependencies = [ + "cfg-if 1.0.0", + "libc", + "redox_syscall", + "smallvec", + "windows-sys", +] + [[package]] name = "pem" version = "0.8.3" @@ -1248,15 +1339,6 @@ version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d4fd5641d01c8f18a23da7b6fe29298ff4b55afcccdf78973b24cf3175fee32e" -[[package]] -name = "pest" -version = "2.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10f4872ae94d7b90ae48754df22fd42ad52ce740b8f370b03da4835417403e53" -dependencies = [ - "ucd-trie", -] - [[package]] name = "petgraph" version = "0.5.1" @@ -1269,28 +1351,28 @@ dependencies = [ [[package]] name = "pharos" -version = "0.5.2" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "235c4b2ebc9552f5eba94ec982acb6c12f224980878e5b74a7d61806bb9c3591" +checksum = "e9567389417feee6ce15dd6527a8a1ecac205ef62c2932bcf3d9f6fc5b78b414" dependencies = [ - "futures 0.3.15", - "rustc_version 0.4.0", + "futures 0.3.21", + "rustc_version", ] [[package]] name = "pin-project" -version = "1.0.7" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7509cc106041c40a4518d2af7a61530e1eed0e6285296a3d8c5472806ccc4a4" +checksum = "58ad3879ad3baf4e44784bc6a718a8698867bb991f8ce24d1bcbe2cfb4c3a75e" dependencies = [ "pin-project-internal", ] [[package]] name = "pin-project-internal" -version = "1.0.7" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48c950132583b500556b1efd71d45b319029f2b71518d979fcc208e16b42426f" +checksum = "744b6f092ba29c3650faf274db506afd39944f48420f6c86b17cfe0ee1cb36bb" dependencies = [ "proc-macro2", "quote", @@ -1299,9 +1381,9 @@ dependencies = [ [[package]] name = "pin-project-lite" -version = "0.2.7" +version = "0.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d31d11c69a6b52a174b42bdc0c30e5e11670f90788b2c471c31c1d17d449443" +checksum = "e280fbe77cc62c91527259e9442153f4688736748d24660126286329742b4c6c" [[package]] name = "pin-utils" @@ -1311,21 +1393,21 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "pkg-config" -version = "0.3.19" +version = "0.3.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3831453b3449ceb48b6d9c7ad7c96d5ea673e9b470a1dc578c2ce6521230884c" +checksum = "58893f751c9b0412871a09abd62ecd2a00298c6c83befa223ef98c52aef40cbe" [[package]] name = "pollster" -version = "0.2.4" +version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb20dcc30536a1508e75d47dd0e399bb2fe7354dcf35cda9127f2bf1ed92e30e" +checksum = "5da3b0203fd7ee5720aa0b5e790b591aa5d3f41c3ed2c34a3a393382198af2f7" [[package]] name = "pprof" -version = "0.3.22" +version = "0.3.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1a6a84f377d00c5cce22c9698811140a1127dcfccab25c30f7365365a12b6e1" +checksum = "f7aa7f5e5c512dd7f44276e636257ed2c07da7800aa4a7b654320292bf571d00" dependencies = [ "backtrace", "inferno", @@ -1333,7 +1415,7 @@ dependencies = [ "libc", "log", "nix 0.19.1", - "parking_lot", + "parking_lot 0.11.2", "prost 0.7.0", "prost-build", "prost-derive 0.7.0", @@ -1344,9 +1426,9 @@ dependencies = [ [[package]] name = "pprof" -version = "0.4.4" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c7600124d694d855283caf9f333befe2abce090833bb638009aeddd9e156dee" +checksum = "d78fcdebc1569625891b4fefed7ece660af53082529d03d9c6e8d01b3880ab92" dependencies = [ "backtrace", "inferno", @@ -1354,7 +1436,7 @@ dependencies = [ "libc", "log", "nix 0.20.0", - "parking_lot", + "parking_lot 0.11.2", "prost 0.7.0", "prost-build", "prost-derive 0.7.0", @@ -1365,9 +1447,9 @@ dependencies = [ [[package]] name = "ppv-lite86" -version = "0.2.10" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac74c624d6b2d21f425f752262f42188365d7b8ff1aff74c82e45136510a4857" +checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872" [[package]] name = "pretty_assertions" @@ -1375,12 +1457,24 @@ version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f81e1644e1b54f5a68959a29aa86cde704219254669da328ecfdf6a1f09d427" dependencies = [ - "ansi_term", + "ansi_term 0.11.0", "ctor", "difference", "output_vt100", ] +[[package]] +name = "pretty_assertions" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c038cb5319b9c704bf9c227c261d275bfec0ad438118a2787ce47944fb228b" +dependencies = [ + "ansi_term 0.12.1", + "ctor", + "diff", + "output_vt100", +] + [[package]] name = "pretty_env_logger" version = "0.4.0" @@ -1391,23 +1485,11 @@ dependencies = [ "log", ] -[[package]] -name = "proc-macro-hack" -version = "0.5.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dbf0c48bc1d91375ae5c3cd81e3722dff1abcf81a30960240640d223f59fe0e5" - -[[package]] -name = "proc-macro-nested" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc881b2c22681370c6a780e47af9840ef841837bc98118431d4e1868bd0c1086" - [[package]] name = "proc-macro2" -version = "1.0.27" +version = "1.0.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0d8caf72986c1a598726adc988bb5984792ef84f5ee5aa50209145ee8077038" +checksum = "c7342d5883fbccae1cc37a2353b09c87c9b0f3afd73f5fb9bba687a1f733b029" dependencies = [ "unicode-xid", ] @@ -1428,7 +1510,7 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e6984d2f1a23009bd270b8bb56d0926810a3d483f59c987d77969e9d8e840b2" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "prost-derive 0.7.0", ] @@ -1438,7 +1520,7 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32d3ebd75ac2679c2af3a92246639f9fcc8a442ee420719cc4fe195b98dd5fa3" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "heck", "itertools 0.9.0", "log", @@ -1482,7 +1564,7 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b518d7cdd93dab1d1122cf07fa9a60771836c668dde9d9e2a139f957f0d9f1bb" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "prost 0.7.0", ] @@ -1494,55 +1576,31 @@ checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" [[package]] name = "quick-xml" -version = "0.20.0" +version = "0.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26aab6b48e2590e4a64d1ed808749ba06257882b461d01ca71baeb747074a6dd" +checksum = "8533f14c8382aaad0d592c812ac3b826162128b65662331e1127b45c3d18536b" dependencies = [ "memchr", ] [[package]] name = "quote" -version = "1.0.9" +version = "1.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3d0b9745dc2debf507c8422de05d7226cc1f0644216dfdfead988f9b1ab32a7" +checksum = "b4af2ec4714533fcdf07e886f17025ace8b997b9ce51204ee69b6da831c3da57" dependencies = [ "proc-macro2", ] [[package]] name = "rand" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03" -dependencies = [ - "getrandom 0.1.16", - "libc", - "rand_chacha 0.2.2", - "rand_core 0.5.1", - "rand_hc 0.2.0", -] - -[[package]] -name = "rand" -version = "0.8.4" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e7573632e6454cf6b99d7aac4ccca54be06da05aca2ef7423d22d27d4d4bcd8" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ "libc", - "rand_chacha 0.3.1", - "rand_core 0.6.3", - "rand_hc 0.3.1", -] - -[[package]] -name = "rand_chacha" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4c8ed856279c9737206bf725bf36935d8666ead7aa69b52be55af369d193402" -dependencies = [ - "ppv-lite86", - "rand_core 0.5.1", + "rand_chacha", + "rand_core", ] [[package]] @@ -1552,16 +1610,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" dependencies = [ "ppv-lite86", - "rand_core 0.6.3", -] - -[[package]] -name = "rand_core" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19" -dependencies = [ - "getrandom 0.1.16", + "rand_core", ] [[package]] @@ -1570,51 +1619,34 @@ version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d34f1408f55294453790c48b2f1ebbb1c5b4b7563eb1f418bcfcfdbb06ebb4e7" dependencies = [ - "getrandom 0.2.3", -] - -[[package]] -name = "rand_hc" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c" -dependencies = [ - "rand_core 0.5.1", -] - -[[package]] -name = "rand_hc" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d51e9f596de227fda2ea6c84607f5558e196eeaf43c986b724ba4fb8fdf497e7" -dependencies = [ - "rand_core 0.6.3", + "getrandom", ] [[package]] name = "redox_syscall" -version = "0.2.9" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ab49abadf3f9e1c4bc499e8845e152ad87d2ad2d30371841171169e9d75feee" +checksum = "8380fe0152551244f0747b1bf41737e0f8a74f97a14ccefd1148187271634f3c" dependencies = [ "bitflags", ] [[package]] name = "redox_users" -version = "0.4.0" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "528532f3d801c87aec9def2add9ca802fe569e44a544afe633765267840abe64" +checksum = "7776223e2696f1aa4c6b0170e83212f47296a00424305117d013dfe86fb0fe55" dependencies = [ - "getrandom 0.2.3", + "getrandom", "redox_syscall", + "thiserror", ] [[package]] name = "regex" -version = "1.5.4" +version = "1.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d07a8629359eb56f1e2fb1652bb04212c072a87ba68546a04065d525673ac461" +checksum = "1a11647b6b25ff05a515cb92c365cec08801e83423a235b51e231e1808747286" dependencies = [ "aho-corasick", "memchr", @@ -1638,9 +1670,9 @@ dependencies = [ [[package]] name = "rgb" -version = "0.8.27" +version = "0.8.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fddb3b23626145d1776addfc307e1a1851f60ef6ca64f376bcb889697144cf0" +checksum = "e74fdc210d8f24a7dbfedc13b04ba5764f5232754ccebfdf5fff1bad791ccbc6" dependencies = [ "bytemuck", ] @@ -1654,7 +1686,7 @@ dependencies = [ "cc", "libc", "once_cell", - "spin", + "spin 0.5.2", "untrusted", "web-sys", "winapi", @@ -1666,16 +1698,17 @@ version = "0.11.0" dependencies = [ "async-channel", "async-tungstenite", - "bytes 1.0.1", + "bytes 1.1.0", "color-backtrace", "crossbeam-channel", "envy", + "flume", "http", "jsonwebtoken", "log", "matches", "pollster", - "pretty_assertions", + "pretty_assertions 1.2.0", "pretty_env_logger", "rustls", "rustls-native-certs", @@ -1685,7 +1718,6 @@ dependencies = [ "tokio", "tokio-rustls", "url", - "webpki", "ws_stream_tungstenite", ] @@ -1694,14 +1726,14 @@ name = "rumqttd" version = "0.10.0" dependencies = [ "argh", - "bytes 1.0.1", + "bytes 1.1.0", "confy", "futures-util", "jackiechan", "jemallocator", "log", - "pprof 0.4.4", - "pretty_assertions", + "pprof 0.4.5", + "pretty_assertions 0.6.1", "pretty_env_logger", "rustls-pemfile 0.3.0", "segments", @@ -1715,18 +1747,9 @@ dependencies = [ [[package]] name = "rustc-demangle" -version = "0.1.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dead70b0b5e03e9c814bcb6b01e03e68f7c57a80aa48c72ec92152ab3e818d49" - -[[package]] -name = "rustc_version" -version = "0.3.3" +version = "0.1.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0dfe2087c51c460008730de8b57e6a320782fbfb312e1f4d520e6c6fae155ee" -dependencies = [ - "semver 0.11.0", -] +checksum = "7ef03e0a2b150c7a90d01faf6254c9c48a41e95fb2a8c2ac1c6f0d2b9aefc342" [[package]] name = "rustc_version" @@ -1734,7 +1757,7 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" dependencies = [ - "semver 1.0.3", + "semver", ] [[package]] @@ -1781,9 +1804,9 @@ dependencies = [ [[package]] name = "ryu" -version = "1.0.5" +version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71d301d4193d031abdd79ff7e3dd721168a9572ef3fe51a1517aba235bd8f86e" +checksum = "73b4b750c782965c211b42f022f59af1fbceabdd026623714f104152f1ec149f" [[package]] name = "safemem" @@ -1825,9 +1848,9 @@ dependencies = [ [[package]] name = "security-framework" -version = "2.3.1" +version = "2.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23a2ac85147a3a11d77ecf1bc7166ec0b92febfa4461c37944e180f319ece467" +checksum = "2dc14f172faf8a0194a3aded622712b0de276821addc574fa54fc0a1167e10dc" dependencies = [ "bitflags", "core-foundation", @@ -1838,9 +1861,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.3.0" +version = "2.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e4effb91b4b8b6fb7732e670b6cee160278ff8e6bf485c7805d9e319d76e284" +checksum = "0160a13a177a45bfb43ce71c01580998474f556ad854dcbca936dd2841a5c556" dependencies = [ "core-foundation-sys", "libc", @@ -1860,42 +1883,24 @@ dependencies = [ [[package]] name = "semver" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f301af10236f6df4160f7c3f04eec6dbc70ace82d23326abad5edee88801c6b6" -dependencies = [ - "semver-parser", -] - -[[package]] -name = "semver" -version = "1.0.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f3aac57ee7f3272d8395c6e4f502f434f0e289fcd62876f70daa008c20dcabe" - -[[package]] -name = "semver-parser" -version = "0.10.2" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00b0bef5b7f9e0df16536d3961cfb6e84331c065b4066afb39768d0e319411f7" -dependencies = [ - "pest", -] +checksum = "a4a3381e03edd24287172047536f20cabde766e2cd3e65e6b00fb3af51c4f38d" [[package]] name = "serde" -version = "1.0.126" +version = "1.0.136" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec7505abeacaec74ae4778d9d9328fe5a5d04253220a85c4ee022239fc996d03" +checksum = "ce31e24b01e1e524df96f1c2fdd054405f8d7376249a5110886fb4b658484789" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.126" +version = "1.0.136" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "963a7dbc9895aeac7ac90e74f34a5d5261828f79df35cbed41e10189d3804d43" +checksum = "08597e7152fcd306f41838ed3e37be9eaeed2b61c42e2117266a554fab4662f9" dependencies = [ "proc-macro2", "quote", @@ -1904,40 +1909,51 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.64" +version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "799e97dc9fdae36a5c8b8f2cae9ce2ee9fdce2058c57a93e6099d919fd982f79" +checksum = "8e8d9fa5c3b304765ce1fd9c4c8a3de2c8db365a5b91be52f186efc675681d95" dependencies = [ - "itoa", + "itoa 1.0.1", "ryu", "serde", ] [[package]] name = "serde_urlencoded" -version = "0.7.0" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "edfa57a7f8d9c1d260a549e7224100f6c43d43f9103e06dd8b4095a9b2b43ce9" +checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" dependencies = [ "form_urlencoded", - "itoa", + "itoa 1.0.1", "ryu", "serde", ] [[package]] name = "sha-1" -version = "0.9.7" +version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a0c8611594e2ab4ebbf06ec7cbbf0a99450b8570e96cbf5188b5d5f6ef18d81" +checksum = "99cd6713db3cf16b6c84e06321e049a9b9f699826e16096d23bbcc44d15d51a6" dependencies = [ - "block-buffer", + "block-buffer 0.9.0", "cfg-if 1.0.0", "cpufeatures", - "digest", + "digest 0.9.0", "opaque-debug", ] +[[package]] +name = "sha-1" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "028f48d513f9678cda28f6e4064755b3fbb2af6acd672f2c209b62323f7aea0f" +dependencies = [ + "cfg-if 1.0.0", + "cpufeatures", + "digest 0.10.3", +] + [[package]] name = "signal-hook-registry" version = "1.4.0" @@ -1958,23 +1974,34 @@ dependencies = [ "num-traits", ] +[[package]] +name = "simplerouter" +version = "0.1.0" +dependencies = [ + "bytes 1.1.0", + "log", + "pretty_env_logger", + "thiserror", + "tokio", +] + [[package]] name = "slab" -version = "0.4.3" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f173ac3d1a7e3b28003f40de0b5ce7fe2710f9b9dc3fc38664cebee46b3b6527" +checksum = "9def91fd1e018fe007022791f865d0ccc9b3a0d5001e01aabb8b40e46000afb5" [[package]] name = "smallvec" -version = "1.6.1" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe0f37c9e8f3c5a4a66ad655a93c74daac4ad00c441533bf5c6e7990bb42604e" +checksum = "f2dd574626839106c320a323308629dcb1acfc96e32a8cba364ddc61ac23ee83" [[package]] name = "socket2" -version = "0.4.0" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e3dfc207c526015c632472a77be09cf1b6e46866581aecae5cc38fb4235dea2" +checksum = "66d72b759436ae32898a2af0a14218dbf55efde3feeb170eb623637db85ee1e0" dependencies = [ "libc", "winapi", @@ -1986,6 +2013,15 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" +[[package]] +name = "spin" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "511254be0c5bcf062b019a6c89c01a664aa359ded62f78aa72c6fc137c0590e5" +dependencies = [ + "lock_api", +] + [[package]] name = "stable_deref_trait" version = "1.2.0" @@ -2000,21 +2036,21 @@ checksum = "9091b6114800a5f2141aee1d1b9d6ca3592ac062dc5decb3764ec5895a47b4eb" [[package]] name = "symbolic-common" -version = "8.3.0" +version = "8.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "348885c332e7d0784d661844b13b198464144a5ebcd3bfc047a6c441867ea467" +checksum = "52ca6f4079d985e79702d1cce708bdd03ac570e220bcf87105d86f5a8ebb26be" dependencies = [ "debugid", - "memmap", + "memmap2", "stable_deref_trait", "uuid", ] [[package]] name = "symbolic-demangle" -version = "8.3.0" +version = "8.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6780c62bfbd609bffaa13d6959715850578aa43caaae7aee14f1f24ceb64f433" +checksum = "7bd7f5075d8bbe9b00eaa9cb47e788fe9405f3306f7497b617b62e5f65ec619a" dependencies = [ "cpp_demangle", "rustc-demangle", @@ -2023,9 +2059,9 @@ dependencies = [ [[package]] name = "syn" -version = "1.0.73" +version = "1.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f71489ff30030d2ae598524f61326b902466f72a0fb1a8564c001cc63425bcc7" +checksum = "ea297be220d52398dcc07ce15a209fce436d361735ac1db700cab3b6cdfb9f54" dependencies = [ "proc-macro2", "quote", @@ -2034,13 +2070,13 @@ dependencies = [ [[package]] name = "tempfile" -version = "3.2.0" +version = "3.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dac1c663cfc93810f88aed9b8941d48cabf856a1b111c29a40439018d870eb22" +checksum = "5cdb1ef4eaeeaddc8fbd371e5017057064af0911902ef36b39801f67cc6d79e4" dependencies = [ "cfg-if 1.0.0", + "fastrand", "libc", - "rand 0.8.4", "redox_syscall", "remove_dir_all", "winapi", @@ -2048,27 +2084,27 @@ dependencies = [ [[package]] name = "termcolor" -version = "1.1.2" +version = "1.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2dfed899f0eb03f32ee8c6a0aabdb8a7949659e3466561fc0adf54e26d88c5f4" +checksum = "bab24d30b911b2376f3a13cc2cd443142f0c81dda04c118693e35b3835757755" dependencies = [ "winapi-util", ] [[package]] name = "thiserror" -version = "1.0.26" +version = "1.0.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93119e4feac1cbe6c798c34d3a53ea0026b0b1de6a120deef895137c0529bfe2" +checksum = "854babe52e4df1653706b98fcfc05843010039b406875930a70e4d9644e5c417" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.26" +version = "1.0.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "060d69a0afe7796bf42e9e2ff91f5ee691fb15c53d38b4b62a9a53eb23164745" +checksum = "aa32fd3f627f367fe16f893e2597ae3c05020f8bba2666a4e6ea73d377e5714b" dependencies = [ "proc-macro2", "quote", @@ -2087,9 +2123,9 @@ dependencies = [ [[package]] name = "tinyvec" -version = "1.2.0" +version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b5220f05bb7de7f3f53c7c065e1199b3172696fe2db9f9c4d8ad9b4ee74c342" +checksum = "2c1c1d5a42b6245520c249549ec267180beaffcc0615401ac8e31853d4b6d8d2" dependencies = [ "tinyvec_macros", ] @@ -2102,29 +2138,29 @@ checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c" [[package]] name = "tokio" -version = "1.8.2" +version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2602b8af3767c285202012822834005f596c811042315fa7e9f5b12b2a43207" +checksum = "2af73ac49756f3f7c01172e34a23e5d0216f6c32333757c2c61feb2bbff5a5ee" dependencies = [ - "autocfg", - "bytes 1.0.1", + "bytes 1.1.0", "libc", "memchr", "mio", "num_cpus", "once_cell", - "parking_lot", + "parking_lot 0.12.0", "pin-project-lite", "signal-hook-registry", + "socket2", "tokio-macros", "winapi", ] [[package]] name = "tokio-macros" -version = "1.3.0" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54473be61f4ebe4efd09cec9bd5d16fa51d70ea0192213d754d2d500457db110" +checksum = "b557f72f448c511a979e2564e55d74e6c4432fc96ff4f6241bc6bded342643b7" dependencies = [ "proc-macro2", "quote", @@ -2143,9 +2179,9 @@ dependencies = [ [[package]] name = "tokio-rustls" -version = "0.23.2" +version = "0.23.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a27d5f2b839802bd8267fa19b0530f5a08b9c08cd417976be2a65d130fe1c11b" +checksum = "4151fda0cf2798550ad0b34bcfc9b9dcc2a9d2471c895c68f3a8818e54f2389e" dependencies = [ "rustls", "tokio", @@ -2154,9 +2190,9 @@ dependencies = [ [[package]] name = "tokio-stream" -version = "0.1.7" +version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b2f3f698253f03119ac0102beaa64f67a67e08074d03a22d18784104543727f" +checksum = "50145484efff8818b5ccd256697f36863f587da82cf8b409c53adf1e840798e3" dependencies = [ "futures-core", "pin-project-lite", @@ -2165,24 +2201,24 @@ dependencies = [ [[package]] name = "tokio-tungstenite" -version = "0.13.0" +version = "0.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1a5f475f1b9d077ea1017ecbc60890fda8e54942d680ca0b1d2b47cfa2d861b" +checksum = "511de3f85caf1c98983545490c3d09685fa8eb634e57eec22bb4db271f46cbd8" dependencies = [ "futures-util", "log", "pin-project", "tokio", - "tungstenite 0.12.0", + "tungstenite 0.14.0", ] [[package]] name = "tokio-util" -version = "0.6.7" +version = "0.6.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1caa0b0c8d94a049db56b5acf8cba99dc0623aab1b26d5b5f5e2d945846b3592" +checksum = "9e99e1983e5d376cd8eb4b66604d2e99e79f5bd988c3055891dcd8c9e2604cc0" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "futures-core", "futures-sink", "log", @@ -2207,9 +2243,9 @@ checksum = "360dfd1d6d30e05fda32ace2c8c70e9c0a9da713275777f5a4dbb8a1893930c6" [[package]] name = "tracing" -version = "0.1.26" +version = "0.1.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09adeb8c97449311ccd28a427f96fb563e7fd31aabf994189879d9da2394b89d" +checksum = "4a1bdf54a7c28a2bbf701e1d2233f6c77f473486b94bee4f9678da5a148dca7f" dependencies = [ "cfg-if 1.0.0", "log", @@ -2219,9 +2255,9 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.18" +version = "0.1.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9ff14f98b1a4b289c6248a023c1c2fa1491062964e9fed67ab29c4e4da4a052" +checksum = "aa31669fa42c09c34d94d8165dd2012e8ff3c66aca50f3bb226b68f216f2706c" dependencies = [ "lazy_static", ] @@ -2234,19 +2270,19 @@ checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642" [[package]] name = "tungstenite" -version = "0.12.0" +version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ada8297e8d70872fa9a551d93250a9f407beb9f37ef86494eb20012a2ff7c24" +checksum = "a0b2d8558abd2e276b0a8df5c05a2ec762609344191e5fd23e292c910e9165b5" dependencies = [ "base64 0.13.0", "byteorder", - "bytes 1.0.1", + "bytes 1.1.0", "http", "httparse", - "input_buffer", "log", - "rand 0.8.4", - "sha-1", + "rand", + "sha-1 0.9.8", + "thiserror", "url", "utf-8", ] @@ -2259,13 +2295,13 @@ checksum = "6ad3713a14ae247f22a728a0456a545df14acf3867f905adff84be99e23b3ad1" dependencies = [ "base64 0.13.0", "byteorder", - "bytes 1.0.1", + "bytes 1.1.0", "http", "httparse", "log", - "rand 0.8.4", + "rand", "rustls", - "sha-1", + "sha-1 0.9.8", "thiserror", "url", "utf-8", @@ -2283,15 +2319,9 @@ dependencies = [ [[package]] name = "typenum" -version = "1.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "879f6906492a7cd215bfa4cf595b600146ccfac0c79bcbd1f3000162af5e8b06" - -[[package]] -name = "ucd-trie" -version = "0.1.3" +version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56dee185309b50d1f11bfedef0fe6d036842e3fb77413abef29f8f8d1c5d4c1c" +checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987" [[package]] name = "unicase" @@ -2304,12 +2334,9 @@ dependencies = [ [[package]] name = "unicode-bidi" -version = "0.3.5" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eeb8be209bb1c96b7c177c7420d26e04eccacb0eeae6b980e35fcb74678107e0" -dependencies = [ - "matches", -] +checksum = "1a01404663e3db436ed2746d9fefef640d868edae3cceb81c3b8d5732fda678f" [[package]] name = "unicode-normalization" @@ -2322,9 +2349,9 @@ dependencies = [ [[package]] name = "unicode-segmentation" -version = "1.8.0" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8895849a949e7845e06bd6dc1aa51731a103c42707010a5b591c0038fb73385b" +checksum = "7e8820f5d777f6224dc4be3632222971ac30164d4a258d595640799554ebfd99" [[package]] name = "unicode-xid" @@ -2370,9 +2397,9 @@ checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" [[package]] name = "version_check" -version = "0.9.3" +version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5fecdca9a5291cc2b8dcf7dc02453fee791a280f3743cb0905f8822ae463b3fe" +checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" [[package]] name = "want" @@ -2386,12 +2413,13 @@ dependencies = [ [[package]] name = "warp" -version = "0.3.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "332d47745e9a0c38636dbd454729b147d16bd1ed08ae67b3ab281c4506771054" +checksum = "3cef4e1e9114a4b7f1ac799f16ce71c14de5778500c5450ec6b7b920c55b587e" dependencies = [ - "bytes 1.0.1", - "futures 0.3.15", + "bytes 1.1.0", + "futures-channel", + "futures-util", "headers", "http", "hyper", @@ -2415,21 +2443,21 @@ dependencies = [ [[package]] name = "wasi" -version = "0.9.0+wasi-snapshot-preview1" +version = "0.10.2+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519" +checksum = "fd6fbd9a79829dd1ad0cc20627bf1ed606756a7f77edff7b66b7064f9cb327c6" [[package]] name = "wasi" -version = "0.10.2+wasi-snapshot-preview1" +version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd6fbd9a79829dd1ad0cc20627bf1ed606756a7f77edff7b66b7064f9cb327c6" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.74" +version = "0.2.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d54ee1d4ed486f78874278e63e4069fc1ab9f6a18ca492076ffb90c5eb2997fd" +checksum = "25f1af7423d8588a3d840681122e72e6a24ddbcb3f0ec385cac0d12d24256c06" dependencies = [ "cfg-if 1.0.0", "wasm-bindgen-macro", @@ -2437,9 +2465,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.74" +version = "0.2.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b33f6a0694ccfea53d94db8b2ed1c3a8a4c86dd936b13b9f0a15ec4a451b900" +checksum = "8b21c0df030f5a177f3cba22e9bc4322695ec43e7257d865302900290bcdedca" dependencies = [ "bumpalo", "lazy_static", @@ -2452,9 +2480,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.74" +version = "0.2.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "088169ca61430fe1e58b8096c24975251700e7b1f6fd91cc9d59b04fb9b18bd4" +checksum = "2f4203d69e40a52ee523b2529a773d5ffc1dc0071801c87b3d270b471b80ed01" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -2462,9 +2490,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.74" +version = "0.2.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be2241542ff3d9f241f5e2cb6dd09b37efe786df8851c54957683a49f0987a97" +checksum = "bfa8a30d46208db204854cadbb5d4baf5fcf8071ba5bf48190c3e59937962ebc" dependencies = [ "proc-macro2", "quote", @@ -2475,15 +2503,15 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.74" +version = "0.2.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7cff876b8f18eed75a66cf49b65e7f967cb354a7aa16003fb55dbfd25b44b4f" +checksum = "3d958d035c4438e28c70e4321a2911302f10135ce78a9c7834c0cab4123d06a2" [[package]] name = "web-sys" -version = "0.3.51" +version = "0.3.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e828417b379f3df7111d3a2a9e5753706cae29c41f7c4029ee9fd77f3e09e582" +checksum = "c060b319f29dd25724f09a2ba1418f142f539b2be99fbf4d2d5a8f7330afb8eb" dependencies = [ "js-sys", "wasm-bindgen", @@ -2501,11 +2529,12 @@ dependencies = [ [[package]] name = "which" -version = "4.1.0" +version = "4.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b55551e42cbdf2ce2bedd2203d0cc08dba002c27510f86dab6d0ce304cba3dfe" +checksum = "5c4fb54e6113b6a8772ee41c3404fb0301ac79604489467e0a9ce1f3e97c24ae" dependencies = [ "either", + "lazy_static", "libc", ] @@ -2540,6 +2569,49 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows-sys" +version = "0.32.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3df6e476185f92a12c072be4a189a0210dcdcf512a1891d6dff9edb874deadc6" +dependencies = [ + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_msvc" +version = "0.32.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8e92753b1c443191654ec532f14c199742964a061be25d77d7a96f09db20bf5" + +[[package]] +name = "windows_i686_gnu" +version = "0.32.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a711c68811799e017b6038e0922cb27a5e2f43a2ddb609fe0b6f3eeda9de615" + +[[package]] +name = "windows_i686_msvc" +version = "0.32.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "146c11bb1a02615db74680b32a68e2d61f553cc24c4eb5b4ca10311740e44172" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.32.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c912b12f7454c6620635bbff3450962753834be2a594819bd5e945af18ec64bc" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.32.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "504a2476202769977a040c6364301a3f65d0cc9e3fb08600b2bda150a0488316" + [[package]] name = "ws_stream_tungstenite" version = "0.7.0" @@ -2555,7 +2627,7 @@ dependencies = [ "futures-util", "log", "pharos", - "rustc_version 0.4.0", + "rustc_version", "tokio", "tungstenite 0.16.0", ] diff --git a/Cargo.toml b/Cargo.toml index bfc8970a7..33819d2ca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,5 +3,6 @@ members = [ "rumqttc", "rumqttd", - "benchmarks" + "benchmarks", + "benchmarks/simplerouter", ] diff --git a/benchmarks/clients/mesh.rs b/benchmarks/clients/mesh.rs index 8f6d04eec..96f38de50 100644 --- a/benchmarks/clients/mesh.rs +++ b/benchmarks/clients/mesh.rs @@ -3,7 +3,7 @@ use tokio::task; use std::thread; use std::path::PathBuf; use bytes::Bytes; -use rumqttlog::router::{Data}; +use rumqttlog::router::Data; mod common; @@ -74,5 +74,3 @@ async fn read(tx: Sender<(usize, RouterInMessage)>) { println!("Id = {}, Total size = {}", id, total_size); } - - diff --git a/benchmarks/clients/rumqttasync.rs b/benchmarks/clients/rumqttasync.rs index 3bfb49ee4..6e4844ad8 100644 --- a/benchmarks/clients/rumqttasync.rs +++ b/benchmarks/clients/rumqttasync.rs @@ -1,4 +1,4 @@ -use rumqttc::*; +use rumqttc::v4::*; use std::error::Error; use std::time::{Duration, Instant}; diff --git a/benchmarks/clients/rumqttasyncqos0.rs b/benchmarks/clients/rumqttasyncqos0.rs index a2a668b0b..081f19ae0 100644 --- a/benchmarks/clients/rumqttasyncqos0.rs +++ b/benchmarks/clients/rumqttasyncqos0.rs @@ -1,4 +1,4 @@ -use rumqttc::*; +use rumqttc::v4::*; use std::error::Error; use std::time::{Duration, Instant}; diff --git a/benchmarks/clients/rumqttsync.rs b/benchmarks/clients/rumqttsync.rs index da85194dd..fd94bd5df 100644 --- a/benchmarks/clients/rumqttsync.rs +++ b/benchmarks/clients/rumqttsync.rs @@ -1,4 +1,4 @@ -use rumqttc::{self, Client, Event, Incoming, MqttOptions, QoS}; +use rumqttc::v4::{Client, Event, Incoming, MqttOptions, QoS}; use std::error::Error; use std::thread; use std::time::{Duration, Instant}; diff --git a/benchmarks/simplerouter/Cargo.toml b/benchmarks/simplerouter/Cargo.toml new file mode 100644 index 000000000..3e2ee24d8 --- /dev/null +++ b/benchmarks/simplerouter/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "simplerouter" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +bytes = "1.1.0" +log = "0.4.14" +pretty_env_logger = "0.4.0" +thiserror = "1.0.30" +tokio = { version = "1.17.0", features = ["net", "sync", "rt-multi-thread", "io-util", "macros"] } diff --git a/benchmarks/simplerouter/src/bin/simplerouter.rs b/benchmarks/simplerouter/src/bin/simplerouter.rs new file mode 100644 index 000000000..cf5527749 --- /dev/null +++ b/benchmarks/simplerouter/src/bin/simplerouter.rs @@ -0,0 +1,11 @@ +use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; + +#[tokio::main] +async fn main() { + pretty_env_logger::init(); + simplerouter::run(simplerouter::Config { + addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1883)), + }) + .await + .unwrap(); +} diff --git a/benchmarks/simplerouter/src/lib.rs b/benchmarks/simplerouter/src/lib.rs new file mode 100644 index 000000000..0634f6a9f --- /dev/null +++ b/benchmarks/simplerouter/src/lib.rs @@ -0,0 +1,122 @@ +use std::{io, net::SocketAddr}; + +use bytes::BytesMut; +use log::*; +use tokio::net::TcpListener; + +mod network; +mod protocol; +use network::Network; +use protocol::{v4, v5}; + +pub struct Config { + pub addr: SocketAddr, +} + +pub async fn run(config: Config) -> Result<(), Error> { + let listener = TcpListener::bind(config.addr).await?; + info!("router: listening on {}", config.addr); + + loop { + let (stream, addr) = listener.accept().await?; + info!("router: accepted connection from {}", addr); + let (network, _) = match Network::read_connect(stream).await { + Ok(v) => v, + Err(e) => { + error!("router: unable to read connect : {}", e); + continue; + } + }; + info!("connection: sent connack"); + tokio::spawn(publisher_handle(network)); + } +} + +async fn publisher_handle(mut network: Network) { + let mut payload = BytesMut::with_capacity(2); + v4::pingresp::write(&mut payload).unwrap(); + let pingresp_bytes = payload.split().freeze(); + + loop { + let packet = match network.poll().await { + Ok(packet) => packet, + Err(e) => { + error!("connection: unable to read packet: {}", e); + return; + } + }; + match packet { + protocol::Packet::V4(packet) => match packet { + v4::Packet::Disconnect => { + info!("connection: received disconnect, exiting"); + return; + } + v4::Packet::PingReq => { + if let Err(e) = network.send_data(&pingresp_bytes).await { + error!("unable to send pingresp, exiting : {}", e); + return; + }; + } + v4::Packet::Publish(publish) => { + let pkid = match publish.view_meta() { + Ok(v) => v.2, + Err(e) => { + error!("connection: malformed publish packet : {}", e); + continue; + } + }; + payload.reserve(2); + v4::puback::write(pkid, &mut payload).unwrap(); + if let Err(e) = network.send_data(&payload.split().freeze()).await { + error!("unable to send puback pkid = {}, exiting : {}", pkid, e); + return; + }; + } + p => { + error!("connection: invalid packet {:?}", p); + continue; + } + }, + protocol::Packet::V5(packet) => match packet { + v5::Packet::Disconnect => { + info!("connection: received disconnect, exiting"); + return; + } + v5::Packet::PingReq => { + if let Err(e) = network.send_data(&pingresp_bytes).await { + error!("unable to send pingresp, exiting : {}", e); + return; + }; + } + v5::Packet::Publish(publish) => { + let pkid = match publish.view_meta() { + Ok(v) => v.2, + Err(e) => { + error!("connection: malformed publish packet : {}", e); + continue; + } + }; + payload.reserve(8); + v5::puback::write(pkid, v5::puback::PubAckReason::Success, None, &mut payload) + .unwrap(); + if let Err(e) = network.send_data(&payload.split().freeze()).await { + error!("unable to send puback pkid = {}, exiting : {}", pkid, e); + return; + }; + } + p => { + error!("connection: invalid packet {:?}", p); + continue; + } + }, + } + } +} + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("MQTT : {0}")] + MQTT(#[from] crate::protocol::Error), + #[error("i/O : {0}")] + IO(#[from] io::Error), +} diff --git a/benchmarks/simplerouter/src/network.rs b/benchmarks/simplerouter/src/network.rs new file mode 100644 index 000000000..b61c02fe4 --- /dev/null +++ b/benchmarks/simplerouter/src/network.rs @@ -0,0 +1,102 @@ +use std::io; + +use bytes::{Bytes, BytesMut}; +use log::*; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::TcpStream, +}; + +use crate::{ + protocol::{self, v4, v5, Connect, Packet}, + Error, +}; + +pub(crate) struct Network { + stream: TcpStream, + buf: BytesMut, + protocol_level: u8, +} + +impl Network { + pub(crate) async fn read_connect(stream: TcpStream) -> Result<(Self, Connect), Error> { + let mut network = Self { + stream, + buf: BytesMut::with_capacity(4096), + protocol_level: 0, + }; + debug!("network: reading from stream"); + network.stream.read_buf(&mut network.buf).await?; + let connect_packet = loop { + match protocol::read_first_connect(&mut network.buf, 4096) { + Err(protocol::Error::InsufficientBytes(count)) => { + network.read_atleast(count).await? + } + res => break res?, + } + }; + debug!("network: read connect"); + match &connect_packet { + Connect::V4(_) => { + network.protocol_level = 4; + let mut payload = BytesMut::with_capacity(10); + v4::connack::write(v4::connack::ConnectReturnCode::Success, false, &mut payload)?; + network.send_data(&payload.split().freeze()).await?; + } + Connect::V5(_) => { + network.protocol_level = 5; + let mut payload = BytesMut::with_capacity(10); + v5::connack::write( + v5::connack::ConnectReturnCode::Success, + false, + None, + &mut payload, + )?; + network.send_data(&payload.split().freeze()).await?; + } + } + debug!("network: sent connack"); + Ok((network, connect_packet)) + } + + async fn read_atleast(&mut self, count: usize) -> io::Result<()> { + let mut len = 0; + while len < count { + len += self.stream.read_buf(&mut self.buf).await?; + } + debug!("network: read {} bytes", len); + + Ok(()) + } + + pub(crate) async fn poll(&mut self) -> Result { + loop { + match self.protocol_level { + 4 => match v4::read_mut(&mut self.buf, 4096) { + Err(protocol::Error::InsufficientBytes(count)) => { + self.read_atleast(count).await?; + continue; + } + res => return Ok(Packet::V4(res?)), + }, + 5 => match v5::read_mut(&mut self.buf, 4096) { + Err(protocol::Error::InsufficientBytes(count)) => { + self.read_atleast(count).await?; + continue; + } + res => return Ok(Packet::V5(res?)), + }, + // SAFETY: we don't allow changing protocol_level + _ => unsafe { std::hint::unreachable_unchecked() }, + } + } + } + + pub(crate) async fn send_data(&mut self, data: &Bytes) -> Result<(), Error> { + debug!( + "network: sent {} bytes", + self.stream.write(data.as_ref()).await? + ); + Ok(()) + } +} diff --git a/benchmarks/simplerouter/src/protocol/mod.rs b/benchmarks/simplerouter/src/protocol/mod.rs new file mode 100644 index 000000000..2ecdbd320 --- /dev/null +++ b/benchmarks/simplerouter/src/protocol/mod.rs @@ -0,0 +1,430 @@ +#![allow(dead_code)] +use std::{slice::Iter, str::Utf8Error}; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +pub mod v4; +pub mod v5; + +/// Checks if the filter is valid +/// +/// https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718106 +pub fn valid_filter(filter: &str) -> bool { + if filter.is_empty() { + return false; + } + + let hirerarchy = filter.split('/').collect::>(); + if let Some((last, remaining)) = hirerarchy.split_last() { + // # is not allowed in filer except as a last entry + // invalid: sport/tennis#/player + // invalid: sport/tennis/#/ranking + for entry in remaining.iter() { + if entry.contains('#') { + return false; + } + } + + // only single '#" is allowed in last entry + // invalid: sport/tennis# + if last.len() != 1 && last.contains('#') { + return false; + } + } + + true +} + +/// MQTT packet type +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PacketType { + Connect = 1, + ConnAck, + Publish, + PubAck, + PubRec, + PubRel, + PubComp, + Subscribe, + SubAck, + Unsubscribe, + UnsubAck, + PingReq, + PingResp, + Disconnect, +} + +/// Error during serialization and deserialization +#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)] +pub enum Error { + #[error("Expected connect packet, received = {0:?}")] + NotConnect(PacketType), + #[error("Received an unexpected connect packet")] + UnexpectedConnect, + #[error("Invalid return code received as response for connect = {0}")] + InvalidConnectReturnCode(u8), + #[error("Invalid reason = {0}")] + InvalidReason(u8), + #[error("Invalid protocol used")] + InvalidProtocol, + #[error("Invalid protocol level")] + InvalidProtocolLevel(u8), + #[error("Invalid packet format")] + IncorrectPacketFormat, + #[error("Invalid packet type = {0}")] + InvalidPacketType(u8), + #[error("Packet type unsupported = {0:?}")] + UnsupportedPacket(PacketType), + #[error("Invalid retain forward rule = {0}")] + InvalidRetainForwardRule(u8), + #[error("Invalid QoS level = {0}")] + InvalidQoS(u8), + #[error("Invalid subscribe reason code = {0}")] + InvalidSubscribeReasonCode(u8), + #[error("Packet received has id Zero")] + PacketIdZero, + #[error("Subscription had id Zero")] + SubscriptionIdZero, + #[error("Payload size is incorrect")] + PayloadSizeIncorrect, + #[error("Payload is too long")] + PayloadTooLong, + #[error("Payload size has been exceeded by {0} bytes")] + PayloadSizeLimitExceeded(usize), + #[error("Payload is required")] + PayloadRequired, + #[error("Topic not utf-8 = {0}")] + TopicNotUtf8(#[from] Utf8Error), + #[error("Promised boundary crossed, contains {0} bytes")] + BoundaryCrossed(usize), + #[error("Packet is malformed")] + MalformedPacket, + #[error("Remaining length is malformed")] + MalformedRemainingLength, + /// More bytes required to frame packet. Argument + /// implies minimum additional bytes required to + /// proceed further + #[error("Insufficient number of bytes to frame packet, {0} more bytes required")] + InsufficientBytes(usize), + #[error("Property does not exist = {0}")] + InvalidPropertyType(u8), +} + +/// Quality of service +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)] +pub enum QoS { + AtMostOnce = 0, + AtLeastOnce = 1, +} + +/// Maps a number to QoS +pub fn qos(num: u8) -> Result { + match num { + 0 => Ok(QoS::AtMostOnce), + 1 => Ok(QoS::AtLeastOnce), + qos => Err(Error::InvalidQoS(qos)), + } +} + +/// Packet type from a byte +/// +/// ```ignore +/// 7 3 0 +/// +--------------------------+--------------------------+ +/// byte 1 | MQTT Control Packet Type | Flags for each type | +/// +--------------------------+--------------------------+ +/// | Remaining Bytes Len (1/2/3/4 bytes) | +/// +-----------------------------------------------------+ +/// +/// http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Figure_2.2_- +/// ``` +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)] +pub struct FixedHeader { + /// First byte of the stream. Used to identify packet types and + /// several flags + pub byte1: u8, + /// Length of fixed header. Byte 1 + (1..4) bytes. So fixed header + /// len can vary from 2 bytes to 5 bytes + /// 1..4 bytes are variable length encoded to represent remaining length + pub fixed_header_len: usize, + /// Remaining length of the packet. Doesn't include fixed header bytes + /// Represents variable header + payload size + pub remaining_len: usize, +} + +impl FixedHeader { + pub fn new(byte1: u8, remaining_len_len: usize, remaining_len: usize) -> FixedHeader { + FixedHeader { + byte1, + fixed_header_len: remaining_len_len + 1, + remaining_len, + } + } + + pub fn packet_type(&self) -> Result { + let num = self.byte1 >> 4; + match num { + 1 => Ok(PacketType::Connect), + 2 => Ok(PacketType::ConnAck), + 3 => Ok(PacketType::Publish), + 4 => Ok(PacketType::PubAck), + 5 => Ok(PacketType::PubRec), + 6 => Ok(PacketType::PubRel), + 7 => Ok(PacketType::PubComp), + 8 => Ok(PacketType::Subscribe), + 9 => Ok(PacketType::SubAck), + 10 => Ok(PacketType::Unsubscribe), + 11 => Ok(PacketType::UnsubAck), + 12 => Ok(PacketType::PingReq), + 13 => Ok(PacketType::PingResp), + 14 => Ok(PacketType::Disconnect), + _ => Err(Error::InvalidPacketType(num)), + } + } + + /// Returns the size of full packet (fixed header + variable header + payload) + /// Fixed header is enough to get the size of a frame in the stream + pub fn frame_length(&self) -> usize { + self.fixed_header_len + self.remaining_len + } +} + +/// Checks if the stream has enough bytes to frame a packet and returns fixed header +/// only if a packet can be framed with existing bytes in the `stream`. +/// The passed stream doesn't modify parent stream's cursor. If this function +/// returned an error, next `check` on the same parent stream is forced start +/// with cursor at 0 again (Iter is owned. Only Iter's cursor is changed internally) +pub fn check(stream: Iter, max_packet_size: usize) -> Result { + // Create fixed header if there are enough bytes in the stream + // to frame full packet + let stream_len = stream.len(); + let fixed_header = parse_fixed_header(stream)?; + + // Don't let rogue connections attack with huge payloads. + // Disconnect them before reading all that data + if fixed_header.remaining_len > max_packet_size { + return Err(Error::PayloadSizeLimitExceeded(fixed_header.remaining_len)); + } + + // If the current call fails due to insufficient bytes in the stream, + // after calculating remaining length, we extend the stream + let frame_length = fixed_header.frame_length(); + if stream_len < frame_length { + return Err(Error::InsufficientBytes(frame_length - stream_len)); + } + + Ok(fixed_header) +} + +/// Parses fixed header +fn parse_fixed_header(mut stream: Iter) -> Result { + // At least 2 bytes are necessary to frame a packet + let stream_len = stream.len(); + if stream_len < 2 { + return Err(Error::InsufficientBytes(2 - stream_len)); + } + + let byte1 = stream.next().unwrap(); + let (len_len, len) = length(stream)?; + + Ok(FixedHeader::new(*byte1, len_len, len)) +} + +/// Parses variable byte integer in the stream and returns the length +/// and number of bytes that make it. Used for remaining length calculation +/// as well as for calculating property lengths +pub fn length(stream: Iter) -> Result<(usize, usize), Error> { + let mut len: usize = 0; + let mut len_len = 0; + let mut done = false; + let mut shift = 0; + + // Use continuation bit at position 7 to continue reading next + // byte to frame 'length'. + // Stream 0b1xxx_xxxx 0b1yyy_yyyy 0b1zzz_zzzz 0b0www_wwww will + // be framed as number 0bwww_wwww_zzz_zzzz_yyy_yyyy_xxx_xxxx + for byte in stream { + len_len += 1; + let byte = *byte as usize; + len += (byte & 0x7F) << shift; + + // stop when continue bit is 0 + done = (byte & 0x80) == 0; + if done { + break; + } + + shift += 7; + + // Only a max of 4 bytes allowed for remaining length + // more than 4 shifts (0, 7, 14, 21) implies bad length + if shift > 21 { + return Err(Error::MalformedRemainingLength); + } + } + + // Not enough bytes to frame remaining length. wait for + // one more byte + if !done { + return Err(Error::InsufficientBytes(1)); + } + + Ok((len_len, len)) +} + +/// Returns big endian u16 view from next 2 bytes +pub fn view_u16(stream: &[u8]) -> Result { + let v = match stream.get(0..2) { + Some(v) => (v[0] as u16) << 8 | (v[1] as u16), + None => return Err(Error::MalformedPacket), + }; + + Ok(v) +} + +/// Returns big endian u16 view from next 2 bytes +pub fn view_str(stream: &[u8], end: usize) -> Result<&str, Error> { + let v = match stream.get(0..end) { + Some(v) => v, + None => return Err(Error::BoundaryCrossed(stream.len())), + }; + + let v = std::str::from_utf8(v)?; + Ok(v) +} + +/// After collecting enough bytes to frame a packet (packet's frame()) +/// , It's possible that content itself in the stream is wrong. Like expected +/// packet id or qos not being present. In cases where `read_mqtt_string` or +/// `read_mqtt_bytes` exhausted remaining length but packet framing expects to +/// parse qos next, these pre checks will prevent `bytes` crashes + +fn read_u32(stream: &mut Bytes) -> Result { + if stream.len() < 4 { + return Err(Error::MalformedPacket); + } + + Ok(stream.get_u32()) +} + +pub fn read_u16(stream: &mut Bytes) -> Result { + if stream.len() < 2 { + return Err(Error::MalformedPacket); + } + + Ok(stream.get_u16()) +} + +fn read_u8(stream: &mut Bytes) -> Result { + if stream.len() < 1 { + return Err(Error::MalformedPacket); + } + + Ok(stream.get_u8()) +} + +/// Reads a series of bytes with a length from a byte stream +fn read_mqtt_bytes(stream: &mut Bytes) -> Result { + let len = read_u16(stream)? as usize; + + // Prevent attacks with wrong remaining length. This method is used in + // `packet.assembly()` with (enough) bytes to frame packet. Ensures that + // reading variable len string or bytes doesn't cross promised boundary + // with `read_fixed_header()` + if len > stream.len() { + return Err(Error::BoundaryCrossed(len)); + } + + Ok(stream.split_to(len)) +} + +/// Serializes bytes to stream (including length) +fn write_mqtt_bytes(stream: &mut BytesMut, bytes: &[u8]) { + stream.put_u16(bytes.len() as u16); + stream.extend_from_slice(bytes); +} + +/// Serializes a string to stream +pub fn write_mqtt_string(stream: &mut BytesMut, string: &str) { + write_mqtt_bytes(stream, string.as_bytes()); +} + +/// Writes remaining length to stream and returns number of bytes for remaining length +pub fn write_remaining_length(stream: &mut BytesMut, len: usize) -> Result { + if len > 268_435_455 { + return Err(Error::PayloadTooLong); + } + + let mut done = false; + let mut x = len; + let mut count = 0; + + while !done { + let mut byte = (x % 128) as u8; + x /= 128; + if x > 0 { + byte |= 128; + } + + stream.put_u8(byte); + count += 1; + done = x == 0; + } + + Ok(count) +} + +/// Return number of remaining length bytes required for encoding length +fn len_len(len: usize) -> usize { + if len >= 2_097_152 { + 4 + } else if len >= 16_384 { + 3 + } else if len >= 128 { + 2 + } else { + 1 + } +} + +pub enum Connect { + V4(v4::connect::Connect), + V5(v5::connect::Connect), +} + +#[derive(Debug)] +pub enum Packet { + V4(v4::Packet), + V5(v5::Packet), +} + +pub(crate) fn read_first_connect(stream: &mut BytesMut, max_size: usize) -> Result { + let fixed_header = check(stream.iter(), max_size)?; + + // Test with a stream with exactly the size to check border panics + let packet = stream.split_to(fixed_header.frame_length()); + match fixed_header.packet_type()? { + PacketType::Connect => {} + p => return Err(Error::NotConnect(p)), + } + let mut packet = packet.freeze(); + + let variable_header_index = fixed_header.fixed_header_len; + packet.advance(variable_header_index); + + // Variable header + let protocol_name = read_mqtt_bytes(&mut packet)?; + let protocol_name = std::str::from_utf8(&protocol_name)?.to_owned(); + let protocol_level = read_u8(&mut packet)?; + if protocol_name != "MQTT" { + return Err(Error::InvalidProtocol); + } + + match protocol_level { + 4 => Ok(Connect::V4(v4::connect::connect_v4_part(packet)?)), + 5 => Ok(Connect::V5(v5::connect::connect_v5_part(packet)?)), + _ => Err(Error::InvalidProtocolLevel(protocol_level)), + } +} diff --git a/benchmarks/simplerouter/src/protocol/v4.rs b/benchmarks/simplerouter/src/protocol/v4.rs new file mode 100644 index 000000000..e418877c0 --- /dev/null +++ b/benchmarks/simplerouter/src/protocol/v4.rs @@ -0,0 +1,863 @@ +#![allow(dead_code)] + +use super::*; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +pub(crate) mod connect { + use super::*; + use bytes::Bytes; + + /// Connection packet initiated by the client + #[derive(Debug, Clone, PartialEq)] + pub struct Connect { + /// Mqtt keep alive time + pub keep_alive: u16, + /// Client Id + pub client_id: String, + /// Clean session. Asks the broker to clear previous state + pub clean_session: bool, + /// Will that broker needs to publish when the client disconnects + pub last_will: Option, + /// Login credentials + pub login: Option, + } + + impl Connect { + pub fn new>(id: S) -> Connect { + Connect { + keep_alive: 10, + client_id: id.into(), + clean_session: true, + last_will: None, + login: None, + } + } + + pub fn len(&self) -> usize { + let mut len = 2 + "MQTT".len() // protocol name + + 1 // protocol version + + 1 // connect flags + + 2; // keep alive + + len += 2 + self.client_id.len(); + + // last will len + if let Some(last_will) = &self.last_will { + len += last_will.len(); + } + + // username and password len + if let Some(login) = &self.login { + len += login.len(); + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + // Variable header + let protocol_name = read_mqtt_bytes(&mut bytes)?; + let protocol_name = std::str::from_utf8(&protocol_name)?.to_owned(); + let protocol_level = read_u8(&mut bytes)?; + if protocol_name != "MQTT" { + return Err(Error::InvalidProtocol); + } + + if protocol_level != 4 { + return Err(Error::InvalidProtocolLevel(protocol_level)); + } + + connect_v4_part(bytes) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + let len = self.len(); + buffer.put_u8(0b0001_0000); + let count = write_remaining_length(buffer, len)?; + write_mqtt_string(buffer, "MQTT"); + buffer.put_u8(0x04); + + let flags_index = 1 + count + 2 + 4 + 1; + + let mut connect_flags = 0; + if self.clean_session { + connect_flags |= 0x02; + } + + buffer.put_u8(connect_flags); + buffer.put_u16(self.keep_alive); + write_mqtt_string(buffer, &self.client_id); + + if let Some(last_will) = &self.last_will { + connect_flags |= last_will.write(buffer)?; + } + + if let Some(login) = &self.login { + connect_flags |= login.write(buffer); + } + + // update connect flags + buffer[flags_index] = connect_flags; + Ok(len) + } + } + + pub(crate) fn connect_v4_part(mut bytes: Bytes) -> Result { + let connect_flags = read_u8(&mut bytes)?; + let clean_session = (connect_flags & 0b10) != 0; + let keep_alive = read_u16(&mut bytes)?; + + let client_id = read_mqtt_bytes(&mut bytes)?; + let client_id = std::str::from_utf8(&client_id)?.to_owned(); + let last_will = LastWill::read(connect_flags, &mut bytes)?; + let login = Login::read(connect_flags, &mut bytes)?; + + let connect = Connect { + keep_alive, + client_id, + clean_session, + last_will, + login, + }; + + Ok(connect) + } + + /// LastWill that broker forwards on behalf of the client + #[derive(Debug, Clone, PartialEq)] + pub struct LastWill { + pub topic: String, + pub message: Bytes, + pub qos: QoS, + pub retain: bool, + } + + impl LastWill { + pub fn _new( + topic: impl Into, + payload: impl Into>, + qos: QoS, + retain: bool, + ) -> LastWill { + LastWill { + topic: topic.into(), + message: Bytes::from(payload.into()), + qos, + retain, + } + } + + fn len(&self) -> usize { + let mut len = 0; + len += 2 + self.topic.len() + 2 + self.message.len(); + len + } + + fn read(connect_flags: u8, mut bytes: &mut Bytes) -> Result, Error> { + let last_will = match connect_flags & 0b100 { + 0 if (connect_flags & 0b0011_1000) != 0 => { + return Err(Error::IncorrectPacketFormat); + } + 0 => None, + _ => { + let will_topic = read_mqtt_bytes(&mut bytes)?; + let will_topic = std::str::from_utf8(&will_topic)?.to_owned(); + let will_message = read_mqtt_bytes(&mut bytes)?; + let will_qos = qos((connect_flags & 0b11000) >> 3)?; + Some(LastWill { + topic: will_topic, + message: will_message, + qos: will_qos, + retain: (connect_flags & 0b0010_0000) != 0, + }) + } + }; + + Ok(last_will) + } + + fn write(&self, buffer: &mut BytesMut) -> Result { + let mut connect_flags = 0; + + connect_flags |= 0x04 | (self.qos as u8) << 3; + if self.retain { + connect_flags |= 0x20; + } + + write_mqtt_string(buffer, &self.topic); + write_mqtt_bytes(buffer, &self.message); + Ok(connect_flags) + } + } + + #[derive(Debug, Clone, PartialEq)] + pub struct Login { + username: String, + password: String, + } + + impl Login { + pub fn new>(u: S, p: S) -> Login { + Login { + username: u.into(), + password: p.into(), + } + } + + fn read(connect_flags: u8, mut bytes: &mut Bytes) -> Result, Error> { + let username = match connect_flags & 0b1000_0000 { + 0 => String::new(), + _ => { + let username = read_mqtt_bytes(&mut bytes)?; + std::str::from_utf8(&username)?.to_owned() + } + }; + + let password = match connect_flags & 0b0100_0000 { + 0 => String::new(), + _ => { + let password = read_mqtt_bytes(&mut bytes)?; + std::str::from_utf8(&password)?.to_owned() + } + }; + + if username.is_empty() && password.is_empty() { + Ok(None) + } else { + Ok(Some(Login { username, password })) + } + } + + fn len(&self) -> usize { + let mut len = 0; + + if !self.username.is_empty() { + len += 2 + self.username.len(); + } + + if !self.password.is_empty() { + len += 2 + self.password.len(); + } + + len + } + + fn write(&self, buffer: &mut BytesMut) -> u8 { + let mut connect_flags = 0; + if !self.username.is_empty() { + connect_flags |= 0x80; + write_mqtt_string(buffer, &self.username); + } + + if !self.password.is_empty() { + connect_flags |= 0x40; + write_mqtt_string(buffer, &self.password); + } + + connect_flags + } + } +} + +pub(crate) mod connack { + use super::*; + use bytes::{Buf, BufMut, Bytes, BytesMut}; + + /// Return code in connack + #[derive(Debug, Clone, Copy, PartialEq)] + #[repr(u8)] + pub enum ConnectReturnCode { + Success = 0, + RefusedProtocolVersion, + BadClientId, + ServiceUnavailable, + BadUserNamePassword, + NotAuthorized, + } + + /// Acknowledgement to connect packet + #[derive(Debug, Clone, PartialEq)] + pub struct ConnAck { + pub session_present: bool, + pub code: ConnectReturnCode, + } + + impl ConnAck { + pub fn new(code: ConnectReturnCode, session_present: bool) -> ConnAck { + ConnAck { + code, + session_present, + } + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + let flags = read_u8(&mut bytes)?; + let return_code = read_u8(&mut bytes)?; + + let session_present = (flags & 0x01) == 1; + let code = connect_return(return_code)?; + let connack = ConnAck { + session_present, + code, + }; + + Ok(connack) + } + } + + pub fn write( + code: ConnectReturnCode, + session_present: bool, + buffer: &mut BytesMut, + ) -> Result { + // sesssion present + code + let len = 1 + 1; + buffer.put_u8(0x20); + + let count = write_remaining_length(buffer, len)?; + buffer.put_u8(session_present as u8); + buffer.put_u8(code as u8); + + Ok(1 + count + len) + } + + /// Connection return code type + fn connect_return(num: u8) -> Result { + match num { + 0 => Ok(ConnectReturnCode::Success), + 1 => Ok(ConnectReturnCode::RefusedProtocolVersion), + 2 => Ok(ConnectReturnCode::BadClientId), + 3 => Ok(ConnectReturnCode::ServiceUnavailable), + 4 => Ok(ConnectReturnCode::BadUserNamePassword), + 5 => Ok(ConnectReturnCode::NotAuthorized), + num => Err(Error::InvalidConnectReturnCode(num)), + } + } +} + +pub(crate) mod publish { + use super::*; + use bytes::{BufMut, Bytes, BytesMut}; + + #[derive(Debug, Clone, PartialEq)] + pub struct Publish { + pub fixed_header: FixedHeader, + pub raw: Bytes, + } + + impl Publish { + // pub fn new, P: Into>>(topic: S, qos: QoS, payload: P) -> Publish { + // Publish { + // dup: false, + // qos, + // retain: false, + // pkid: 0, + // topic: topic.into(), + // payload: Bytes::from(payload.into()), + // } + // } + + // pub fn from_bytes>(topic: S, qos: QoS, payload: Bytes) -> Publish { + // Publish { + // dup: false, + // qos, + // retain: false, + // pkid: 0, + // topic: topic.into(), + // payload, + // } + // } + + // pub fn len(&self) -> usize { + // let mut len = 2 + self.topic.len(); + // if self.qos != QoS::AtMostOnce && self.pkid != 0 { + // len += 2; + // } + + // len += self.payload.len(); + // len + // } + + pub fn view_meta(&self) -> Result<(&str, u8, u16, bool, bool), Error> { + let qos = (self.fixed_header.byte1 & 0b0110) >> 1; + let dup = (self.fixed_header.byte1 & 0b1000) != 0; + let retain = (self.fixed_header.byte1 & 0b0001) != 0; + + // FIXME: Remove indexes and use get method + let stream = &self.raw[self.fixed_header.fixed_header_len..]; + let topic_len = view_u16(&stream)? as usize; + + let stream = &stream[2..]; + let topic = view_str(stream, topic_len)?; + + let pkid = match qos { + 0 => 0, + 1 => { + let stream = &stream[topic_len..]; + let pkid = view_u16(stream)?; + pkid + } + v => return Err(Error::InvalidQoS(v)), + }; + + if qos == 1 && pkid == 0 { + return Err(Error::PacketIdZero); + } + + Ok((topic, qos, pkid, dup, retain)) + } + + pub fn view_topic(&self) -> Result<&str, Error> { + // FIXME: Remove indexes + let stream = &self.raw[self.fixed_header.fixed_header_len..]; + let topic_len = view_u16(&stream)? as usize; + + let stream = &stream[2..]; + let topic = view_str(stream, topic_len)?; + Ok(topic) + } + + pub fn take_topic_and_payload(mut self) -> Result<(Bytes, Bytes), Error> { + let qos = (self.fixed_header.byte1 & 0b0110) >> 1; + + let variable_header_index = self.fixed_header.fixed_header_len; + self.raw.advance(variable_header_index); + let topic = read_mqtt_bytes(&mut self.raw)?; + + match qos { + 0 => (), + 1 => self.raw.advance(2), + v => return Err(Error::InvalidQoS(v)), + }; + + let payload = self.raw; + Ok((topic, payload)) + } + + pub fn read(fixed_header: FixedHeader, bytes: Bytes) -> Result { + let publish = Publish { + fixed_header, + raw: bytes, + }; + + Ok(publish) + } + } + + pub struct PublishBytes(pub Bytes); + + impl From for Result { + fn from(raw: PublishBytes) -> Self { + let fixed_header = check(raw.0.iter(), 100 * 1024 * 1024)?; + Ok(Publish { + fixed_header, + raw: raw.0, + }) + } + } + + pub fn write( + topic: &str, + qos: QoS, + pkid: u16, + dup: bool, + retain: bool, + payload: &[u8], + buffer: &mut BytesMut, + ) -> Result { + let mut len = 2 + topic.len(); + if qos != QoS::AtMostOnce { + len += 2; + } + + len += payload.len(); + + let dup = dup as u8; + let qos = qos as u8; + let retain = retain as u8; + + buffer.put_u8(0b0011_0000 | retain | qos << 1 | dup << 3); + + let count = write_remaining_length(buffer, len)?; + write_mqtt_string(buffer, topic); + + if qos != 0 { + if pkid == 0 { + return Err(Error::PacketIdZero); + } + + buffer.put_u16(pkid); + } + + buffer.extend_from_slice(&payload); + + // TODO: Returned length is wrong in other packets. Fix it + Ok(1 + count + len) + } +} + +pub(crate) mod puback { + use super::*; + use bytes::{Buf, BufMut, Bytes, BytesMut}; + + /// Acknowledgement to QoS1 publish + #[derive(Debug, Clone, PartialEq)] + pub struct PubAck { + pub pkid: u16, + } + + impl PubAck { + pub fn new(pkid: u16) -> PubAck { + PubAck { pkid } + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + let pkid = read_u16(&mut bytes)?; + + // No reason code or properties if remaining length == 2 + if fixed_header.remaining_len == 2 { + return Ok(PubAck { pkid }); + } + + // No properties len or properties if remaining len > 2 but < 4 + if fixed_header.remaining_len < 4 { + return Ok(PubAck { pkid }); + } + + let puback = PubAck { pkid }; + + Ok(puback) + } + } + + pub fn write(pkid: u16, buffer: &mut BytesMut) -> Result { + let len = 2; // pkid + buffer.put_u8(0x40); + + let count = write_remaining_length(buffer, len)?; + buffer.put_u16(pkid); + Ok(1 + count + len) + } +} + +pub(crate) mod subscribe { + use super::*; + use bytes::{Buf, Bytes}; + + /// Subscription packet + #[derive(Debug, Clone, PartialEq)] + pub struct Subscribe { + pub pkid: u16, + pub filters: Vec, + } + + impl Subscribe { + pub fn new>(path: S, qos: QoS) -> Subscribe { + let filter = SubscribeFilter { + path: path.into(), + qos, + }; + + let mut filters = Vec::new(); + filters.push(filter); + Subscribe { pkid: 0, filters } + } + + pub fn add(&mut self, path: String, qos: QoS) -> &mut Self { + let filter = SubscribeFilter { path, qos }; + + self.filters.push(filter); + self + } + + pub fn len(&self) -> usize { + let len = 2 + self.filters.iter().fold(0, |s, t| s + t.len()); // len of pkid + vec![subscribe filter len] + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + let pkid = read_u16(&mut bytes)?; + + // variable header size = 2 (packet identifier) + let mut filters = Vec::new(); + + while bytes.has_remaining() { + let path = read_mqtt_bytes(&mut bytes)?; + let path = std::str::from_utf8(&path)?.to_owned(); + let options = read_u8(&mut bytes)?; + let requested_qos = options & 0b0000_0011; + + filters.push(SubscribeFilter { + path, + qos: qos(requested_qos)?, + }); + } + + let subscribe = Subscribe { pkid, filters }; + + Ok(subscribe) + } + } + + pub fn write( + filters: Vec, + pkid: u16, + buffer: &mut BytesMut, + ) -> Result { + let len = 2 + filters.iter().fold(0, |s, t| s + t.len()); // len of pkid + vec![subscribe filter len] + // write packet type + buffer.put_u8(0x82); + + // write remaining length + let remaining_len_bytes = write_remaining_length(buffer, len)?; + + // write packet id + buffer.put_u16(pkid); + + // write filters + for filter in filters.iter() { + filter.write(buffer); + } + + Ok(1 + remaining_len_bytes + len) + } + + /// Subscription filter + #[derive(Debug, Clone, PartialEq)] + pub struct SubscribeFilter { + pub path: String, + pub qos: QoS, + } + + impl SubscribeFilter { + pub fn new(path: String, qos: QoS) -> SubscribeFilter { + SubscribeFilter { path, qos } + } + + pub fn len(&self) -> usize { + // filter len + filter + options + 2 + self.path.len() + 1 + } + + fn write(&self, buffer: &mut BytesMut) { + let mut options = 0; + options |= self.qos as u8; + + write_mqtt_string(buffer, self.path.as_str()); + buffer.put_u8(options); + } + } +} + +pub(crate) mod suback { + use std::convert::{TryFrom, TryInto}; + + use super::*; + use bytes::{Buf, Bytes}; + + /// Acknowledgement to subscribe + #[derive(Debug, Clone, PartialEq)] + pub struct SubAck { + pub pkid: u16, + pub return_codes: Vec, + } + + impl SubAck { + pub fn new(pkid: u16, return_codes: Vec) -> SubAck { + SubAck { pkid, return_codes } + } + + pub fn len(&self) -> usize { + let len = 2 + self.return_codes.len(); + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + let pkid = read_u16(&mut bytes)?; + + if !bytes.has_remaining() { + return Err(Error::MalformedPacket); + } + + let mut return_codes = Vec::new(); + while bytes.has_remaining() { + let return_code = read_u8(&mut bytes)?; + return_codes.push(return_code.try_into()?); + } + + let suback = SubAck { pkid, return_codes }; + Ok(suback) + } + } + + pub fn write( + return_codes: Vec, + pkid: u16, + buffer: &mut BytesMut, + ) -> Result { + let len = 2 + return_codes.len(); + buffer.put_u8(0x90); + + let remaining_len_bytes = write_remaining_length(buffer, len)?; + buffer.put_u16(pkid); + + let p: Vec = return_codes + .iter() + .map(|&code| match code { + SubscribeReasonCode::Success(qos) => qos as u8, + SubscribeReasonCode::Failure => 0x80, + }) + .collect(); + + buffer.extend_from_slice(&p); + Ok(1 + remaining_len_bytes + len) + } + + #[derive(Debug, Clone, Copy, PartialEq, Eq)] + pub enum SubscribeReasonCode { + Success(QoS), + Failure, + } + + impl TryFrom for SubscribeReasonCode { + type Error = Error; + + fn try_from(value: u8) -> Result { + let v = match value { + 0 => SubscribeReasonCode::Success(QoS::AtMostOnce), + 1 => SubscribeReasonCode::Success(QoS::AtLeastOnce), + 128 => SubscribeReasonCode::Failure, + v => return Err(Error::InvalidSubscribeReasonCode(v)), + }; + + Ok(v) + } + } + + pub fn codes(c: Vec) -> Vec { + c.into_iter() + .map(|v| { + let qos = qos(v).unwrap(); + SubscribeReasonCode::Success(qos) + }) + .collect() + } +} + +pub(crate) mod pingresp { + use super::*; + + pub fn write(payload: &mut BytesMut) -> Result { + payload.put_slice(&[0xD0, 0x00]); + Ok(2) + } +} + +pub(crate) mod pingreq { + use super::*; + + pub fn write(payload: &mut BytesMut) -> Result { + payload.put_slice(&[0xC0, 0x00]); + Ok(2) + } +} + +/// Reads a stream of bytes and extracts next MQTT packet out of it +pub fn read_mut(stream: &mut BytesMut, max_size: usize) -> Result { + let fixed_header = check(stream.iter(), max_size)?; + + // Test with a stream with exactly the size to check border panics + let packet = stream.split_to(fixed_header.frame_length()); + let packet_type = fixed_header.packet_type()?; + + if fixed_header.remaining_len == 0 { + // no payload packets + return match packet_type { + PacketType::PingReq => Ok(Packet::PingReq), + PacketType::PingResp => Ok(Packet::PingResp), + PacketType::Disconnect => Ok(Packet::Disconnect), + _ => Err(Error::PayloadRequired), + }; + } + + let packet = packet.freeze(); + let packet = match packet_type { + PacketType::Connect => Packet::Connect(connect::Connect::read(fixed_header, packet)?), + PacketType::ConnAck => Packet::ConnAck(connack::ConnAck::read(fixed_header, packet)?), + PacketType::Publish => Packet::Publish(publish::Publish::read(fixed_header, packet)?), + PacketType::PubAck => Packet::PubAck(puback::PubAck::read(fixed_header, packet)?), + PacketType::Subscribe => { + Packet::Subscribe(subscribe::Subscribe::read(fixed_header, packet)?) + } + PacketType::SubAck => Packet::SubAck(suback::SubAck::read(fixed_header, packet)?), + PacketType::PingReq => Packet::PingReq, + PacketType::PingResp => Packet::PingResp, + PacketType::Disconnect => Packet::Disconnect, + v => return Err(Error::UnsupportedPacket(v)), + }; + + Ok(packet) +} + +/// Reads a stream of bytes and extracts next MQTT packet out of it +pub fn read(stream: &mut Bytes, max_size: usize) -> Result { + let fixed_header = check(stream.iter(), max_size)?; + + // Test with a stream with exactly the size to check border panics + let packet = stream.split_to(fixed_header.frame_length()); + let packet_type = fixed_header.packet_type()?; + + if fixed_header.remaining_len == 0 { + // no payload packets + return match packet_type { + PacketType::PingReq => Ok(Packet::PingReq), + PacketType::PingResp => Ok(Packet::PingResp), + PacketType::Disconnect => Ok(Packet::Disconnect), + _ => Err(Error::PayloadRequired), + }; + } + + let packet = match packet_type { + PacketType::Connect => Packet::Connect(connect::Connect::read(fixed_header, packet)?), + PacketType::ConnAck => Packet::ConnAck(connack::ConnAck::read(fixed_header, packet)?), + PacketType::Publish => Packet::Publish(publish::Publish::read(fixed_header, packet)?), + PacketType::PubAck => Packet::PubAck(puback::PubAck::read(fixed_header, packet)?), + PacketType::Subscribe => { + Packet::Subscribe(subscribe::Subscribe::read(fixed_header, packet)?) + } + PacketType::SubAck => Packet::SubAck(suback::SubAck::read(fixed_header, packet)?), + PacketType::PingReq => Packet::PingReq, + PacketType::PingResp => Packet::PingResp, + PacketType::Disconnect => Packet::Disconnect, + v => return Err(Error::UnsupportedPacket(v)), + }; + + Ok(packet) +} + +#[derive(Clone, Debug, PartialEq)] +pub enum Packet { + Connect(connect::Connect), + Publish(publish::Publish), + ConnAck(connack::ConnAck), + PubAck(puback::PubAck), + PingReq, + PingResp, + Subscribe(subscribe::Subscribe), + SubAck(suback::SubAck), + Disconnect, +} diff --git a/benchmarks/simplerouter/src/protocol/v5.rs b/benchmarks/simplerouter/src/protocol/v5.rs new file mode 100644 index 000000000..c63fe816e --- /dev/null +++ b/benchmarks/simplerouter/src/protocol/v5.rs @@ -0,0 +1,1952 @@ +#![allow(dead_code)] + +use std::fmt; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +use super::*; + +pub(crate) mod connect { + use super::*; + use bytes::Bytes; + + /// Connection packet initiated by the client + #[derive(Debug, Clone, PartialEq)] + pub struct Connect { + /// Mqtt keep alive time + pub keep_alive: u16, + /// Client Id + pub client_id: String, + /// Clean session. Asks the broker to clear previous state + pub clean_session: bool, + /// Will that broker needs to publish when the client disconnects + pub last_will: Option, + /// Login credentials + pub login: Option, + /// Properties + pub properties: Option, + } + + impl Connect { + pub fn new>(id: S) -> Connect { + Connect { + keep_alive: 10, + client_id: id.into(), + clean_session: true, + last_will: None, + login: None, + properties: None, + } + } + + pub fn len(&self) -> usize { + let mut len = 2 + "MQTT".len() // protocol name + + 1 // protocol version + + 1 // connect flags + + 2; // keep alive + + len += 2 + self.client_id.len(); + + // last will len + if let Some(last_will) = &self.last_will { + len += last_will.len(); + } + + // username and password len + if let Some(login) = &self.login { + len += login.len(); + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + // Variable header + let protocol_name = read_mqtt_bytes(&mut bytes)?; + let protocol_name = std::str::from_utf8(&protocol_name)?.to_owned(); + if protocol_name != "MQTT" { + return Err(Error::InvalidProtocol); + } + + let protocol_level = read_u8(&mut bytes)?; + if protocol_level != 5 { + return Err(Error::InvalidProtocolLevel(protocol_level)); + } + + connect_v5_part(bytes) + } + } + + pub(crate) fn connect_v5_part(mut bytes: Bytes) -> Result { + let connect_flags = read_u8(&mut bytes)?; + let clean_session = (connect_flags & 0b10) != 0; + let keep_alive = read_u16(&mut bytes)?; + + let properties = ConnectProperties::read(&mut bytes)?; + + // Payload + let client_id = read_mqtt_bytes(&mut bytes)?; + let client_id = std::str::from_utf8(&client_id)?.to_owned(); + let last_will = LastWill::read(connect_flags, &mut bytes)?; + let login = Login::read(connect_flags, &mut bytes)?; + + let connect = Connect { + keep_alive, + client_id, + clean_session, + last_will, + login, + properties, + }; + + Ok(connect) + } + + /// LastWill that broker forwards on behalf of the client + #[derive(Debug, Clone, PartialEq)] + pub struct LastWill { + pub topic: String, + pub message: Bytes, + pub qos: QoS, + pub retain: bool, + } + + impl LastWill { + pub fn _new( + topic: impl Into, + payload: impl Into>, + qos: QoS, + retain: bool, + ) -> LastWill { + LastWill { + topic: topic.into(), + message: Bytes::from(payload.into()), + qos, + retain, + } + } + + fn len(&self) -> usize { + let mut len = 0; + len += 2 + self.topic.len() + 2 + self.message.len(); + len + } + + fn read(connect_flags: u8, mut bytes: &mut Bytes) -> Result, Error> { + let last_will = match connect_flags & 0b100 { + 0 if (connect_flags & 0b0011_1000) != 0 => { + return Err(Error::IncorrectPacketFormat); + } + 0 => None, + _ => { + let will_topic = read_mqtt_bytes(&mut bytes)?; + let will_topic = std::str::from_utf8(&will_topic)?.to_owned(); + let will_message = read_mqtt_bytes(&mut bytes)?; + let will_qos = qos((connect_flags & 0b11000) >> 3)?; + Some(LastWill { + topic: will_topic, + message: will_message, + qos: will_qos, + retain: (connect_flags & 0b0010_0000) != 0, + }) + } + }; + + Ok(last_will) + } + } + + #[derive(Debug, Clone, PartialEq)] + pub struct Login { + username: String, + password: String, + } + + impl Login { + pub fn new>(u: S, p: S) -> Login { + Login { + username: u.into(), + password: p.into(), + } + } + + fn read(connect_flags: u8, mut bytes: &mut Bytes) -> Result, Error> { + let username = match connect_flags & 0b1000_0000 { + 0 => String::new(), + _ => { + let username = read_mqtt_bytes(&mut bytes)?; + std::str::from_utf8(&username)?.to_owned() + } + }; + + let password = match connect_flags & 0b0100_0000 { + 0 => String::new(), + _ => { + let password = read_mqtt_bytes(&mut bytes)?; + std::str::from_utf8(&password)?.to_owned() + } + }; + + if username.is_empty() && password.is_empty() { + Ok(None) + } else { + Ok(Some(Login { username, password })) + } + } + + fn len(&self) -> usize { + let mut len = 0; + + if !self.username.is_empty() { + len += 2 + self.username.len(); + } + + if !self.password.is_empty() { + len += 2 + self.password.len(); + } + + len + } + } + + #[derive(Debug, Clone, PartialEq)] + pub struct ConnectProperties { + /// Expiry interval property after loosing connection + pub session_expiry_interval: Option, + /// Maximum simultaneous packets + pub receive_maximum: Option, + /// Maximum packet size + pub max_packet_size: Option, + /// Maximum mapping integer for a topic + pub topic_alias_max: Option, + pub request_response_info: Option, + pub request_problem_info: Option, + /// List of user properties + pub user_properties: Vec<(String, String)>, + /// Method of authentication + pub authentication_method: Option, + /// Authentication data + pub authentication_data: Option, + } + + impl ConnectProperties { + fn _new() -> ConnectProperties { + ConnectProperties { + session_expiry_interval: None, + receive_maximum: None, + max_packet_size: None, + topic_alias_max: None, + request_response_info: None, + request_problem_info: None, + user_properties: Vec::new(), + authentication_method: None, + authentication_data: None, + } + } + + fn read(mut bytes: &mut Bytes) -> Result, Error> { + let mut session_expiry_interval = None; + let mut receive_maximum = None; + let mut max_packet_size = None; + let mut topic_alias_max = None; + let mut request_response_info = None; + let mut request_problem_info = None; + let mut user_properties = Vec::new(); + let mut authentication_method = None; + let mut authentication_data = None; + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + match property(prop)? { + PropertyType::SessionExpiryInterval => { + session_expiry_interval = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::ReceiveMaximum => { + receive_maximum = Some(read_u16(&mut bytes)?); + cursor += 2; + } + PropertyType::MaximumPacketSize => { + max_packet_size = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::TopicAliasMaximum => { + topic_alias_max = Some(read_u16(&mut bytes)?); + cursor += 2; + } + PropertyType::RequestResponseInformation => { + request_response_info = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::RequestProblemInformation => { + request_problem_info = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::UserProperty => { + let key = read_mqtt_bytes(&mut bytes)?; + let key = std::str::from_utf8(&key)?.to_owned(); + let value = read_mqtt_bytes(&mut bytes)?; + let value = std::str::from_utf8(&value)?.to_owned(); + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + PropertyType::AuthenticationMethod => { + let method = read_mqtt_bytes(&mut bytes)?; + let method = std::str::from_utf8(&method)?.to_owned(); + cursor += 2 + method.len(); + authentication_method = Some(method); + } + PropertyType::AuthenticationData => { + let data = read_mqtt_bytes(&mut bytes)?; + cursor += 2 + data.len(); + authentication_data = Some(data); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(ConnectProperties { + session_expiry_interval, + receive_maximum, + max_packet_size, + topic_alias_max, + request_response_info, + request_problem_info, + user_properties, + authentication_method, + authentication_data, + })) + } + + fn len(&self) -> usize { + let mut len = 0; + + if self.session_expiry_interval.is_some() { + len += 1 + 4; + } + + if self.receive_maximum.is_some() { + len += 1 + 2; + } + + if self.max_packet_size.is_some() { + len += 1 + 4; + } + + if self.topic_alias_max.is_some() { + len += 1 + 2; + } + + if self.request_response_info.is_some() { + len += 1 + 1; + } + + if self.request_problem_info.is_some() { + len += 1 + 1; + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + if let Some(authentication_method) = &self.authentication_method { + len += 1 + 2 + authentication_method.len(); + } + + if let Some(authentication_data) = &self.authentication_data { + len += 1 + 2 + authentication_data.len(); + } + + len + } + } +} + +pub(crate) mod connack { + use super::*; + use bytes::{Buf, BufMut, Bytes, BytesMut}; + + /// Return code in connack + #[derive(Debug, Clone, Copy, PartialEq)] + #[repr(u8)] + pub enum ConnectReturnCode { + Success = 0x00, + UnspecifiedError = 0x80, + MalformedPacket = 0x81, + ProtocolError = 0x82, + ImplementationSpecificError = 0x83, + UnsupportedProtocolVersion = 0x84, + ClientIdentifierNotValid = 0x85, + BadUserNamePassword = 0x86, + NotAuthorized = 0x87, + ServerUnavailable = 0x88, + ServerBusy = 0x89, + Banned = 0x8a, + BadAuthenticationMethod = 0x8c, + TopicNameInvalid = 0x90, + PacketTooLarge = 0x95, + QuotaExceeded = 0x97, + PayloadFormatInvalid = 0x99, + RetainNotSupported = 0x9a, + QoSNotSupported = 0x9b, + UseAnotherServer = 0x9c, + ServerMoved = 0x9d, + ConnectionRateExceeded = 0x94, + } + + /// Acknowledgement to connect packet + #[derive(Debug, Clone, PartialEq)] + pub struct ConnAck { + pub session_present: bool, + pub code: ConnectReturnCode, + pub properties: Option, + } + + impl ConnAck { + pub fn new(code: ConnectReturnCode, session_present: bool) -> ConnAck { + ConnAck { + code, + session_present, + properties: None, + } + } + + pub fn len(&self) -> usize { + let mut len = 1 // session present + + 1; // code + + if let Some(properties) = &self.properties { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } else { + // 1 byte for 0 len + len += 1; + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + let flags = read_u8(&mut bytes)?; + let return_code = read_u8(&mut bytes)?; + + let session_present = (flags & 0x01) == 1; + let code = connect_return(return_code)?; + let connack = ConnAck { + session_present, + code, + properties: ConnAckProperties::extract(&mut bytes)?, + }; + + Ok(connack) + } + } + + pub fn write( + code: ConnectReturnCode, + session_present: bool, + properties: Option, + buffer: &mut BytesMut, + ) -> Result { + // TODO: maybe we can remove double checking if properties == None ? + + let mut len = 1 // session present + + 1; // code + if let Some(ref properties) = properties { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } else { + // 1 byte for 0 len + len += 1; + } + + buffer.put_u8(0x20); + + let count = write_remaining_length(buffer, len)?; + + buffer.put_u8(session_present as u8); + buffer.put_u8(code as u8); + + if let Some(properties) = properties { + properties.write(buffer)?; + } else { + // 1 byte for 0 len + buffer.put_u8(0); + } + + Ok(1 + count + len) + } + + #[derive(Debug, Clone, PartialEq)] + pub struct ConnAckProperties { + pub session_expiry_interval: Option, + pub receive_max: Option, + pub max_qos: Option, + pub retain_available: Option, + pub max_packet_size: Option, + pub assigned_client_identifier: Option, + pub topic_alias_max: Option, + pub reason_string: Option, + pub user_properties: Vec<(String, String)>, + pub wildcard_subscription_available: Option, + pub subscription_identifiers_available: Option, + pub shared_subscription_available: Option, + pub server_keep_alive: Option, + pub response_information: Option, + pub server_reference: Option, + pub authentication_method: Option, + pub authentication_data: Option, + } + + impl ConnAckProperties { + pub fn new() -> ConnAckProperties { + ConnAckProperties { + session_expiry_interval: None, + receive_max: None, + max_qos: None, + retain_available: None, + max_packet_size: None, + assigned_client_identifier: None, + topic_alias_max: None, + reason_string: None, + user_properties: Vec::new(), + wildcard_subscription_available: None, + subscription_identifiers_available: None, + shared_subscription_available: None, + server_keep_alive: None, + response_information: None, + server_reference: None, + authentication_method: None, + authentication_data: None, + } + } + + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(_) = &self.session_expiry_interval { + len += 1 + 4; + } + + if let Some(_) = &self.receive_max { + len += 1 + 2; + } + + if let Some(_) = &self.max_qos { + len += 1 + 1; + } + + if let Some(_) = &self.retain_available { + len += 1 + 1; + } + + if let Some(_) = &self.max_packet_size { + len += 1 + 4; + } + + if let Some(id) = &self.assigned_client_identifier { + len += 1 + 2 + id.len(); + } + + if let Some(_) = &self.topic_alias_max { + len += 1 + 2; + } + + if let Some(reason) = &self.reason_string { + len += 1 + 2 + reason.len(); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + if let Some(_) = &self.wildcard_subscription_available { + len += 1 + 1; + } + + if let Some(_) = &self.subscription_identifiers_available { + len += 1 + 1; + } + + if let Some(_) = &self.shared_subscription_available { + len += 1 + 1; + } + + if let Some(_) = &self.server_keep_alive { + len += 1 + 2; + } + + if let Some(info) = &self.response_information { + len += 1 + 2 + info.len(); + } + + if let Some(reference) = &self.server_reference { + len += 1 + 2 + reference.len(); + } + + if let Some(authentication_method) = &self.authentication_method { + len += 1 + 2 + authentication_method.len(); + } + + if let Some(authentication_data) = &self.authentication_data { + len += 1 + 2 + authentication_data.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut session_expiry_interval = None; + let mut receive_max = None; + let mut max_qos = None; + let mut retain_available = None; + let mut max_packet_size = None; + let mut assigned_client_identifier = None; + let mut topic_alias_max = None; + let mut reason_string = None; + let mut user_properties = Vec::new(); + let mut wildcard_subscription_available = None; + let mut subscription_identifiers_available = None; + let mut shared_subscription_available = None; + let mut server_keep_alive = None; + let mut response_information = None; + let mut server_reference = None; + let mut authentication_method = None; + let mut authentication_data = None; + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::SessionExpiryInterval => { + session_expiry_interval = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::ReceiveMaximum => { + receive_max = Some(read_u16(&mut bytes)?); + cursor += 2; + } + PropertyType::MaximumQos => { + max_qos = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::RetainAvailable => { + retain_available = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::AssignedClientIdentifier => { + let bytes = read_mqtt_bytes(&mut bytes)?; + let id = std::str::from_utf8(&bytes)?.to_owned(); + cursor += 2 + id.len(); + assigned_client_identifier = Some(id); + } + PropertyType::MaximumPacketSize => { + max_packet_size = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::TopicAliasMaximum => { + topic_alias_max = Some(read_u16(&mut bytes)?); + cursor += 2; + } + PropertyType::ReasonString => { + let reason = read_mqtt_bytes(&mut bytes)?; + let reason = std::str::from_utf8(&reason)?.to_owned(); + cursor += 2 + reason.len(); + reason_string = Some(reason); + } + PropertyType::UserProperty => { + let key = read_mqtt_bytes(&mut bytes)?; + let key = std::str::from_utf8(&key)?.to_owned(); + let value = read_mqtt_bytes(&mut bytes)?; + let value = std::str::from_utf8(&value)?.to_owned(); + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + PropertyType::WildcardSubscriptionAvailable => { + wildcard_subscription_available = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::SubscriptionIdentifierAvailable => { + subscription_identifiers_available = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::SharedSubscriptionAvailable => { + shared_subscription_available = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::ServerKeepAlive => { + server_keep_alive = Some(read_u16(&mut bytes)?); + cursor += 2; + } + PropertyType::ResponseInformation => { + let info = read_mqtt_bytes(&mut bytes)?; + let info = std::str::from_utf8(&info)?.to_owned(); + cursor += 2 + info.len(); + response_information = Some(info); + } + PropertyType::ServerReference => { + let bytes = read_mqtt_bytes(&mut bytes)?; + let reference = std::str::from_utf8(&bytes)?.to_owned(); + cursor += 2 + reference.len(); + server_reference = Some(reference); + } + PropertyType::AuthenticationMethod => { + let bytes = read_mqtt_bytes(&mut bytes)?; + let method = std::str::from_utf8(&bytes)?.to_owned(); + cursor += 2 + method.len(); + authentication_method = Some(method); + } + PropertyType::AuthenticationData => { + let data = read_mqtt_bytes(&mut bytes)?; + cursor += 2 + data.len(); + authentication_data = Some(data); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(ConnAckProperties { + session_expiry_interval, + receive_max, + max_qos, + retain_available, + max_packet_size, + assigned_client_identifier, + topic_alias_max, + reason_string, + user_properties, + wildcard_subscription_available, + subscription_identifiers_available, + shared_subscription_available, + server_keep_alive, + response_information, + server_reference, + authentication_method, + authentication_data, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(session_expiry_interval) = self.session_expiry_interval { + buffer.put_u8(PropertyType::SessionExpiryInterval as u8); + buffer.put_u32(session_expiry_interval); + } + + if let Some(receive_maximum) = self.receive_max { + buffer.put_u8(PropertyType::ReceiveMaximum as u8); + buffer.put_u16(receive_maximum); + } + + if let Some(qos) = self.max_qos { + buffer.put_u8(PropertyType::MaximumQos as u8); + buffer.put_u8(qos); + } + + if let Some(retain_available) = self.retain_available { + buffer.put_u8(PropertyType::RetainAvailable as u8); + buffer.put_u8(retain_available); + } + + if let Some(max_packet_size) = self.max_packet_size { + buffer.put_u8(PropertyType::MaximumPacketSize as u8); + buffer.put_u32(max_packet_size); + } + + if let Some(id) = &self.assigned_client_identifier { + buffer.put_u8(PropertyType::AssignedClientIdentifier as u8); + write_mqtt_string(buffer, id); + } + + if let Some(topic_alias_max) = self.topic_alias_max { + buffer.put_u8(PropertyType::TopicAliasMaximum as u8); + buffer.put_u16(topic_alias_max); + } + + if let Some(reason) = &self.reason_string { + buffer.put_u8(PropertyType::ReasonString as u8); + write_mqtt_string(buffer, reason); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + if let Some(w) = self.wildcard_subscription_available { + buffer.put_u8(PropertyType::WildcardSubscriptionAvailable as u8); + buffer.put_u8(w); + } + + if let Some(s) = self.subscription_identifiers_available { + buffer.put_u8(PropertyType::SubscriptionIdentifierAvailable as u8); + buffer.put_u8(s); + } + + if let Some(s) = self.shared_subscription_available { + buffer.put_u8(PropertyType::SharedSubscriptionAvailable as u8); + buffer.put_u8(s); + } + + if let Some(keep_alive) = self.server_keep_alive { + buffer.put_u8(PropertyType::ServerKeepAlive as u8); + buffer.put_u16(keep_alive); + } + + if let Some(info) = &self.response_information { + buffer.put_u8(PropertyType::ResponseInformation as u8); + write_mqtt_string(buffer, info); + } + + if let Some(reference) = &self.server_reference { + buffer.put_u8(PropertyType::ServerReference as u8); + write_mqtt_string(buffer, reference); + } + + if let Some(authentication_method) = &self.authentication_method { + buffer.put_u8(PropertyType::AuthenticationMethod as u8); + write_mqtt_string(buffer, authentication_method); + } + + if let Some(authentication_data) = &self.authentication_data { + buffer.put_u8(PropertyType::AuthenticationData as u8); + write_mqtt_bytes(buffer, authentication_data); + } + + Ok(()) + } + } + + /// Connection return code type + fn connect_return(num: u8) -> Result { + match num { + 0x00 => Ok(ConnectReturnCode::Success), + 0x80 => Ok(ConnectReturnCode::UnspecifiedError), + 0x81 => Ok(ConnectReturnCode::MalformedPacket), + 0x82 => Ok(ConnectReturnCode::ProtocolError), + 0x83 => Ok(ConnectReturnCode::ImplementationSpecificError), + 0x84 => Ok(ConnectReturnCode::UnsupportedProtocolVersion), + 0x85 => Ok(ConnectReturnCode::ClientIdentifierNotValid), + 0x86 => Ok(ConnectReturnCode::BadUserNamePassword), + 0x87 => Ok(ConnectReturnCode::NotAuthorized), + 0x88 => Ok(ConnectReturnCode::ServerUnavailable), + 0x89 => Ok(ConnectReturnCode::ServerBusy), + 0x8a => Ok(ConnectReturnCode::Banned), + 0x8c => Ok(ConnectReturnCode::BadAuthenticationMethod), + 0x90 => Ok(ConnectReturnCode::TopicNameInvalid), + 0x95 => Ok(ConnectReturnCode::PacketTooLarge), + 0x97 => Ok(ConnectReturnCode::QuotaExceeded), + 0x99 => Ok(ConnectReturnCode::PayloadFormatInvalid), + 0x9a => Ok(ConnectReturnCode::RetainNotSupported), + 0x9b => Ok(ConnectReturnCode::QoSNotSupported), + 0x9c => Ok(ConnectReturnCode::UseAnotherServer), + 0x9d => Ok(ConnectReturnCode::ServerMoved), + 0x94 => Ok(ConnectReturnCode::ConnectionRateExceeded), + num => Err(Error::InvalidConnectReturnCode(num)), + } + } +} + +pub(crate) mod publish { + use super::*; + use bytes::{BufMut, Bytes, BytesMut}; + + #[derive(Debug, Clone, PartialEq)] + pub struct Publish { + pub fixed_header: FixedHeader, + pub raw: Bytes, + } + + impl Publish { + // pub fn new, P: Into>>(topic: S, qos: QoS, payload: P) -> Publish { + // Publish { + // dup: false, + // qos, + // retain: false, + // pkid: 0, + // topic: topic.into(), + // payload: Bytes::from(payload.into()), + // } + // } + + // pub fn from_bytes>(topic: S, qos: QoS, payload: Bytes) -> Publish { + // Publish { + // dup: false, + // qos, + // retain: false, + // pkid: 0, + // topic: topic.into(), + // payload, + // } + // } + + // pub fn len(&self) -> usize { + // let mut len = 2 + self.topic.len(); + // if self.qos != QoS::AtMostOnce && self.pkid != 0 { + // len += 2; + // } + + // len += self.payload.len(); + // len + // } + + pub fn view_meta(&self) -> Result<(&str, u8, u16, bool, bool), Error> { + let qos = (self.fixed_header.byte1 & 0b0110) >> 1; + let dup = (self.fixed_header.byte1 & 0b1000) != 0; + let retain = (self.fixed_header.byte1 & 0b0001) != 0; + + // FIXME: Remove indexes and use get method + let stream = &self.raw[self.fixed_header.fixed_header_len..]; + let topic_len = view_u16(&stream)? as usize; + + let stream = &stream[2..]; + let topic = view_str(stream, topic_len)?; + + let pkid = match qos { + 0 => 0, + 1 => { + let stream = &stream[topic_len..]; + let pkid = view_u16(stream)?; + pkid + } + v => return Err(Error::InvalidQoS(v)), + }; + + if qos == 1 && pkid == 0 { + return Err(Error::PacketIdZero); + } + + Ok((topic, qos, pkid, dup, retain)) + } + + pub fn view_topic(&self) -> Result<&str, Error> { + // FIXME: Remove indexes + let stream = &self.raw[self.fixed_header.fixed_header_len..]; + let topic_len = view_u16(&stream)? as usize; + + let stream = &stream[2..]; + let topic = view_str(stream, topic_len)?; + Ok(topic) + } + + pub fn take_topic_and_payload(mut self) -> Result<(Bytes, Bytes), Error> { + let qos = (self.fixed_header.byte1 & 0b0110) >> 1; + + let variable_header_index = self.fixed_header.fixed_header_len; + self.raw.advance(variable_header_index); + let topic = read_mqtt_bytes(&mut self.raw)?; + + match qos { + 0 => (), + 1 => self.raw.advance(2), + v => return Err(Error::InvalidQoS(v)), + }; + + let payload = self.raw; + Ok((topic, payload)) + } + + pub fn read(fixed_header: FixedHeader, bytes: Bytes) -> Result { + let publish = Publish { + fixed_header, + raw: bytes, + }; + + Ok(publish) + } + } + + pub struct PublishBytes(pub Bytes); + + impl From for Result { + fn from(raw: PublishBytes) -> Self { + let fixed_header = check(raw.0.iter(), 100 * 1024 * 1024)?; + Ok(Publish { + fixed_header, + raw: raw.0, + }) + } + } + + pub fn write( + topic: &str, + qos: QoS, + pkid: u16, + dup: bool, + retain: bool, + payload: &[u8], + buffer: &mut BytesMut, + ) -> Result { + let mut len = 2 + topic.len(); + if qos != QoS::AtMostOnce { + len += 2; + } + + len += payload.len(); + + let dup = dup as u8; + let qos = qos as u8; + let retain = retain as u8; + + buffer.put_u8(0b0011_0000 | retain | qos << 1 | dup << 3); + + let count = write_remaining_length(buffer, len)?; + write_mqtt_string(buffer, topic); + + if qos != 0 { + if pkid == 0 { + return Err(Error::PacketIdZero); + } + + buffer.put_u16(pkid); + } + + buffer.extend_from_slice(&payload); + + // TODO: Returned length is wrong in other packets. Fix it + Ok(1 + count + len) + } +} + +pub(crate) mod puback { + use super::*; + use bytes::{Buf, BufMut, Bytes, BytesMut}; + + /// Acknowledgement to QoS1 publish + #[derive(Debug, Clone, PartialEq)] + pub struct PubAck { + pub pkid: u16, + pub reason: PubAckReason, + pub properties: Option, + } + + impl PubAck { + pub fn new(pkid: u16) -> PubAck { + PubAck { + pkid, + reason: PubAckReason::Success, + properties: None, + } + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + let pkid = read_u16(&mut bytes)?; + + // No reason code or properties if remaining length == 2 + if fixed_header.remaining_len == 2 { + return Ok(PubAck { + pkid, + reason: PubAckReason::Success, + properties: None, + }); + } + + // No properties len or properties if remaining len > 2 but < 4 + let ack_reason = read_u8(&mut bytes)?; + if fixed_header.remaining_len < 4 { + return Ok(PubAck { + pkid, + reason: reason(ack_reason)?, + properties: None, + }); + } + + let puback = PubAck { + pkid, + reason: reason(ack_reason)?, + properties: PubAckProperties::extract(&mut bytes)?, + }; + + Ok(puback) + } + } + + pub fn write( + pkid: u16, + reason: PubAckReason, + properties: Option, + buffer: &mut BytesMut, + ) -> Result { + buffer.put_u8(0x40); + + match &properties { + Some(properties) => { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + let len = 2 + 1 + properties_len_len + properties_len; + + let count = write_remaining_length(buffer, len)?; + buffer.put_u16(pkid); + buffer.put_u8(reason as u8); + properties.write(buffer)?; + + Ok(len + count + 1) + } + None => { + // Unlike other packets, property length can be ignored if there are + // no properties in acks + // + // TODO: maybe we should set len = 2 for PubAckReason == Success + let len = 2 + 1; + let count = write_remaining_length(buffer, len)?; + buffer.put_u16(pkid); + buffer.put_u8(reason as u8); + + Ok(len + count + 1) + } + } + } + + #[derive(Debug, Clone, PartialEq)] + pub struct PubAckProperties { + pub reason_string: Option, + pub user_properties: Vec<(String, String)>, + } + + /// Return code in connack + #[derive(Debug, Clone, Copy, PartialEq)] + #[repr(u8)] + pub enum PubAckReason { + Success = 0, + NoMatchingSubscribers = 16, + UnspecifiedError = 128, + ImplementationSpecificError = 131, + NotAuthorized = 135, + TopicNameInvalid = 144, + PacketIdentifierInUse = 145, + QuotaExceeded = 151, + PayloadFormatInvalid = 153, + } + + impl PubAckProperties { + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(reason) = &self.reason_string { + len += 1 + 2 + reason.len(); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut reason_string = None; + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::ReasonString => { + let bytes = read_mqtt_bytes(&mut bytes)?; + let reason = std::str::from_utf8(&bytes)?.to_owned(); + cursor += 2 + reason.len(); + reason_string = Some(reason); + } + PropertyType::UserProperty => { + let key = read_mqtt_bytes(&mut bytes)?; + let key = std::str::from_utf8(&key)?.to_owned(); + let value = read_mqtt_bytes(&mut bytes)?; + let value = std::str::from_utf8(&value)?.to_owned(); + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(PubAckProperties { + reason_string, + user_properties, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(reason) = &self.reason_string { + buffer.put_u8(PropertyType::ReasonString as u8); + write_mqtt_string(buffer, reason); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } + } + /// Connection return code type + fn reason(num: u8) -> Result { + let code = match num { + 0 => PubAckReason::Success, + 16 => PubAckReason::NoMatchingSubscribers, + 128 => PubAckReason::UnspecifiedError, + 131 => PubAckReason::ImplementationSpecificError, + 135 => PubAckReason::NotAuthorized, + 144 => PubAckReason::TopicNameInvalid, + 145 => PubAckReason::PacketIdentifierInUse, + 151 => PubAckReason::QuotaExceeded, + 153 => PubAckReason::PayloadFormatInvalid, + num => return Err(Error::InvalidConnectReturnCode(num)), + }; + + Ok(code) + } +} + +pub(crate) mod subscribe { + use super::*; + use bytes::{Buf, Bytes}; + + /// Subscription packet + #[derive(Clone, PartialEq)] + pub struct Subscribe { + pub pkid: u16, + pub filters: Vec, + pub properties: Option, + } + + impl Subscribe { + pub fn new>(path: S, qos: QoS) -> Subscribe { + let filter = SubscribeFilter { + path: path.into(), + qos, + nolocal: false, + preserve_retain: false, + retain_forward_rule: RetainForwardRule::OnEverySubscribe, + }; + + let mut filters = Vec::new(); + filters.push(filter); + Subscribe { + pkid: 0, + filters, + properties: None, + } + } + + pub fn add(&mut self, path: String, qos: QoS) -> &mut Self { + let filter = SubscribeFilter { + path, + qos, + nolocal: false, + preserve_retain: false, + retain_forward_rule: RetainForwardRule::OnEverySubscribe, + }; + + self.filters.push(filter); + self + } + + pub fn len(&self) -> usize { + let mut len = 2 + self.filters.iter().fold(0, |s, t| s + t.len()); + + if let Some(properties) = &self.properties { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } else { + // just 1 byte representing 0 len + len += 1; + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + let pkid = read_u16(&mut bytes)?; + let properties = SubscribeProperties::extract(&mut bytes)?; + + // variable header size = 2 (packet identifier) + let mut filters = Vec::new(); + + while bytes.has_remaining() { + let path = read_mqtt_bytes(&mut bytes)?; + let path = std::str::from_utf8(&path)?.to_owned(); + let options = read_u8(&mut bytes)?; + let requested_qos = options & 0b0000_0011; + + let nolocal = options >> 2 & 0b0000_0001; + let nolocal = if nolocal == 0 { false } else { true }; + + let preserve_retain = options >> 3 & 0b0000_0001; + let preserve_retain = if preserve_retain == 0 { false } else { true }; + + let retain_forward_rule = (options >> 4) & 0b0000_0011; + let retain_forward_rule = match retain_forward_rule { + 0 => RetainForwardRule::OnEverySubscribe, + 1 => RetainForwardRule::OnNewSubscribe, + 2 => RetainForwardRule::Never, + r => return Err(Error::InvalidRetainForwardRule(r)), + }; + + filters.push(SubscribeFilter { + path, + qos: qos(requested_qos)?, + nolocal, + preserve_retain, + retain_forward_rule, + }); + } + + let subscribe = Subscribe { + pkid, + filters, + properties, + }; + + Ok(subscribe) + } + } + + pub fn write( + filters: Vec, + pkid: u16, + properties: Option, + buffer: &mut BytesMut, + ) -> Result { + // write packet type + buffer.put_u8(0x82); + + // write remaining length + let mut len = 2 + filters.iter().fold(0, |s, t| s + t.len()); + + if let Some(properties) = &properties { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } else { + // just 1 byte representing 0 len + len += 1; + } + let remaining_len = len; + let remaining_len_bytes = write_remaining_length(buffer, remaining_len)?; + + // write packet id + buffer.put_u16(pkid); + + match &properties { + Some(properties) => properties.write(buffer)?, + None => { + write_remaining_length(buffer, 0)?; + } + }; + + // write filters + for filter in filters.iter() { + filter.write(buffer); + } + + Ok(1 + remaining_len_bytes + remaining_len) + } + + /// Subscription filter + #[derive(Clone, PartialEq)] + pub struct SubscribeFilter { + pub path: String, + pub qos: QoS, + pub nolocal: bool, + pub preserve_retain: bool, + pub retain_forward_rule: RetainForwardRule, + } + + impl SubscribeFilter { + pub fn new(path: String, qos: QoS) -> SubscribeFilter { + SubscribeFilter { + path, + qos, + nolocal: false, + preserve_retain: false, + retain_forward_rule: RetainForwardRule::OnEverySubscribe, + } + } + + pub fn len(&self) -> usize { + // filter len + filter + options + 2 + self.path.len() + 1 + } + + fn write(&self, buffer: &mut BytesMut) { + let mut options = 0; + options |= self.qos as u8; + + if self.nolocal { + options |= 1 << 2; + } + + if self.preserve_retain { + options |= 1 << 3; + } + + match self.retain_forward_rule { + RetainForwardRule::OnEverySubscribe => options |= 0 << 4, + RetainForwardRule::OnNewSubscribe => options |= 1 << 4, + RetainForwardRule::Never => options |= 2 << 4, + } + + write_mqtt_string(buffer, self.path.as_str()); + buffer.put_u8(options); + } + } + + #[derive(Debug, Clone, PartialEq)] + pub struct SubscribeProperties { + pub id: Option, + pub user_properties: Vec<(String, String)>, + } + + impl SubscribeProperties { + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(id) = &self.id { + len += 1 + len_len(*id); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut id = None; + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::SubscriptionIdentifier => { + let (id_len, sub_id) = length(bytes.iter())?; + // TODO: Validate 1 +. Tests are working either way + cursor += 1 + id_len; + bytes.advance(id_len); + id = Some(sub_id) + } + PropertyType::UserProperty => { + let key = read_mqtt_bytes(&mut bytes)?; + let key = std::str::from_utf8(&key)?.to_owned(); + let value = read_mqtt_bytes(&mut bytes)?; + let value = std::str::from_utf8(&value)?.to_owned(); + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(SubscribeProperties { + id, + user_properties, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(id) = &self.id { + buffer.put_u8(PropertyType::SubscriptionIdentifier as u8); + write_remaining_length(buffer, *id)?; + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } + } + + #[derive(Debug, Clone, PartialEq)] + pub enum RetainForwardRule { + OnEverySubscribe, + OnNewSubscribe, + Never, + } + + impl fmt::Debug for Subscribe { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "Filters = {:?}, Packet id = {:?}", + self.filters, self.pkid + ) + } + } + + impl fmt::Debug for SubscribeFilter { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "Filter = {}, Qos = {:?}, Nolocal = {}, Preserve retain = {}, Forward rule = {:?}", + self.path, self.qos, self.nolocal, self.preserve_retain, self.retain_forward_rule + ) + } + } +} + +pub(crate) mod suback { + use std::convert::{TryFrom, TryInto}; + + use super::*; + use bytes::{Buf, Bytes}; + + /// Acknowledgement to subscribe + #[derive(Debug, Clone, PartialEq)] + pub struct SubAck { + pub pkid: u16, + pub return_codes: Vec, + pub properties: Option, + } + + impl SubAck { + pub fn new(pkid: u16, return_codes: Vec) -> SubAck { + SubAck { + pkid, + return_codes, + properties: None, + } + } + + pub fn len(&self) -> usize { + let mut len = 2 + self.return_codes.len(); + + match &self.properties { + Some(properties) => { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + None => { + // just 1 byte representing 0 len + len += 1; + } + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + let pkid = read_u16(&mut bytes)?; + let properties = SubAckProperties::extract(&mut bytes)?; + + if !bytes.has_remaining() { + return Err(Error::MalformedPacket); + } + + let mut return_codes = Vec::new(); + while bytes.has_remaining() { + let return_code = read_u8(&mut bytes)?; + return_codes.push(return_code.try_into()?); + } + + let suback = SubAck { + pkid, + return_codes, + properties, + }; + + Ok(suback) + } + } + + pub fn write( + return_codes: Vec, + pkid: u16, + properties: Option, + buffer: &mut BytesMut, + ) -> Result { + buffer.put_u8(0x90); + + let mut len = 2 + return_codes.len(); + + match &properties { + Some(properties) => { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + None => { + // just 1 byte representing 0 len + len += 1; + } + } + + let remaining_len = len; + let remaining_len_bytes = write_remaining_length(buffer, remaining_len)?; + + buffer.put_u16(pkid); + + match &properties { + Some(properties) => properties.write(buffer)?, + None => { + write_remaining_length(buffer, 0)?; + } + }; + + let p: Vec = return_codes.iter().map(|code| *code as u8).collect(); + buffer.extend_from_slice(&p); + Ok(1 + remaining_len_bytes + remaining_len) + } + + #[derive(Debug, Clone, PartialEq)] + pub struct SubAckProperties { + pub reason_string: Option, + pub user_properties: Vec<(String, String)>, + } + + impl SubAckProperties { + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(reason) = &self.reason_string { + len += 1 + 2 + reason.len(); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut reason_string = None; + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::ReasonString => { + let bytes = read_mqtt_bytes(&mut bytes)?; + let reason = std::str::from_utf8(&bytes)?.to_owned(); + cursor += 2 + reason.len(); + reason_string = Some(reason); + } + PropertyType::UserProperty => { + let key = read_mqtt_bytes(&mut bytes)?; + let key = std::str::from_utf8(&key)?.to_owned(); + let value = read_mqtt_bytes(&mut bytes)?; + let value = std::str::from_utf8(&value)?.to_owned(); + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(SubAckProperties { + reason_string, + user_properties, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(reason) = &self.reason_string { + buffer.put_u8(PropertyType::ReasonString as u8); + write_mqtt_string(buffer, reason); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } + } + + #[derive(Debug, Clone, Copy, PartialEq, Eq)] + pub enum SubscribeReasonCode { + QoS0 = 0, + QoS1 = 1, + QoS2 = 2, + Unspecified = 128, + ImplementationSpecific = 131, + NotAuthorized = 135, + TopicFilterInvalid = 143, + PkidInUse = 145, + QuotaExceeded = 151, + SharedSubscriptionsNotSupported = 158, + SubscriptionIdNotSupported = 161, + WildcardSubscriptionsNotSupported = 162, + } + + impl TryFrom for SubscribeReasonCode { + type Error = Error; + + fn try_from(value: u8) -> Result { + let v = match value { + 0 => SubscribeReasonCode::QoS0, + 1 => SubscribeReasonCode::QoS1, + 2 => SubscribeReasonCode::QoS2, + 128 => SubscribeReasonCode::Unspecified, + 131 => SubscribeReasonCode::ImplementationSpecific, + 135 => SubscribeReasonCode::NotAuthorized, + 143 => SubscribeReasonCode::TopicFilterInvalid, + 145 => SubscribeReasonCode::PkidInUse, + 151 => SubscribeReasonCode::QuotaExceeded, + 158 => SubscribeReasonCode::SharedSubscriptionsNotSupported, + 161 => SubscribeReasonCode::SubscriptionIdNotSupported, + 162 => SubscribeReasonCode::WildcardSubscriptionsNotSupported, + v => return Err(Error::InvalidSubscribeReasonCode(v)), + }; + + Ok(v) + } + } + + pub fn codes(c: Vec) -> Vec { + c.into_iter() + .map(|v| match qos(v).unwrap() { + QoS::AtMostOnce => SubscribeReasonCode::QoS0, + QoS::AtLeastOnce => SubscribeReasonCode::QoS1, + }) + .collect() + } +} + +pub(crate) mod pingresp { + use super::*; + + pub fn write(payload: &mut BytesMut) -> Result { + payload.put_slice(&[0xD0, 0x00]); + Ok(2) + } +} + +/// Reads a stream of bytes and extracts next MQTT packet out of it +pub fn read_mut(stream: &mut BytesMut, max_size: usize) -> Result { + let fixed_header = check(stream.iter(), max_size)?; + + // Test with a stream with exactly the size to check border panics + let packet = stream.split_to(fixed_header.frame_length()); + let packet_type = fixed_header.packet_type()?; + + if fixed_header.remaining_len == 0 { + // no payload packets + return match packet_type { + PacketType::PingReq => Ok(Packet::PingReq), + PacketType::PingResp => Ok(Packet::PingResp), + PacketType::Disconnect => Ok(Packet::Disconnect), + _ => Err(Error::PayloadRequired), + }; + } + + let packet = packet.freeze(); + let packet = match packet_type { + PacketType::Connect => Packet::Connect(connect::Connect::read(fixed_header, packet)?), + PacketType::ConnAck => Packet::ConnAck(connack::ConnAck::read(fixed_header, packet)?), + PacketType::Publish => Packet::Publish(publish::Publish::read(fixed_header, packet)?), + PacketType::PubAck => Packet::PubAck(puback::PubAck::read(fixed_header, packet)?), + PacketType::Subscribe => { + Packet::Subscribe(subscribe::Subscribe::read(fixed_header, packet)?) + } + PacketType::SubAck => Packet::SubAck(suback::SubAck::read(fixed_header, packet)?), + PacketType::PingReq => Packet::PingReq, + PacketType::PingResp => Packet::PingResp, + PacketType::Disconnect => Packet::Disconnect, + v => return Err(Error::UnsupportedPacket(v)), + }; + + Ok(packet) +} + +/// Reads a stream of bytes and extracts next MQTT packet out of it +pub fn read(stream: &mut Bytes, max_size: usize) -> Result { + let fixed_header = check(stream.iter(), max_size)?; + + // Test with a stream with exactly the size to check border panics + let packet = stream.split_to(fixed_header.frame_length()); + let packet_type = fixed_header.packet_type()?; + + if fixed_header.remaining_len == 0 { + // no payload packets + return match packet_type { + PacketType::PingReq => Ok(Packet::PingReq), + PacketType::PingResp => Ok(Packet::PingResp), + PacketType::Disconnect => Ok(Packet::Disconnect), + _ => Err(Error::PayloadRequired), + }; + } + + let packet = match packet_type { + PacketType::Connect => Packet::Connect(connect::Connect::read(fixed_header, packet)?), + PacketType::ConnAck => Packet::ConnAck(connack::ConnAck::read(fixed_header, packet)?), + PacketType::Publish => Packet::Publish(publish::Publish::read(fixed_header, packet)?), + PacketType::PubAck => Packet::PubAck(puback::PubAck::read(fixed_header, packet)?), + PacketType::Subscribe => { + Packet::Subscribe(subscribe::Subscribe::read(fixed_header, packet)?) + } + PacketType::SubAck => Packet::SubAck(suback::SubAck::read(fixed_header, packet)?), + PacketType::PingReq => Packet::PingReq, + PacketType::PingResp => Packet::PingResp, + PacketType::Disconnect => Packet::Disconnect, + v => return Err(Error::UnsupportedPacket(v)), + }; + + Ok(packet) +} + +#[derive(Clone, Debug, PartialEq)] +pub enum Packet { + Connect(connect::Connect), + Publish(publish::Publish), + ConnAck(connack::ConnAck), + PubAck(puback::PubAck), + PingReq, + PingResp, + Subscribe(subscribe::Subscribe), + SubAck(suback::SubAck), + Disconnect, +} + +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum PropertyType { + PayloadFormatIndicator = 1, + MessageExpiryInterval = 2, + ContentType = 3, + ResponseTopic = 8, + CorrelationData = 9, + SubscriptionIdentifier = 11, + SessionExpiryInterval = 17, + AssignedClientIdentifier = 18, + ServerKeepAlive = 19, + AuthenticationMethod = 21, + AuthenticationData = 22, + RequestProblemInformation = 23, + WillDelayInterval = 24, + RequestResponseInformation = 25, + ResponseInformation = 26, + ServerReference = 28, + ReasonString = 31, + ReceiveMaximum = 33, + TopicAliasMaximum = 34, + TopicAlias = 35, + MaximumQos = 36, + RetainAvailable = 37, + UserProperty = 38, + MaximumPacketSize = 39, + WildcardSubscriptionAvailable = 40, + SubscriptionIdentifierAvailable = 41, + SharedSubscriptionAvailable = 42, +} + +fn property(num: u8) -> Result { + let property = match num { + 1 => PropertyType::PayloadFormatIndicator, + 2 => PropertyType::MessageExpiryInterval, + 3 => PropertyType::ContentType, + 8 => PropertyType::ResponseTopic, + 9 => PropertyType::CorrelationData, + 11 => PropertyType::SubscriptionIdentifier, + 17 => PropertyType::SessionExpiryInterval, + 18 => PropertyType::AssignedClientIdentifier, + 19 => PropertyType::ServerKeepAlive, + 21 => PropertyType::AuthenticationMethod, + 22 => PropertyType::AuthenticationData, + 23 => PropertyType::RequestProblemInformation, + 24 => PropertyType::WillDelayInterval, + 25 => PropertyType::RequestResponseInformation, + 26 => PropertyType::ResponseInformation, + 28 => PropertyType::ServerReference, + 31 => PropertyType::ReasonString, + 33 => PropertyType::ReceiveMaximum, + 34 => PropertyType::TopicAliasMaximum, + 35 => PropertyType::TopicAlias, + 36 => PropertyType::MaximumQos, + 37 => PropertyType::RetainAvailable, + 38 => PropertyType::UserProperty, + 39 => PropertyType::MaximumPacketSize, + 40 => PropertyType::WildcardSubscriptionAvailable, + 41 => PropertyType::SubscriptionIdentifierAvailable, + 42 => PropertyType::SharedSubscriptionAvailable, + num => return Err(Error::InvalidPropertyType(num)), + }; + + Ok(property) +} diff --git a/rumqttc/Cargo.toml b/rumqttc/Cargo.toml index 54730e835..a81f49373 100644 --- a/rumqttc/Cargo.toml +++ b/rumqttc/Cargo.toml @@ -14,14 +14,15 @@ all-features = true rustdoc-args = ["--cfg", "docsrs"] [features] +default = ["use-rustls"] websocket = ["async-tungstenite", "ws_stream_tungstenite"] +use-rustls = ["tokio-rustls", "rustls-pemfile"] [dependencies] tokio = { version = "1.0", features = ["rt", "macros", "io-util", "net", "time"] } bytes = "1.0" -webpki = "0.22.0" -tokio-rustls = "0.23.2" -rustls-pemfile = "0.3.0" +tokio-rustls = { version = "0.23.2", optional = true } +rustls-pemfile = { version = "0.3.0", optional = true } async-tungstenite = { version = "0.16.1", default-features = false, features = ["tokio-rustls-native-certs"], optional = true } ws_stream_tungstenite = { version = "0.7.0", default-features = false, features = ["tokio_io"], optional = true } pollster = "0.2" @@ -30,6 +31,7 @@ log = "0.4" thiserror = "1.0.21" http = "^0.2" url = { version = "2.2", default-features = false, optional = true } +flume = "0.10.10" [dev-dependencies] pretty_env_logger = "0.4" @@ -42,4 +44,4 @@ tokio = { version = "1.0", features = ["full", "macros"] } matches = "0.1.8" rustls = "0.20.2" rustls-native-certs = "0.6.1" -pretty_assertions = "0.6.1" +pretty_assertions = "1.1.0" diff --git a/rumqttc/examples/async_manual_acks_v5.rs b/rumqttc/examples/async_manual_acks_v5.rs new file mode 100644 index 000000000..fde55ad81 --- /dev/null +++ b/rumqttc/examples/async_manual_acks_v5.rs @@ -0,0 +1,81 @@ +#![allow(dead_code, unused_imports)] +use tokio::{task, time}; + +use rumqttc::v5::{AsyncClient, EventLoop, MqttOptions, QoS}; +use std::error::Error; +use std::time::Duration; + +fn create_conn() -> (AsyncClient, EventLoop) { + let mut mqttoptions = MqttOptions::new("test-1", "localhost", 1883); + mqttoptions + .set_keep_alive(Duration::from_secs(5)) + .set_manual_acks(true) + .set_clean_session(false); + + AsyncClient::new(mqttoptions, 10) +} + +#[tokio::main(worker_threads = 1)] +async fn main() -> Result<(), Box> { + todo!("fix this example with new way of spawning clients") + // pretty_env_logger::init(); + + // // create mqtt connection with clean_session = false and manual_acks = true + // let (client, mut eventloop) = create_conn(); + + // // subscribe example topic + // client + // .subscribe("hello/world", QoS::AtLeastOnce) + // .await + // .unwrap(); + + // task::spawn(async move { + // // send some messages to example topic and disconnect + // requests(client.clone()).await; + // client.disconnect().await.unwrap() + // }); + + // loop { + // // get subscribed messages without acking + // let event = eventloop.poll().await; + // println!("{:?}", event); + // if let Err(_err) = event { + // // break loop on disconnection + // break; + // } + // } + + // // create new broker connection + // let (_client, mut eventloop) = create_conn(); + + // loop { + // // previously published messages should be republished after reconnection. + // let event = eventloop.poll().await; + // println!("{:?}", event); + + // todo!("fix the commented out code below") + + // // match event { + // // Ok(Event::Incoming(Incoming::Publish(publish))) => { + // // // this time we will ack incoming publishes. + // // // Its important not to block eventloop as this can cause deadlock. + // // let c = client.clone(); + // // tokio::spawn(async move { + // // c.ack(&publish).await.unwrap(); + // // }); + // // } + // // _ => {} + // // } + // } +} + +async fn requests(mut client: AsyncClient) { + for i in 1..=10 { + client + .publish("hello/world", QoS::AtLeastOnce, false, vec![1; i]) + .await + .unwrap(); + + time::sleep(Duration::from_secs(1)).await; + } +} diff --git a/rumqttc/examples/asyncpubsub_v5.rs b/rumqttc/examples/asyncpubsub_v5.rs new file mode 100644 index 000000000..b398a5e3f --- /dev/null +++ b/rumqttc/examples/asyncpubsub_v5.rs @@ -0,0 +1,43 @@ +use tokio::{task, time}; + +use rumqttc::v5::{AsyncClient, MqttOptions, QoS}; +use std::error::Error; +use std::time::Duration; + +#[tokio::main(worker_threads = 1)] +async fn main() -> Result<(), Box> { + pretty_env_logger::init(); + // color_backtrace::install(); + + let mut mqttoptions = MqttOptions::new("test-1", "localhost", 1883); + mqttoptions.set_keep_alive(Duration::from_secs(5)); + + let (client, mut eventloop) = AsyncClient::new(mqttoptions, 10); + task::spawn(async move { + requests(client).await; + time::sleep(Duration::from_secs(3)).await; + }); + + loop { + let event = eventloop.poll().await; + println!("{:?}", event.unwrap()); + } +} + +async fn requests(mut client: AsyncClient) { + client + .subscribe("hello/world", QoS::AtMostOnce) + .await + .unwrap(); + + for i in 1..=10 { + client + .publish("hello/world", QoS::ExactlyOnce, false, vec![1; i]) + .await + .unwrap(); + + time::sleep(Duration::from_secs(1)).await; + } + + time::sleep(Duration::from_secs(120)).await; +} diff --git a/rumqttc/examples/syncpubsub_v5.rs b/rumqttc/examples/syncpubsub_v5.rs new file mode 100644 index 000000000..857ab4789 --- /dev/null +++ b/rumqttc/examples/syncpubsub_v5.rs @@ -0,0 +1,35 @@ +use rumqttc::v5::{Client, LastWill, MqttOptions, QoS}; +use std::thread; +use std::time::Duration; + +fn main() { + pretty_env_logger::init(); + + let mut mqttoptions = MqttOptions::new("test-1", "localhost", 1883); + let will = LastWill::new("hello/world", "good bye", QoS::AtMostOnce, false); + mqttoptions + .set_keep_alive(Duration::from_secs(5)) + .set_last_will(will); + + let (client, mut connection) = Client::new(mqttoptions, 10); + thread::spawn(move || publish(client)); + + for (i, notification) in connection.iter().enumerate() { + println!("{}. Notification = {:?}", i, notification); + } + + println!("Done with the stream!!"); +} + +fn publish(mut client: Client) { + client.subscribe("hello/+/world", QoS::AtMostOnce).unwrap(); + for i in 0..10 { + let payload = vec![1; i as usize]; + let topic = format!("hello/{}/world", i); + let qos = QoS::AtLeastOnce; + + client.publish(topic, qos, true, payload).unwrap(); + } + + thread::sleep(Duration::from_secs(1)); +} diff --git a/rumqttc/examples/tls.rs b/rumqttc/examples/tls.rs index 2bb1b2272..0e25860e3 100644 --- a/rumqttc/examples/tls.rs +++ b/rumqttc/examples/tls.rs @@ -1,11 +1,12 @@ //! Example of how to configure rumqttd to connect to a server using TLS and authentication. - -use rumqttc::{self, AsyncClient, Event, Incoming, MqttOptions, Transport}; -use rustls::ClientConfig; use std::error::Error; +#[cfg(feature = "use-rustls")] #[tokio::main] async fn main() -> Result<(), Box> { + use rumqttc::{self, AsyncClient, Event, Incoming, MqttOptions, Transport}; + use rustls::ClientConfig; + pretty_env_logger::init(); color_backtrace::install(); @@ -43,3 +44,8 @@ async fn main() -> Result<(), Box> { } } } + +#[cfg(not(feature = "use-rustls"))] +fn main() -> Result<(), Box> { + panic!("Enable feature 'use-rustls'"); +} diff --git a/rumqttc/examples/tls2.rs b/rumqttc/examples/tls2.rs index c6df58e85..496a806a0 100644 --- a/rumqttc/examples/tls2.rs +++ b/rumqttc/examples/tls2.rs @@ -1,10 +1,11 @@ //! Example of how to configure rumqttd to connect to a server using TLS and authentication. - -use rumqttc::{self, AsyncClient, Key, MqttOptions, TlsConfiguration, Transport}; use std::error::Error; +#[cfg(feature = "use-rustls")] #[tokio::main] async fn main() -> Result<(), Box> { + use rumqttc::{self, AsyncClient, Key, MqttOptions, TlsConfiguration, Transport}; + pretty_env_logger::init(); color_backtrace::install(); @@ -43,3 +44,8 @@ async fn main() -> Result<(), Box> { Ok(()) } + +#[cfg(not(feature = "use-rustls"))] +fn main() -> Result<(), Box> { + panic!("Enable feature 'use-rustls'"); +} diff --git a/rumqttc/src/eventloop.rs b/rumqttc/src/eventloop.rs index fb4d8cad5..e65619b3f 100644 --- a/rumqttc/src/eventloop.rs +++ b/rumqttc/src/eventloop.rs @@ -1,11 +1,14 @@ -use crate::{framed::Network, Transport}; -use crate::{tls, Incoming, MqttState, Packet, Request, StateError}; -use crate::{MqttOptions, Outgoing}; +use crate::framed::Network; +#[cfg(feature = "use-rustls")] +use crate::tls; +use crate::{Incoming, MqttOptions, MqttState, Outgoing, Packet, Request, StateError, Transport}; use crate::mqttbytes::v4::*; use async_channel::{bounded, Receiver, Sender}; #[cfg(feature = "websocket")] -use async_tungstenite::tokio::{connect_async, connect_async_with_tls_connector}; +use async_tungstenite::tokio::connect_async; +#[cfg(all(feature = "use-rustls", feature = "websocket"))] +use async_tungstenite::tokio::connect_async_with_tls_connector; use tokio::net::TcpStream; #[cfg(unix)] use tokio::net::UnixStream; @@ -31,6 +34,10 @@ pub enum ConnectionError { #[cfg(feature = "websocket")] #[error("Websocket: {0}")] Websocket(#[from] async_tungstenite::tungstenite::error::Error), + #[cfg(feature = "websocket")] + #[error("Websocket Connect: {0}")] + WsConnect(#[from] http::Error), + #[cfg(feature = "use-rustls")] #[error("TLS: {0}")] Tls(#[from] tls::Error), #[error("I/O: {0}")] @@ -276,6 +283,7 @@ async fn network_connect(options: &MqttOptions) -> Result { let socket = tls::tls_connect(options, &tls_config).await?; Network::new(socket, options.max_incoming_packet_size) @@ -292,21 +300,19 @@ async fn network_connect(options: &MqttOptions) -> Result { let request = http::Request::builder() .method(http::Method::GET) .uri(options.broker_addr.as_str()) .header("Sec-WebSocket-Protocol", "mqttv3.1") - .body(()) - .unwrap(); + .body(())?; let connector = tls::tls_connector(&tls_config).await?; diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index 364f651fd..a124168d3 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -100,6 +100,7 @@ extern crate log; use std::fmt::{self, Debug, Formatter}; +#[cfg(feature = "use-rustls")] use std::sync::Arc; use std::time::Duration; @@ -108,7 +109,9 @@ mod eventloop; mod framed; pub mod mqttbytes; mod state; +#[cfg(feature = "use-rustls")] mod tls; +pub mod v5; pub use async_channel::{SendError, Sender, TrySendError}; pub use client::{AsyncClient, Client, ClientError, Connection}; @@ -116,7 +119,9 @@ pub use eventloop::{ConnectionError, Event, EventLoop}; pub use mqttbytes::v4::*; pub use mqttbytes::*; pub use state::{MqttState, StateError}; -pub use tls::Error; +#[cfg(feature = "use-rustls")] +pub use tls::Error as TlsError; +#[cfg(feature = "use-rustls")] pub use tokio_rustls::rustls::ClientConfig; pub type Incoming = Packet; @@ -194,14 +199,15 @@ impl From for Request { #[derive(Clone)] pub enum Transport { Tcp, + #[cfg(feature = "use-rustls")] Tls(TlsConfiguration), #[cfg(unix)] Unix, #[cfg(feature = "websocket")] #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] Ws, - #[cfg(feature = "websocket")] - #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + #[cfg(all(feature = "use-rustls", feature = "websocket"))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "use-rustls", feature = "websocket"))))] Wss(TlsConfiguration), } @@ -218,6 +224,7 @@ impl Transport { } /// Use secure tcp with tls as transport + #[cfg(feature = "use-rustls")] pub fn tls( ca: Vec, client_auth: Option<(Vec, Key)>, @@ -232,6 +239,7 @@ impl Transport { Self::tls_with_config(config) } + #[cfg(feature = "use-rustls")] pub fn tls_with_config(tls_config: TlsConfiguration) -> Self { Self::Tls(tls_config) } @@ -249,8 +257,8 @@ impl Transport { } /// Use secure websockets with tls as transport - #[cfg(feature = "websocket")] - #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + #[cfg(all(feature = "use-rustls", feature = "websocket"))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "use-rustls", feature = "websocket"))))] pub fn wss( ca: Vec, client_auth: Option<(Vec, Key)>, @@ -265,14 +273,15 @@ impl Transport { Self::wss_with_config(config) } - #[cfg(feature = "websocket")] - #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + #[cfg(all(feature = "use-rustls", feature = "websocket"))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "use-rustls", feature = "websocket"))))] pub fn wss_with_config(tls_config: TlsConfiguration) -> Self { Self::Wss(tls_config) } } #[derive(Clone)] +#[cfg(feature = "use-rustls")] pub enum TlsConfiguration { Simple { /// connection method @@ -286,6 +295,7 @@ pub enum TlsConfiguration { Rustls(Arc), } +#[cfg(feature = "use-rustls")] impl From for TlsConfiguration { fn from(config: ClientConfig) -> Self { TlsConfiguration::Rustls(Arc::new(config)) @@ -715,7 +725,7 @@ mod test { } #[test] - #[cfg(feature = "websocket")] + #[cfg(all(feature = "use-rustls", feature = "websocket"))] fn no_scheme() { let mut _mqtt_opts = MqttOptions::new("client_a", "a3f8czas.iot.eu-west-1.amazonaws.com/mqtt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=MyCreds%2F20201001%2Feu-west-1%2Fiotdevicegateway%2Faws4_request&X-Amz-Date=20201001T130812Z&X-Amz-Expires=7200&X-Amz-Signature=9ae09b49896f44270f2707551581953e6cac71a4ccf34c7c3415555be751b2d1&X-Amz-SignedHeaders=host", 443); diff --git a/rumqttc/src/v5/client/asyncclient.rs b/rumqttc/src/v5/client/asyncclient.rs new file mode 100644 index 000000000..fb39f8d6a --- /dev/null +++ b/rumqttc/src/v5/client/asyncclient.rs @@ -0,0 +1,312 @@ +use std::sync::{Arc, Mutex}; + +use bytes::Bytes; +use flume::{SendError, Sender, TrySendError}; + +use crate::v5::{ + client::get_ack_req, + outgoing_buf::OutgoingBuf, + packet::{Publish, Subscribe, SubscribeFilter, Unsubscribe}, + ClientError, EventLoop, MqttOptions, QoS, Request, +}; + +/// `AsyncClient` to communicate with MQTT `Eventloop` +/// This is cloneable and can be used to asynchronously Publish, Subscribe. +#[derive(Debug)] +pub struct AsyncClient { + pub(crate) outgoing_buf: Arc>, + pub(crate) request_tx: Sender<()>, +} + +impl AsyncClient { + /// Create a new `AsyncClient` + pub fn new(options: MqttOptions, cap: usize) -> (AsyncClient, EventLoop) { + let eventloop = EventLoop::new(options, cap); + let outgoing_buf = eventloop.state.outgoing_buf.clone(); + let request_tx = eventloop.handle(); + + let client = AsyncClient { + outgoing_buf, + request_tx, + }; + + (client, eventloop) + } + + /// Sends a MQTT Publish to the eventloop + pub async fn publish( + &mut self, + topic: S, + qos: QoS, + retain: bool, + payload: V, + ) -> Result + where + S: Into, + V: Into>, + { + let mut publish = Publish::new(topic, qos, payload); + publish.retain = retain; + let pkid = if qos != QoS::AtMostOnce { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + publish.pkid = pkid; + request_buf.buf.push_back(Request::Publish(publish)); + pkid + } else { + 0 + }; + self.notify_async().await?; + Ok(pkid) + } + + /// Sends a MQTT Publish to the eventloop + pub fn try_publish( + &mut self, + topic: S, + qos: QoS, + retain: bool, + payload: V, + ) -> Result + where + S: Into, + V: Into>, + { + let mut publish = Publish::new(topic, qos, payload); + publish.retain = retain; + let pkid = if qos != QoS::AtMostOnce { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + publish.pkid = pkid; + request_buf.buf.push_back(Request::Publish(publish)); + pkid + } else { + 0 + }; + self.try_notify()?; + Ok(pkid) + } + + /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. + pub async fn ack(&mut self, publish: &Publish) -> Result<(), ClientError> { + if let Some(ack) = get_ack_req(publish.qos, publish.pkid) { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + request_buf.buf.push_back(ack); + self.notify_async().await?; + } + Ok(()) + } + + /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. + pub fn try_ack(&mut self, publish: &Publish) -> Result<(), ClientError> { + if let Some(ack) = get_ack_req(publish.qos, publish.pkid) { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + request_buf.buf.push_back(ack); + self.try_notify()?; + } + Ok(()) + } + + /// Sends a MQTT Publish to the eventloop + pub async fn publish_bytes( + &mut self, + topic: S, + qos: QoS, + retain: bool, + payload: Bytes, + ) -> Result + where + S: Into, + { + let mut publish = Publish::from_bytes(topic, qos, payload); + publish.retain = retain; + let pkid = if qos != QoS::AtMostOnce { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + publish.pkid = pkid; + request_buf.buf.push_back(Request::Publish(publish)); + pkid + } else { + 0 + }; + self.notify_async().await?; + Ok(pkid) + } + + /// Sends a MQTT Subscribe to the eventloop + pub async fn subscribe>( + &mut self, + topic: S, + qos: QoS, + ) -> Result { + let mut subscribe = Subscribe::new(topic.into(), qos); + let pkid = { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + subscribe.pkid = pkid; + request_buf.buf.push_back(Request::Subscribe(subscribe)); + pkid + }; + self.notify_async().await?; + Ok(pkid) + } + + /// Sends a MQTT Subscribe to the eventloop + pub fn try_subscribe>( + &mut self, + topic: S, + qos: QoS, + ) -> Result { + let mut subscribe = Subscribe::new(topic.into(), qos); + let pkid = { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + subscribe.pkid = pkid; + request_buf.buf.push_back(Request::Subscribe(subscribe)); + pkid + }; + self.try_notify()?; + Ok(pkid) + } + + /// Sends a MQTT Subscribe for multiple topics to the eventloop + pub async fn subscribe_many(&mut self, topics: T) -> Result + where + T: IntoIterator, + { + let mut subscribe = Subscribe::new_many(topics); + let pkid = { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + subscribe.pkid = pkid; + request_buf.buf.push_back(Request::Subscribe(subscribe)); + pkid + }; + self.notify_async().await?; + Ok(pkid) + } + + /// Sends a MQTT Subscribe for multiple topics to the eventloop + pub fn try_subscribe_many(&mut self, topics: T) -> Result + where + T: IntoIterator, + { + let mut subscribe = Subscribe::new_many(topics); + let pkid = { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + subscribe.pkid = pkid; + request_buf.buf.push_back(Request::Subscribe(subscribe)); + pkid + }; + self.try_notify()?; + Ok(pkid) + } + + /// Sends a MQTT Unsubscribe to the eventloop + pub async fn unsubscribe>(&mut self, topic: S) -> Result { + let mut unsubscribe = Unsubscribe::new(topic.into()); + let pkid = { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + unsubscribe.pkid = pkid; + request_buf.buf.push_back(Request::Unsubscribe(unsubscribe)); + pkid + }; + self.notify_async().await?; + Ok(pkid) + } + + /// Sends a MQTT Unsubscribe to the eventloop + pub fn try_unsubscribe>(&mut self, topic: S) -> Result { + let mut unsubscribe = Unsubscribe::new(topic.into()); + let pkid = { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + unsubscribe.pkid = pkid; + request_buf.buf.push_back(Request::Unsubscribe(unsubscribe)); + pkid + }; + self.try_notify()?; + Ok(pkid) + } + + /// Sends a MQTT disconnect to the eventloop + #[inline] + pub async fn disconnect(&mut self) -> Result<(), ClientError> { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + request_buf.buf.push_back(Request::Disconnect); + self.notify_async().await + } + + /// Sends a MQTT disconnect to the eventloop + #[inline] + pub fn try_disconnect(&mut self) -> Result<(), ClientError> { + let mut request_buf = self.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + request_buf.buf.push_back(Request::Disconnect); + self.try_notify() + } + + #[inline] + async fn notify_async(&self) -> Result<(), ClientError> { + if let Err(SendError(_)) = self.request_tx.send_async(()).await { + return Err(ClientError::EventloopClosed); + }; + Ok(()) + } + + #[inline] + pub(crate) fn notify(&self) -> Result<(), ClientError> { + if let Err(SendError(_)) = self.request_tx.send(()) { + return Err(ClientError::EventloopClosed); + }; + Ok(()) + } + + #[inline] + fn try_notify(&self) -> Result<(), ClientError> { + if let Err(TrySendError::Disconnected(_)) = self.request_tx.try_send(()) { + return Err(ClientError::EventloopClosed); + } + Ok(()) + } +} diff --git a/rumqttc/src/v5/client/mod.rs b/rumqttc/src/v5/client/mod.rs new file mode 100644 index 000000000..f20efe87e --- /dev/null +++ b/rumqttc/src/v5/client/mod.rs @@ -0,0 +1,97 @@ +//! This module offers a high level synchronous and asynchronous abstraction to +//! async eventloop. +use crate::v5::{packet::*, ConnectionError, EventLoop, Request}; + +use flume::SendError; +use std::mem; +use tokio::runtime::{self, Runtime}; + +mod asyncclient; +pub use asyncclient::AsyncClient; +mod syncclient; +pub use syncclient::Client; + +/// Client Error +#[derive(Debug, thiserror::Error)] +pub enum ClientError { + #[error("Failed to send cancel request to eventloop")] + Cancel(SendError<()>), + #[error("Failed to send mqtt request to eventloop, the evenloop has been closed")] + EventloopClosed, + #[error("Failed to send mqtt request to evenloop, to requests buffer is full right now")] + RequestsFull, + #[error("Serialization error")] + Mqtt5(Error), +} + +fn get_ack_req(qos: QoS, pkid: u16) -> Option { + let ack = match qos { + QoS::AtMostOnce => return None, + QoS::AtLeastOnce => Request::PubAck(PubAck::new(pkid)), + QoS::ExactlyOnce => Request::PubRec(PubRec::new(pkid)), + }; + Some(ack) +} + +/// MQTT connection. Maintains all the necessary state +pub struct Connection { + pub eventloop: EventLoop, + runtime: Option, +} + +impl Connection { + fn new(eventloop: EventLoop, runtime: Runtime) -> Connection { + Connection { + eventloop, + runtime: Some(runtime), + } + } + + /// Returns an iterator over this connection. Iterating over this is all that's + /// necessary to make connection progress and maintain a robust connection. + /// Just continuing to loop will reconnect + /// **NOTE** Don't block this while iterating + #[must_use = "Connection should be iterated over a loop to make progress"] + pub fn iter(&mut self) -> Iter { + let runtime = self.runtime.take().unwrap(); + Iter { + connection: self, + runtime, + } + } +} + +/// Iterator which polls the eventloop for connection progress +pub struct Iter<'a> { + connection: &'a mut Connection, + runtime: runtime::Runtime, +} + +impl<'a> Iterator for Iter<'a> { + type Item = Result<(), ConnectionError>; + + fn next(&mut self) -> Option { + let f = self.connection.eventloop.poll(); + match self.runtime.block_on(f) { + Ok(_) => Some(Ok(())), + // closing of request channel should stop the iterator + Err(ConnectionError::RequestsDone) => { + trace!("Done with requests"); + None + } + Err(ConnectionError::Cancel) => { + trace!("Cancellation request received"); + None + } + Err(e) => Some(Err(e)), + } + } +} + +impl<'a> Drop for Iter<'a> { + fn drop(&mut self) { + // TODO: Don't create new runtime in drop + let runtime = runtime::Builder::new_current_thread().build().unwrap(); + self.connection.runtime = Some(mem::replace(&mut self.runtime, runtime)); + } +} diff --git a/rumqttc/src/v5/client/syncclient.rs b/rumqttc/src/v5/client/syncclient.rs new file mode 100644 index 000000000..425ad4e03 --- /dev/null +++ b/rumqttc/src/v5/client/syncclient.rs @@ -0,0 +1,184 @@ +use tokio::runtime; + +use crate::v5::{ + client::get_ack_req, + packet::{Publish, Subscribe, SubscribeFilter, Unsubscribe}, + AsyncClient, ClientError, Connection, MqttOptions, QoS, Request, +}; + +/// `Client` to communicate with MQTT eventloop `Connection`. +/// +/// Client is cloneable and can be used to synchronously Publish, Subscribe. +/// Asynchronous channel handle can also be extracted if necessary +pub struct Client { + client: AsyncClient, +} + +impl Client { + /// Create a new `Client` + pub fn new(options: MqttOptions, cap: usize) -> (Client, Connection) { + let (client, eventloop) = AsyncClient::new(options, cap); + let client = Client { client }; + let runtime = runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + let connection = Connection::new(eventloop, runtime); + (client, connection) + } + + /// Sends a MQTT Publish to the eventloop + pub fn publish( + &mut self, + topic: S, + qos: QoS, + retain: bool, + payload: V, + ) -> Result + where + S: Into, + V: Into>, + { + let mut publish = Publish::new(topic, qos, payload); + publish.retain = retain; + let pkid = if qos != QoS::AtMostOnce { + let mut request_buf = self.client.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + publish.pkid = pkid; + request_buf.buf.push_back(Request::Publish(publish)); + pkid + } else { + 0 + }; + self.client.notify()?; + Ok(pkid) + } + + pub fn try_publish( + &mut self, + topic: S, + qos: QoS, + retain: bool, + payload: V, + ) -> Result + where + S: Into, + V: Into>, + { + self.client.try_publish(topic, qos, retain, payload) + } + + /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. + pub fn ack(&self, publish: &Publish) -> Result<(), ClientError> { + if let Some(ack) = get_ack_req(publish.qos, publish.pkid) { + let mut request_buf = self.client.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + request_buf.buf.push_back(ack); + self.client.notify()?; + } + Ok(()) + } + + /// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set. + pub fn try_ack(&mut self, publish: &Publish) -> Result<(), ClientError> { + self.client.try_ack(publish) + } + + /// Sends a MQTT Subscribe to the eventloop + pub fn subscribe>(&mut self, topic: S, qos: QoS) -> Result { + let mut subscribe = Subscribe::new(topic.into(), qos); + let pkid = if qos != QoS::AtMostOnce { + let mut request_buf = self.client.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + subscribe.pkid = pkid; + request_buf.buf.push_back(Request::Subscribe(subscribe)); + pkid + } else { + 0 + }; + self.client.notify()?; + Ok(pkid) + } + + /// Sends a MQTT Subscribe to the eventloop + pub fn try_subscribe>( + &mut self, + topic: S, + qos: QoS, + ) -> Result { + self.client.try_subscribe(topic, qos) + } + + /// Sends a MQTT Subscribe for multiple topics to the eventloop + pub fn subscribe_many(&mut self, topics: T) -> Result + where + T: IntoIterator, + { + let mut subscribe = Subscribe::new_many(topics); + let pkid = { + let mut request_buf = self.client.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + subscribe.pkid = pkid; + request_buf.buf.push_back(Request::Subscribe(subscribe)); + pkid + }; + self.client.notify()?; + Ok(pkid) + } + + pub fn try_subscribe_many(&mut self, topics: T) -> Result + where + T: IntoIterator, + { + self.client.try_subscribe_many(topics) + } + + /// Sends a MQTT Unsubscribe to the eventloop + pub fn unsubscribe>(&mut self, topic: S) -> Result { + let mut unsubscribe = Unsubscribe::new(topic.into()); + let pkid = { + let mut request_buf = self.client.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + let pkid = request_buf.increment_pkid(); + unsubscribe.pkid = pkid; + request_buf.buf.push_back(Request::Unsubscribe(unsubscribe)); + pkid + }; + self.client.notify()?; + Ok(pkid) + } + + /// Sends a MQTT Unsubscribe to the eventloop + pub fn try_unsubscribe>(&mut self, topic: S) -> Result { + self.client.try_unsubscribe(topic) + } + + /// Sends a MQTT disconnect to the eventloop + pub fn disconnect(&mut self) -> Result<(), ClientError> { + let mut request_buf = self.client.outgoing_buf.lock().unwrap(); + if request_buf.buf.len() == request_buf.capacity { + return Err(ClientError::RequestsFull); + } + request_buf.buf.push_back(Request::Disconnect); + self.client.notify() + } + + /// Sends a MQTT disconnect to the eventloop + pub fn try_disconnect(&mut self) -> Result<(), ClientError> { + self.client.try_disconnect() + } +} diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs new file mode 100644 index 000000000..7725c373f --- /dev/null +++ b/rumqttc/src/v5/eventloop.rs @@ -0,0 +1,361 @@ +#[cfg(feature = "use-rustls")] +use crate::v5::tls; +use crate::v5::{ + framed::Network, outgoing_buf::OutgoingBuf, packet::*, Incoming, MqttOptions, MqttState, + Packet, Request, StateError, Transport, +}; + +#[cfg(feature = "websocket")] +use async_tungstenite::tokio::{connect_async, connect_async_with_tls_connector}; +use flume::{bounded, Receiver, Sender}; +use tokio::net::TcpStream; +#[cfg(unix)] +use tokio::net::UnixStream; +use tokio::select; +use tokio::time::{self, error::Elapsed, Instant, Sleep}; +#[cfg(feature = "websocket")] +use ws_stream_tungstenite::WsStream; + +#[cfg(unix)] +use std::path::Path; +use std::{ + collections::VecDeque, + io, + pin::Pin, + sync::{Arc, Mutex}, + time::Duration, + vec::IntoIter, +}; + +/// Critical errors during eventloop polling +#[derive(Debug, thiserror::Error)] +pub enum ConnectionError { + #[error("Mqtt state: {0}")] + MqttState(#[from] StateError), + #[error("Timeout")] + Timeout(#[from] Elapsed), + #[error("Packet parsing error: {0}")] + Mqtt5Bytes(Error), + #[cfg(feature = "use-rustls")] + #[error("Network: {0}")] + Network(#[from] tls::Error), + #[error("I/O: {0}")] + Io(#[from] io::Error), + #[error("Stream done")] + StreamDone, + #[error("Requests done")] + RequestsDone, + #[error("Cancel request by the user")] + Cancel, +} + +/// Eventloop with all the state of a connection +pub struct EventLoop { + /// Options of the current mqtt connection + pub options: MqttOptions, + /// Current state of the connection + pub state: MqttState, + outgoing_buf: Arc>, + outgoing_buf_cache: VecDeque, + /// Request stream + pub incoming_rx: Receiver<()>, + /// Requests handle to send requests + pub incoming_tx: Sender<()>, + /// Pending packets from last session + pub pending: IntoIter, + /// Network connection to the broker + pub(crate) network: Option, + /// Keep alive time + pub(crate) keepalive_timeout: Option>>, +} + +impl EventLoop { + /// New MQTT `EventLoop` + /// + /// When connection encounters critical errors (like auth failure), user has a choice to + /// access and update `options`, `state` and `requests`. + pub fn new(options: MqttOptions, cap: usize) -> EventLoop { + let (incoming_tx, incoming_rx) = bounded(1); + let pending = Vec::new(); + let pending = pending.into_iter(); + let max_inflight = options.inflight; + let manual_acks = options.manual_acks; + let state = MqttState::new(max_inflight, manual_acks, cap); + let outgoing_buf = state.outgoing_buf.clone(); + + EventLoop { + options, + state, + outgoing_buf, + outgoing_buf_cache: VecDeque::with_capacity(cap), + incoming_tx, + incoming_rx, + pending, + network: None, + keepalive_timeout: None, + } + } + + /// Returns a handle to communicate with this eventloop + #[inline] + pub fn handle(&self) -> Sender<()> { + self.incoming_tx.clone() + } + + fn clean(&mut self) { + self.network = None; + self.keepalive_timeout = None; + let pending = self.state.clean(); + self.pending = pending.into_iter(); + } + + /// Yields Next notification or outgoing request and periodically pings + /// the broker. Continuing to poll will reconnect to the broker if there is + /// a disconnection. + /// **NOTE** Don't block this while iterating + pub async fn poll(&mut self) -> Result<(), ConnectionError> { + if self.network.is_none() { + let (network, _connack) = connect(&self.options).await?; + self.network = Some(network); + + if self.keepalive_timeout.is_none() { + self.keepalive_timeout = Some(Box::pin(time::sleep(self.options.keep_alive))); + } + + return Ok(()); + } + + if let Err(e) = self.select().await { + self.clean(); + return Err(e); + } + + Ok(()) + } + + /// Select on network and requests and generate keepalive pings when necessary + async fn select(&mut self) -> Result<(), ConnectionError> { + let network = self.network.as_mut().unwrap(); + // let await_acks = self.state.await_acks; + let inflight_full = self.state.inflight >= self.options.inflight; + let throttle = self.options.pending_throttle; + let pending = self.pending.len() > 0; + let collision = self.state.collision.is_some(); + + // this loop is necessary as self.request_buf might be empty, in which case it is possible + // for self.state.events to be empty, and so popping off from it might return None. If None + // is returned, we select again. + loop { + select! { + // Pull a bunch of packets from network, reply in bunch and yield the first item + o = network.readb(&mut self.state) => { + o?; + // flush all the acks and return first incoming packet + network.flush(&mut self.state.write).await?; + return Ok(()); + }, + // Pull next request from user requests channel. + // If conditions in the below branch are for flow control. We read next user + // request only when inflight messages are < configured inflight and there are no + // collisions while handling previous outgoing requests. + // + // Flow control is based on ack count. If inflight packet count in the buffer is + // less than max_inflight setting, next outgoing request will progress. For this to + // work correctly, broker should ack in sequence (a lot of brokers won't) + // + // E.g If max inflight = 5, user requests will be blocked when inflight queue looks + // like this -> [1, 2, 3, 4, 5]. + // If broker acking 2 instead of 1 -> [1, x, 3, 4, 5]. + // This pulls next user request. But because max packet id = max_inflight, next + // user request's packet id will roll to 1. This replaces existing packet id 1. + // Resulting in a collision + // + // Eventloop can stop receiving outgoing user requests when previous outgoing + // request collided. I.e collision state. Collision state will be cleared only + // when correct ack is received + // Full inflight queue will look like -> [1a, 2, 3, 4, 5]. + // If 3 is acked instead of 1 first -> [1a, 2, x, 4, 5]. + // After collision with pkid 1 -> [1b ,2, x, 4, 5]. + // 1a is saved to state and event loop is set to collision mode stopping new + // outgoing requests (along with 1b). + o = self.incoming_rx.recv_async(), if !inflight_full && !pending && !collision => match o { + Ok(_request_notif) => { + // swapping to avoid blocking the mutex + std::mem::swap(&mut self.outgoing_buf_cache, &mut self.outgoing_buf.lock().unwrap().buf); + if self.outgoing_buf_cache.is_empty() { + continue; + } + for request in self.outgoing_buf_cache.drain(..) { + self.state.handle_outgoing_packet(request)?; + } + network.flush(&mut self.state.write).await?; + // remaining events in the self.state.events will be taken out in next call + // to poll() even before the select! is used. + return Ok(()) + } + Err(_) => return Err(ConnectionError::RequestsDone), + }, + // Handle the next pending packet from previous session. Disable + // this branch when done with all the pending packets + Some(request) = next_pending(throttle, &mut self.pending), if pending => { + self.state.handle_outgoing_packet(request)?; + network.flush(&mut self.state.write).await?; + return Ok(()) + }, + // We generate pings irrespective of network activity. This keeps the ping logic + // simple. We can change this behavior in future if necessary (to prevent extra pings) + _ = self.keepalive_timeout.as_mut().unwrap() => { + let timeout = self.keepalive_timeout.as_mut().unwrap(); + timeout.as_mut().reset(Instant::now() + self.options.keep_alive); + + self.state.handle_outgoing_packet(Request::PingReq)?; + network.flush(&mut self.state.write).await?; + return Ok(()) + } + } + } + } +} + +/// This stream internally processes requests from the request stream provided to the eventloop +/// while also consuming byte stream from the network and yielding mqtt packets as the output of +/// the stream. +/// This function (for convenience) includes internal delays for users to perform internal sleeps +/// between re-connections so that cancel semantics can be used during this sleep +async fn connect(options: &MqttOptions) -> Result<(Network, Incoming), ConnectionError> { + // connect to the broker + let mut network = match network_connect(options).await { + Ok(network) => network, + Err(e) => { + return Err(e); + } + }; + + // make MQTT connection request (which internally awaits for ack) + let packet = match mqtt_connect(options, &mut network).await { + Ok(p) => p, + Err(e) => return Err(e), + }; + + // Last session might contain packets which aren't acked. MQTT says these packets should be + // republished in the next session + // move pending messages from state to eventloop + // let pending = self.state.clean(); + // self.pending = pending.into_iter(); + Ok((network, packet)) +} + +async fn network_connect(options: &MqttOptions) -> Result { + let network = match options.transport() { + Transport::Tcp => { + let addr = options.broker_addr.as_str(); + let port = options.port; + let socket = TcpStream::connect((addr, port)).await?; + Network::new(socket, options.max_incoming_packet_size) + } + #[cfg(feature = "use-rustls")] + Transport::Tls(tls_config) => { + let socket = tls::tls_connect(options, &tls_config).await?; + Network::new(socket, options.max_incoming_packet_size) + } + #[cfg(unix)] + Transport::Unix => { + let file = options.broker_addr.as_str(); + let socket = UnixStream::connect(Path::new(file)).await?; + Network::new(socket, options.max_incoming_packet_size) + } + #[cfg(feature = "websocket")] + Transport::Ws => { + let request = http::Request::builder() + .method(http::Method::GET) + .uri(options.broker_addr.as_str()) + .header("Sec-WebSocket-Protocol", "mqttv3.1") + .body(()) + .unwrap(); + + let (socket, _) = connect_async(request) + .await + .map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, e))?; + + Network::new(WsStream::new(socket), options.max_incoming_packet_size) + } + #[cfg(feature = "websocket")] + Transport::Wss(tls_config) => { + let request = http::Request::builder() + .method(http::Method::GET) + .uri(options.broker_addr.as_str()) + .header("Sec-WebSocket-Protocol", "mqttv3.1") + .body(()) + .unwrap(); + + let connector = tls::tls_connector(&tls_config).await?; + + let (socket, _) = connect_async_with_tls_connector(request, Some(connector)) + .await + .map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, e))?; + + Network::new(WsStream::new(socket), options.max_incoming_packet_size) + } + }; + + Ok(network) +} + +async fn mqtt_connect( + options: &MqttOptions, + network: &mut Network, +) -> Result { + let keep_alive = options.keep_alive().as_secs() as u16; + let clean_session = options.clean_session(); + let last_will = options.last_will(); + + let mut connect = Connect::new(options.client_id()); + connect.keep_alive = keep_alive; + connect.clean_session = clean_session; + connect.last_will = last_will; + + if let Some((username, password)) = options.credentials() { + let login = Login::new(username, password); + connect.login = Some(login); + } + + // mqtt connection with timeout + time::timeout(Duration::from_secs(options.connection_timeout()), async { + network.connect(connect).await?; + Ok::<_, ConnectionError>(()) + }) + .await??; + + // wait for 'timeout' time to validate connack + let packet = time::timeout(Duration::from_secs(options.connection_timeout()), async { + let packet = match network.read().await? { + Incoming::ConnAck(connack) if connack.code == ConnectReturnCode::Success => { + Packet::ConnAck(connack) + } + Incoming::ConnAck(connack) => { + let error = format!("Broker rejected. Reason = {:?}", connack.code); + return Err(io::Error::new(io::ErrorKind::InvalidData, error)); + } + packet => { + let error = format!("Expecting connack. Received = {:?}", packet); + return Err(io::Error::new(io::ErrorKind::InvalidData, error)); + } + }; + + io::Result::Ok(packet) + }) + .await??; + + Ok(packet) +} + +/// Returns the next pending packet asynchronously to be used in select! +/// This is a synchronous function but made async to make it fit in select! +pub(crate) async fn next_pending( + delay: Duration, + pending: &mut IntoIter, +) -> Option { + // return next packet with a delay + time::sleep(delay).await; + pending.next() +} diff --git a/rumqttc/src/v5/framed.rs b/rumqttc/src/v5/framed.rs new file mode 100644 index 000000000..684694d5a --- /dev/null +++ b/rumqttc/src/v5/framed.rs @@ -0,0 +1,120 @@ +use bytes::BytesMut; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +use crate::v5::{packet::*, Incoming, MqttState, StateError}; +use std::io; + +/// Network transforms packets <-> frames efficiently. It takes +/// advantage of pre-allocation, buffering and vectorization when +/// appropriate to achieve performance +pub struct Network { + /// Socket for IO + socket: Box, + /// Buffered reads + read: BytesMut, + /// Maximum packet size + max_incoming_size: usize, + /// Maximum readv count + max_readb_count: usize, +} + +impl Network { + pub fn new(socket: impl N + 'static, max_incoming_size: usize) -> Network { + let socket = Box::new(socket) as Box; + Network { + socket, + read: BytesMut::with_capacity(10 * 1024), + max_incoming_size, + max_readb_count: 10, + } + } + + /// Reads more than 'required' bytes to frame a packet into self.read buffer + async fn read_bytes(&mut self, required: usize) -> io::Result { + let mut total_read = 0; + loop { + let read = self.socket.read_buf(&mut self.read).await?; + if 0 == read { + return if self.read.is_empty() { + Err(io::Error::new( + io::ErrorKind::ConnectionAborted, + "connection closed by peer", + )) + } else { + Err(io::Error::new( + io::ErrorKind::ConnectionReset, + "connection reset by peer", + )) + }; + } + + total_read += read; + if total_read >= required { + return Ok(total_read); + } + } + } + + pub async fn read(&mut self) -> io::Result { + loop { + let required = match read(&mut self.read, self.max_incoming_size) { + Ok(packet) => return Ok(packet), + Err(Error::InsufficientBytes(required)) => required, + Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())), + }; + + // read more packets until a frame can be created. This function + // blocks until a frame can be created. Use this in a select! branch + self.read_bytes(required).await?; + } + } + + /// Read packets in bulk. This allow replies to be in bulk. This method is used + /// after the connection is established to read a bunch of incoming packets + pub async fn readb(&mut self, state: &mut MqttState) -> Result<(), StateError> { + let mut count = 0; + loop { + match read(&mut self.read, self.max_incoming_size) { + Ok(packet) => { + state.handle_incoming_packet(packet)?; + + count += 1; + if count >= self.max_readb_count { + return Ok(()); + } + } + // If some packets are already framed, return those + Err(Error::InsufficientBytes(_)) if count > 0 => return Ok(()), + // Wait for more bytes until a frame can be created + Err(Error::InsufficientBytes(required)) => { + self.read_bytes(required).await?; + } + Err(e) => return Err(StateError::Deserialization(e)), + }; + } + } + + pub async fn connect(&mut self, connect: Connect) -> io::Result { + let mut write = BytesMut::new(); + let len = match connect.write(&mut write) { + Ok(size) => size, + Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())), + }; + + self.socket.write_all(&write[..]).await?; + Ok(len) + } + + pub async fn flush(&mut self, write: &mut BytesMut) -> io::Result<()> { + if write.is_empty() { + return Ok(()); + } + + self.socket.write_all(&write[..]).await?; + write.clear(); + Ok(()) + } +} + +pub trait N: AsyncRead + AsyncWrite + Send + Unpin {} +impl N for T where T: AsyncRead + AsyncWrite + Send + Unpin {} diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs new file mode 100644 index 000000000..07230ab62 --- /dev/null +++ b/rumqttc/src/v5/mod.rs @@ -0,0 +1,722 @@ +#[cfg(feature = "use-rustls")] +use std::sync::Arc; +use std::{ + collections::VecDeque, + fmt::{self, Debug, Formatter}, + time::Duration, +}; + +mod client; +mod eventloop; +mod framed; +mod notifier; +mod outgoing_buf; +#[allow(clippy::all)] +mod packet; +mod state; +#[cfg(feature = "use-rustls")] +mod tls; + +pub use client::{AsyncClient, Client, ClientError, Connection}; +pub use eventloop::{ConnectionError, EventLoop}; +pub use flume::{SendError, Sender, TrySendError}; +pub use notifier::Notifier; +pub use packet::*; +pub use state::{MqttState, StateError}; +#[cfg(feature = "use-rustls")] +pub use tls::Error; +#[cfg(feature = "use-rustls")] +pub use tokio_rustls::rustls::ClientConfig; + +pub type Incoming = Packet; + +/// Requests by the client to mqtt event loop. Request are +/// handled one by one. +#[derive(Clone, Debug, PartialEq)] +pub enum Request { + Publish(Publish), + PubAck(PubAck), + PubRec(PubRec), + PubComp(PubComp), + PubRel(PubRel), + PingReq, + PingResp, + Subscribe(Subscribe), + SubAck(SubAck), + Unsubscribe(Unsubscribe), + UnsubAck(UnsubAck), + Disconnect, +} + +/// Key type for TLS authentication +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum Key { + RSA(Vec), + ECC(Vec), +} + +impl From for Request { + fn from(publish: Publish) -> Request { + Request::Publish(publish) + } +} + +impl From for Request { + fn from(subscribe: Subscribe) -> Request { + Request::Subscribe(subscribe) + } +} + +impl From for Request { + fn from(unsubscribe: Unsubscribe) -> Request { + Request::Unsubscribe(unsubscribe) + } +} + +#[derive(Clone)] +pub enum Transport { + Tcp, + #[cfg(feature = "use-rustls")] + Tls(TlsConfiguration), + #[cfg(unix)] + Unix, + #[cfg(feature = "websocket")] + #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + Ws, + #[cfg(all(feature = "use-rustls", feature = "websocket"))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "use-rustls", feature = "websocket"))))] + Wss(TlsConfiguration), +} + +impl Default for Transport { + fn default() -> Self { + Self::tcp() + } +} + +impl Transport { + /// Use regular tcp as transport (default) + pub fn tcp() -> Self { + Self::Tcp + } + + /// Use secure tcp with tls as transport + #[cfg(feature = "use-rustls")] + pub fn tls( + ca: Vec, + client_auth: Option<(Vec, Key)>, + alpn: Option>>, + ) -> Self { + let config = TlsConfiguration::Simple { + ca, + alpn, + client_auth, + }; + + Self::tls_with_config(config) + } + + #[cfg(feature = "use-rustls")] + pub fn tls_with_config(tls_config: TlsConfiguration) -> Self { + Self::Tls(tls_config) + } + + #[cfg(unix)] + pub fn unix() -> Self { + Self::Unix + } + + /// Use websockets as transport + #[cfg(feature = "websocket")] + #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] + pub fn ws() -> Self { + Self::Ws + } + + /// Use secure websockets with tls as transport + #[cfg(all(feature = "use-rustls", feature = "websocket"))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "use-rustls", feature = "websocket"))))] + pub fn wss( + ca: Vec, + client_auth: Option<(Vec, Key)>, + alpn: Option>>, + ) -> Self { + let config = TlsConfiguration::Simple { + ca, + client_auth, + alpn, + }; + + Self::wss_with_config(config) + } + + #[cfg(all(feature = "use-rustls", feature = "websocket"))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "use-rustls", feature = "websocket"))))] + pub fn wss_with_config(tls_config: TlsConfiguration) -> Self { + Self::Wss(tls_config) + } +} + +#[derive(Clone)] +#[cfg(feature = "use-rustls")] +pub enum TlsConfiguration { + Simple { + /// connection method + ca: Vec, + /// alpn settings + alpn: Option>>, + /// tls client_authentication + client_auth: Option<(Vec, Key)>, + }, + /// Injected rustls ClientConfig for TLS, to allow more customisation. + Rustls(Arc), +} + +#[cfg(feature = "use-rustls")] +impl From for TlsConfiguration { + fn from(config: ClientConfig) -> Self { + TlsConfiguration::Rustls(Arc::new(config)) + } +} + +// TODO: Should all the options be exposed as public? Drawback +// would be loosing the ability to panic when the user options +// are wrong (e.g empty client id) or aggressive (keep alive time) +/// Options to configure the behaviour of mqtt connection +#[derive(Clone)] +pub struct MqttOptions { + /// broker address that you want to connect to + broker_addr: String, + /// broker port + port: u16, + // What transport protocol to use + transport: Transport, + /// keep alive time to send pingreq to broker when the connection is idle + keep_alive: Duration, + /// clean (or) persistent session + clean_session: bool, + /// client identifier + client_id: String, + /// username and password + credentials: Option<(String, String)>, + /// maximum incoming packet size (verifies remaining length of the packet) + max_incoming_packet_size: usize, + /// Maximum outgoing packet size (only verifies publish payload size) + // TODO Verify this with all packets. This can be packet.write but message left in + // the state might be a footgun as user has to explicitly clean it. Probably state + // has to be moved to network + max_outgoing_packet_size: usize, + /// request (publish, subscribe) channel capacity + request_channel_capacity: usize, + /// Max internal request batching + max_request_batch: usize, + /// Minimum delay time between consecutive outgoing packets + /// while retransmitting pending packets + pending_throttle: Duration, + /// maximum number of outgoing inflight messages + inflight: u16, + /// Last will that will be issued on unexpected disconnect + last_will: Option, + /// Connection timeout + conn_timeout: u64, + /// If set to `true` MQTT acknowledgements are not sent automatically. + /// Every incoming publish packet must be manually acknowledged with `client.ack(...)` method. + manual_acks: bool, +} + +impl MqttOptions { + /// New mqtt options + pub fn new, T: Into>(id: S, host: T, port: u16) -> MqttOptions { + let id = id.into(); + if id.starts_with(' ') || id.is_empty() { + panic!("Invalid client id") + } + + MqttOptions { + broker_addr: host.into(), + port, + transport: Transport::tcp(), + keep_alive: Duration::from_secs(60), + clean_session: true, + client_id: id, + credentials: None, + max_incoming_packet_size: 10 * 1024, + max_outgoing_packet_size: 10 * 1024, + request_channel_capacity: 10, + max_request_batch: 0, + pending_throttle: Duration::from_micros(0), + inflight: 100, + last_will: None, + conn_timeout: 5, + manual_acks: false, + } + } + + /// Broker address + pub fn broker_address(&self) -> (String, u16) { + (self.broker_addr.clone(), self.port) + } + + pub fn set_last_will(&mut self, will: LastWill) -> &mut Self { + self.last_will = Some(will); + self + } + + pub fn last_will(&self) -> Option { + self.last_will.clone() + } + + pub fn set_transport(&mut self, transport: Transport) -> &mut Self { + self.transport = transport; + self + } + + pub fn transport(&self) -> Transport { + self.transport.clone() + } + + /// Set number of seconds after which client should ping the broker + /// if there is no other data exchange + pub fn set_keep_alive(&mut self, duration: Duration) -> &mut Self { + if duration.as_secs() < 5 { + panic!("Keep alives should be >= 5 secs"); + } + + self.keep_alive = duration; + self + } + + /// Keep alive time + pub fn keep_alive(&self) -> Duration { + self.keep_alive + } + + /// Client identifier + pub fn client_id(&self) -> String { + self.client_id.clone() + } + + /// Set packet size limit for outgoing an incoming packets + pub fn set_max_packet_size(&mut self, incoming: usize, outgoing: usize) -> &mut Self { + self.max_incoming_packet_size = incoming; + self.max_outgoing_packet_size = outgoing; + self + } + + /// Maximum packet size + pub fn max_packet_size(&self) -> usize { + self.max_incoming_packet_size + } + + /// `clean_session = true` removes all the state from queues & instructs the broker + /// to clean all the client state when client disconnects. + /// + /// When set `false`, broker will hold the client state and performs pending + /// operations on the client when reconnection with same `client_id` + /// happens. Local queue state is also held to retransmit packets after reconnection. + pub fn set_clean_session(&mut self, clean_session: bool) -> &mut Self { + self.clean_session = clean_session; + self + } + + /// Clean session + pub fn clean_session(&self) -> bool { + self.clean_session + } + + /// Username and password + pub fn set_credentials, P: Into>( + &mut self, + username: U, + password: P, + ) -> &mut Self { + self.credentials = Some((username.into(), password.into())); + self + } + + /// Security options + pub fn credentials(&self) -> Option<(String, String)> { + self.credentials.clone() + } + + /// Set request channel capacity + pub fn set_request_channel_capacity(&mut self, capacity: usize) -> &mut Self { + self.request_channel_capacity = capacity; + self + } + + /// Request channel capacity + pub fn request_channel_capacity(&self) -> usize { + self.request_channel_capacity + } + + /// Enables throttling and sets outoing message rate to the specified 'rate' + pub fn set_pending_throttle(&mut self, duration: Duration) -> &mut Self { + self.pending_throttle = duration; + self + } + + /// Outgoing message rate + pub fn pending_throttle(&self) -> Duration { + self.pending_throttle + } + + /// Set number of concurrent in flight messages + pub fn set_inflight(&mut self, inflight: u16) -> &mut Self { + if inflight == 0 { + panic!("zero in flight is not allowed") + } + + self.inflight = inflight; + self + } + + /// Number of concurrent in flight messages + pub fn inflight(&self) -> u16 { + self.inflight + } + + /// set connection timeout in secs + pub fn set_connection_timeout(&mut self, timeout: u64) -> &mut Self { + self.conn_timeout = timeout; + self + } + + /// get timeout in secs + pub fn connection_timeout(&self) -> u64 { + self.conn_timeout + } + + /// set manual acknowledgements + pub fn set_manual_acks(&mut self, manual_acks: bool) -> &mut Self { + self.manual_acks = manual_acks; + self + } + + /// get manual acknowledgements + pub fn manual_acks(&self) -> bool { + self.manual_acks + } +} + +#[cfg(feature = "url")] +#[derive(Debug, PartialEq, thiserror::Error)] +pub enum OptionError { + #[error("Unsupported URL scheme.")] + Scheme, + + #[error("Missing client ID.")] + ClientId, + + #[error("Invalid keep-alive value.")] + KeepAlive, + + #[error("Invalid clean-session value.")] + CleanSession, + + #[error("Invalid max-incoming-packet-size value.")] + MaxIncomingPacketSize, + + #[error("Invalid max-outgoing-packet-size value.")] + MaxOutgoingPacketSize, + + #[error("Invalid request-channel-capacity value.")] + RequestChannelCapacity, + + #[error("Invalid max-request-batch value.")] + MaxRequestBatch, + + #[error("Invalid pending-throttle value.")] + PendingThrottle, + + #[error("Invalid inflight value.")] + Inflight, + + #[error("Invalid conn-timeout value.")] + ConnTimeout, + + #[error("Unknown option: {0}")] + Unknown(String), +} + +#[cfg(feature = "url")] +impl std::convert::TryFrom for MqttOptions { + type Error = OptionError; + + fn try_from(url: url::Url) -> Result { + use std::collections::HashMap; + + let broker_addr = url.host_str().unwrap_or_default().to_owned(); + + let (transport, default_port) = match url.scheme() { + // Encrypted connections are supported, but require explicit TLS configuration. We fall + // back to the unencrypted transport layer, so that `set_transport` can be used to + // configure the encrypted transport layer with the provided TLS configuration. + "mqtts" | "ssl" => (Transport::Tcp, 8883), + "mqtt" | "tcp" => (Transport::Tcp, 1883), + _ => return Err(OptionError::Scheme), + }; + + let port = url.port().unwrap_or(default_port); + + let mut queries = url.query_pairs().collect::>(); + + let keep_alive = Duration::from_secs( + queries + .remove("keep_alive_secs") + .map(|v| v.parse::().map_err(|_| OptionError::KeepAlive)) + .transpose()? + .unwrap_or(60), + ); + + let client_id = queries + .remove("client_id") + .ok_or(OptionError::ClientId)? + .into_owned(); + + let clean_session = queries + .remove("clean_session") + .map(|v| v.parse::().map_err(|_| OptionError::CleanSession)) + .transpose()? + .unwrap_or(true); + + let credentials = { + match url.username() { + "" => None, + username => Some(( + username.to_owned(), + url.password().unwrap_or_default().to_owned(), + )), + } + }; + + let max_incoming_packet_size = queries + .remove("max_incoming_packet_size_bytes") + .map(|v| { + v.parse::() + .map_err(|_| OptionError::MaxIncomingPacketSize) + }) + .transpose()? + .unwrap_or(10 * 1024); + + let max_outgoing_packet_size = queries + .remove("max_outgoing_packet_size_bytes") + .map(|v| { + v.parse::() + .map_err(|_| OptionError::MaxOutgoingPacketSize) + }) + .transpose()? + .unwrap_or(10 * 1024); + + let request_channel_capacity = queries + .remove("request_channel_capacity_num") + .map(|v| { + v.parse::() + .map_err(|_| OptionError::RequestChannelCapacity) + }) + .transpose()? + .unwrap_or(10); + + let max_request_batch = queries + .remove("max_request_batch_num") + .map(|v| v.parse::().map_err(|_| OptionError::MaxRequestBatch)) + .transpose()? + .unwrap_or(0); + + let pending_throttle = Duration::from_micros( + queries + .remove("pending_throttle_usecs") + .map(|v| v.parse::().map_err(|_| OptionError::PendingThrottle)) + .transpose()? + .unwrap_or(0), + ); + + let inflight = queries + .remove("inflight_num") + .map(|v| v.parse::().map_err(|_| OptionError::Inflight)) + .transpose()? + .unwrap_or(100); + + let conn_timeout = queries + .remove("conn_timeout_secs") + .map(|v| v.parse::().map_err(|_| OptionError::ConnTimeout)) + .transpose()? + .unwrap_or(5); + + if let Some((opt, _)) = queries.into_iter().next() { + return Err(OptionError::Unknown(opt.into_owned())); + } + + Ok(Self { + broker_addr, + port, + transport, + keep_alive, + clean_session, + client_id, + credentials, + max_incoming_packet_size, + max_outgoing_packet_size, + request_channel_capacity, + max_request_batch, + pending_throttle, + inflight, + last_will: None, + conn_timeout, + manual_acks: false, + }) + } +} + +// Implement Debug manually because ClientConfig doesn't implement it, so derive(Debug) doesn't +// work. +impl Debug for MqttOptions { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("MqttOptions") + .field("broker_addr", &self.broker_addr) + .field("port", &self.port) + .field("keep_alive", &self.keep_alive) + .field("clean_session", &self.clean_session) + .field("client_id", &self.client_id) + .field("credentials", &self.credentials) + .field("max_packet_size", &self.max_incoming_packet_size) + .field("request_channel_capacity", &self.request_channel_capacity) + .field("max_request_batch", &self.max_request_batch) + .field("pending_throttle", &self.pending_throttle) + .field("inflight", &self.inflight) + .field("last_will", &self.last_will) + .field("conn_timeout", &self.conn_timeout) + .field("manual_acks", &self.manual_acks) + .finish() + } +} + +pub async fn connect(options: MqttOptions, cap: usize) -> Result<(AsyncClient, Notifier), ()> { + let mut eventloop = EventLoop::new(options, cap); + let outgoing_buf = eventloop.state.outgoing_buf.clone(); + let incoming_buf = eventloop.state.incoming_buf.clone(); + let incoming_buf_cache = VecDeque::with_capacity(cap); + let request_tx = eventloop.handle(); + + let client = AsyncClient { + outgoing_buf, + request_tx, + }; + + tokio::spawn(async move { + loop { + // TODO: maybe do something like retries for some specific errors? or maybe give user + // options to configure these retries? + eventloop.poll().await.unwrap(); + } + }); + + Ok((client, Notifier::new(incoming_buf, incoming_buf_cache))) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + #[should_panic] + fn client_id_startswith_space() { + let _mqtt_opts = MqttOptions::new(" client_a", "127.0.0.1", 1883).set_clean_session(true); + } + + #[test] + #[cfg(all(feature = "use-rustls", feature = "websocket"))] + fn no_scheme() { + let mut _mqtt_opts = MqttOptions::new("client_a", "a3f8czas.iot.eu-west-1.amazonaws.com/mqtt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=MyCreds%2F20201001%2Feu-west-1%2Fiotdevicegateway%2Faws4_request&X-Amz-Date=20201001T130812Z&X-Amz-Expires=7200&X-Amz-Signature=9ae09b49896f44270f2707551581953e6cac71a4ccf34c7c3415555be751b2d1&X-Amz-SignedHeaders=host", 443); + + _mqtt_opts.set_transport(crate::v5::Transport::wss(Vec::from("Test CA"), None, None)); + + if let crate::v5::Transport::Wss(TlsConfiguration::Simple { + ca, + client_auth, + alpn, + }) = _mqtt_opts.transport + { + assert_eq!(ca, Vec::from("Test CA")); + assert_eq!(client_auth, None); + assert_eq!(alpn, None); + } else { + panic!("Unexpected transport!"); + } + + assert_eq!(_mqtt_opts.broker_addr, "a3f8czas.iot.eu-west-1.amazonaws.com/mqtt?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=MyCreds%2F20201001%2Feu-west-1%2Fiotdevicegateway%2Faws4_request&X-Amz-Date=20201001T130812Z&X-Amz-Expires=7200&X-Amz-Signature=9ae09b49896f44270f2707551581953e6cac71a4ccf34c7c3415555be751b2d1&X-Amz-SignedHeaders=host"); + } + + #[test] + #[cfg(feature = "url")] + fn from_url() { + use std::convert::TryInto; + use std::str::FromStr; + + fn opt(s: &str) -> Result { + url::Url::from_str(s).expect("valid url").try_into() + } + fn ok(s: &str) -> MqttOptions { + opt(s).expect("valid options") + } + fn err(s: &str) -> OptionError { + opt(s).expect_err("invalid options") + } + + let v = ok("mqtt://host:42?client_id=foo"); + assert_eq!(v.broker_address(), ("host".to_owned(), 42)); + assert_eq!(v.client_id(), "foo".to_owned()); + + let v = ok("mqtt://host:42?client_id=foo&keep_alive_secs=5"); + assert_eq!(v.keep_alive, Duration::from_secs(5)); + + assert_eq!(err("mqtt://host:42"), OptionError::ClientId); + assert_eq!( + err("mqtt://host:42?client_id=foo&foo=bar"), + OptionError::Unknown("foo".to_owned()) + ); + assert_eq!(err("mqt://host:42?client_id=foo"), OptionError::Scheme); + assert_eq!( + err("mqtt://host:42?client_id=foo&keep_alive_secs=foo"), + OptionError::KeepAlive + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&clean_session=foo"), + OptionError::CleanSession + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&max_incoming_packet_size_bytes=foo"), + OptionError::MaxIncomingPacketSize + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&max_outgoing_packet_size_bytes=foo"), + OptionError::MaxOutgoingPacketSize + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&request_channel_capacity_num=foo"), + OptionError::RequestChannelCapacity + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&max_request_batch_num=foo"), + OptionError::MaxRequestBatch + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&pending_throttle_usecs=foo"), + OptionError::PendingThrottle + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&inflight_num=foo"), + OptionError::Inflight + ); + assert_eq!( + err("mqtt://host:42?client_id=foo&conn_timeout_secs=foo"), + OptionError::ConnTimeout + ); + } + + #[test] + #[should_panic] + fn no_client_id() { + let _mqtt_opts = MqttOptions::new("", "127.0.0.1", 1883).set_clean_session(true); + } +} diff --git a/rumqttc/src/v5/notifier.rs b/rumqttc/src/v5/notifier.rs new file mode 100644 index 000000000..4aad90d16 --- /dev/null +++ b/rumqttc/src/v5/notifier.rs @@ -0,0 +1,56 @@ +use std::{ + collections::VecDeque, + mem, + sync::{Arc, Mutex}, +}; + +use crate::v5::Incoming; + +#[derive(Debug)] +pub struct Notifier { + incoming_buf: Arc>>, + incoming_buf_cache: VecDeque, +} + +impl Notifier { + #[inline] + pub(crate) fn new( + incoming_buf: Arc>>, + incoming_buf_cache: VecDeque, + ) -> Self { + Self { + incoming_buf, + incoming_buf_cache, + } + } + + #[inline] + pub fn next(&mut self) -> Option { + match self.incoming_buf_cache.pop_front() { + None => { + mem::swap( + &mut self.incoming_buf_cache, + &mut *self.incoming_buf.lock().unwrap(), + ); + self.incoming_buf_cache.pop_front() + } + val => val, + } + } + + #[inline] + pub fn iter(&mut self) -> NotifierIter<'_> { + NotifierIter(self) + } +} + +pub struct NotifierIter<'a>(&'a mut Notifier); + +impl<'a> Iterator for NotifierIter<'a> { + type Item = Incoming; + + #[inline] + fn next(&mut self) -> Option { + self.0.next() + } +} diff --git a/rumqttc/src/v5/outgoing_buf.rs b/rumqttc/src/v5/outgoing_buf.rs new file mode 100644 index 000000000..d036dd0c8 --- /dev/null +++ b/rumqttc/src/v5/outgoing_buf.rs @@ -0,0 +1,34 @@ +use std::{ + collections::VecDeque, + sync::{Arc, Mutex}, +}; + +use crate::v5::Request; + +#[derive(Debug)] +pub struct OutgoingBuf { + pub(crate) buf: VecDeque, + pub(crate) pkid_counter: u16, + pub(crate) capacity: usize, +} + +impl OutgoingBuf { + #[inline] + pub fn new(cap: usize) -> Arc> { + Arc::new(Mutex::new(Self { + buf: VecDeque::with_capacity(cap), + pkid_counter: 0, + capacity: cap, + })) + } + + #[inline] + pub fn increment_pkid(&mut self) -> u16 { + self.pkid_counter = if self.pkid_counter == self.capacity as u16 { + 1 + } else { + self.pkid_counter + 1 + }; + self.pkid_counter + } +} diff --git a/rumqttc/src/v5/packet/connack.rs b/rumqttc/src/v5/packet/connack.rs new file mode 100644 index 000000000..f0e0ebaee --- /dev/null +++ b/rumqttc/src/v5/packet/connack.rs @@ -0,0 +1,553 @@ +use super::*; +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +/// Return code in connack +#[derive(Debug, Clone, Copy, PartialEq)] +#[repr(u8)] +pub enum ConnectReturnCode { + Success = 0, + UnspecifiedError = 128, + MalformedPacket = 129, + ProtocolError = 130, + ImplementationSpecificError = 131, + UnsupportedProtocolVersion = 132, + ClientIdentifierNotValid = 133, + BadUserNamePassword = 134, + NotAuthorized = 135, + ServerUnavailable = 136, + ServerBusy = 137, + Banned = 138, + BadAuthenticationMethod = 140, + TopicNameInvalid = 144, + PacketTooLarge = 149, + QuotaExceeded = 151, + PayloadFormatInvalid = 153, + RetainNotSupported = 154, + QoSNotSupported = 155, + UseAnotherServer = 156, + ServerMoved = 157, + ConnectionRateExceeded = 159, +} + +/// Acknowledgement to connect packet +#[derive(Debug, Clone, PartialEq)] +pub struct ConnAck { + pub session_present: bool, + pub code: ConnectReturnCode, + pub properties: Option, +} + +impl ConnAck { + pub fn new(code: ConnectReturnCode, session_present: bool) -> ConnAck { + ConnAck { + code, + session_present, + properties: None, + } + } + + fn len(&self) -> usize { + let mut len = 1 // session present + + 1; // code + + if let Some(properties) = &self.properties { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + let flags = read_u8(&mut bytes)?; + let return_code = read_u8(&mut bytes)?; + + let session_present = (flags & 0x01) == 1; + let code = connect_return(return_code)?; + let connack = ConnAck { + session_present, + code, + properties: ConnAckProperties::extract(&mut bytes)?, + }; + + Ok(connack) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + let len = self.len(); + buffer.put_u8(0x20); + + let count = write_remaining_length(buffer, len)?; + buffer.put_u8(self.session_present as u8); + buffer.put_u8(self.code as u8); + + if let Some(properties) = &self.properties { + properties.write(buffer)?; + } + + Ok(1 + count + len) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct ConnAckProperties { + pub session_expiry_interval: Option, + pub receive_max: Option, + pub max_qos: Option, + pub retain_available: Option, + pub max_packet_size: Option, + pub assigned_client_identifier: Option, + pub topic_alias_max: Option, + pub reason_string: Option, + pub user_properties: Vec<(String, String)>, + pub wildcard_subscription_available: Option, + pub subscription_identifiers_available: Option, + pub shared_subscription_available: Option, + pub server_keep_alive: Option, + pub response_information: Option, + pub server_reference: Option, + pub authentication_method: Option, + pub authentication_data: Option, +} + +impl ConnAckProperties { + pub fn new() -> ConnAckProperties { + ConnAckProperties { + session_expiry_interval: None, + receive_max: None, + max_qos: None, + retain_available: None, + max_packet_size: None, + assigned_client_identifier: None, + topic_alias_max: None, + reason_string: None, + user_properties: Vec::new(), + wildcard_subscription_available: None, + subscription_identifiers_available: None, + shared_subscription_available: None, + server_keep_alive: None, + response_information: None, + server_reference: None, + authentication_method: None, + authentication_data: None, + } + } + + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(_) = &self.session_expiry_interval { + len += 1 + 4; + } + + if let Some(_) = &self.receive_max { + len += 1 + 2; + } + + if let Some(_) = &self.max_qos { + len += 1 + 1; + } + + if let Some(_) = &self.retain_available { + len += 1 + 1; + } + + if let Some(_) = &self.max_packet_size { + len += 1 + 4; + } + + if let Some(id) = &self.assigned_client_identifier { + len += 1 + 2 + id.len(); + } + + if let Some(_) = &self.topic_alias_max { + len += 1 + 2; + } + + if let Some(reason) = &self.reason_string { + len += 1 + 2 + reason.len(); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + if let Some(_) = &self.wildcard_subscription_available { + len += 1 + 1; + } + + if let Some(_) = &self.subscription_identifiers_available { + len += 1 + 1; + } + + if let Some(_) = &self.shared_subscription_available { + len += 1 + 1; + } + + if let Some(_) = &self.server_keep_alive { + len += 1 + 2; + } + + if let Some(info) = &self.response_information { + len += 1 + 2 + info.len(); + } + + if let Some(reference) = &self.server_reference { + len += 1 + 2 + reference.len(); + } + + if let Some(authentication_method) = &self.authentication_method { + len += 1 + 2 + authentication_method.len(); + } + + if let Some(authentication_data) = &self.authentication_data { + len += 1 + 2 + authentication_data.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut session_expiry_interval = None; + let mut receive_max = None; + let mut max_qos = None; + let mut retain_available = None; + let mut max_packet_size = None; + let mut assigned_client_identifier = None; + let mut topic_alias_max = None; + let mut reason_string = None; + let mut user_properties = Vec::new(); + let mut wildcard_subscription_available = None; + let mut subscription_identifiers_available = None; + let mut shared_subscription_available = None; + let mut server_keep_alive = None; + let mut response_information = None; + let mut server_reference = None; + let mut authentication_method = None; + let mut authentication_data = None; + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::SessionExpiryInterval => { + session_expiry_interval = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::ReceiveMaximum => { + receive_max = Some(read_u16(&mut bytes)?); + cursor += 2; + } + PropertyType::MaximumQos => { + max_qos = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::RetainAvailable => { + retain_available = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::AssignedClientIdentifier => { + let id = read_mqtt_string(&mut bytes)?; + cursor += 2 + id.len(); + assigned_client_identifier = Some(id); + } + PropertyType::MaximumPacketSize => { + max_packet_size = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::TopicAliasMaximum => { + topic_alias_max = Some(read_u16(&mut bytes)?); + cursor += 2; + } + PropertyType::ReasonString => { + let reason = read_mqtt_string(&mut bytes)?; + cursor += 2 + reason.len(); + reason_string = Some(reason); + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + PropertyType::WildcardSubscriptionAvailable => { + wildcard_subscription_available = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::SubscriptionIdentifierAvailable => { + subscription_identifiers_available = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::SharedSubscriptionAvailable => { + shared_subscription_available = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::ServerKeepAlive => { + server_keep_alive = Some(read_u16(&mut bytes)?); + cursor += 2; + } + PropertyType::ResponseInformation => { + let info = read_mqtt_string(&mut bytes)?; + cursor += 2 + info.len(); + response_information = Some(info); + } + PropertyType::ServerReference => { + let reference = read_mqtt_string(&mut bytes)?; + cursor += 2 + reference.len(); + server_reference = Some(reference); + } + PropertyType::AuthenticationMethod => { + let method = read_mqtt_string(&mut bytes)?; + cursor += 2 + method.len(); + authentication_method = Some(method); + } + PropertyType::AuthenticationData => { + let data = read_mqtt_bytes(&mut bytes)?; + cursor += 2 + data.len(); + authentication_data = Some(data); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(ConnAckProperties { + session_expiry_interval, + receive_max, + max_qos, + retain_available, + max_packet_size, + assigned_client_identifier, + topic_alias_max, + reason_string, + user_properties, + wildcard_subscription_available, + subscription_identifiers_available, + shared_subscription_available, + server_keep_alive, + response_information, + server_reference, + authentication_method, + authentication_data, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(session_expiry_interval) = self.session_expiry_interval { + buffer.put_u8(PropertyType::SessionExpiryInterval as u8); + buffer.put_u32(session_expiry_interval); + } + + if let Some(receive_maximum) = self.receive_max { + buffer.put_u8(PropertyType::ReceiveMaximum as u8); + buffer.put_u16(receive_maximum); + } + + if let Some(qos) = self.max_qos { + buffer.put_u8(PropertyType::MaximumQos as u8); + buffer.put_u8(qos); + } + + if let Some(retain_available) = self.retain_available { + buffer.put_u8(PropertyType::RetainAvailable as u8); + buffer.put_u8(retain_available); + } + + if let Some(max_packet_size) = self.max_packet_size { + buffer.put_u8(PropertyType::MaximumPacketSize as u8); + buffer.put_u32(max_packet_size); + } + + if let Some(id) = &self.assigned_client_identifier { + buffer.put_u8(PropertyType::AssignedClientIdentifier as u8); + write_mqtt_string(buffer, id); + } + + if let Some(topic_alias_max) = self.topic_alias_max { + buffer.put_u8(PropertyType::TopicAliasMaximum as u8); + buffer.put_u16(topic_alias_max); + } + + if let Some(reason) = &self.reason_string { + buffer.put_u8(PropertyType::ReasonString as u8); + write_mqtt_string(buffer, reason); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + if let Some(w) = self.wildcard_subscription_available { + buffer.put_u8(PropertyType::WildcardSubscriptionAvailable as u8); + buffer.put_u8(w); + } + + if let Some(s) = self.subscription_identifiers_available { + buffer.put_u8(PropertyType::SubscriptionIdentifierAvailable as u8); + buffer.put_u8(s); + } + + if let Some(s) = self.shared_subscription_available { + buffer.put_u8(PropertyType::SharedSubscriptionAvailable as u8); + buffer.put_u8(s); + } + + if let Some(keep_alive) = self.server_keep_alive { + buffer.put_u8(PropertyType::ServerKeepAlive as u8); + buffer.put_u16(keep_alive); + } + + if let Some(info) = &self.response_information { + buffer.put_u8(PropertyType::ResponseInformation as u8); + write_mqtt_string(buffer, info); + } + + if let Some(reference) = &self.server_reference { + buffer.put_u8(PropertyType::ServerReference as u8); + write_mqtt_string(buffer, reference); + } + + if let Some(authentication_method) = &self.authentication_method { + buffer.put_u8(PropertyType::AuthenticationMethod as u8); + write_mqtt_string(buffer, authentication_method); + } + + if let Some(authentication_data) = &self.authentication_data { + buffer.put_u8(PropertyType::AuthenticationData as u8); + write_mqtt_bytes(buffer, authentication_data); + } + + Ok(()) + } +} + +/// Connection return code type +fn connect_return(num: u8) -> Result { + let code = match num { + 0 => ConnectReturnCode::Success, + 128 => ConnectReturnCode::UnspecifiedError, + 129 => ConnectReturnCode::MalformedPacket, + 130 => ConnectReturnCode::ProtocolError, + 131 => ConnectReturnCode::ImplementationSpecificError, + 132 => ConnectReturnCode::UnsupportedProtocolVersion, + 133 => ConnectReturnCode::ClientIdentifierNotValid, + 134 => ConnectReturnCode::BadUserNamePassword, + 135 => ConnectReturnCode::NotAuthorized, + 136 => ConnectReturnCode::ServerUnavailable, + 137 => ConnectReturnCode::ServerBusy, + 138 => ConnectReturnCode::Banned, + 140 => ConnectReturnCode::BadAuthenticationMethod, + 144 => ConnectReturnCode::TopicNameInvalid, + 149 => ConnectReturnCode::PacketTooLarge, + 151 => ConnectReturnCode::QuotaExceeded, + 153 => ConnectReturnCode::PayloadFormatInvalid, + 154 => ConnectReturnCode::RetainNotSupported, + 155 => ConnectReturnCode::QoSNotSupported, + 156 => ConnectReturnCode::UseAnotherServer, + 157 => ConnectReturnCode::ServerMoved, + 159 => ConnectReturnCode::ConnectionRateExceeded, + num => return Err(Error::InvalidConnectReturnCode(num)), + }; + + Ok(code) +} + +#[cfg(test)] +mod test { + use super::*; + use bytes::{Bytes, BytesMut}; + use pretty_assertions::assert_eq; + + fn sample() -> ConnAck { + let properties = ConnAckProperties { + session_expiry_interval: Some(1234), + receive_max: Some(432), + max_qos: Some(2), + retain_available: Some(1), + max_packet_size: Some(100), + assigned_client_identifier: Some("test".to_owned()), + topic_alias_max: Some(456), + reason_string: Some("test".to_owned()), + user_properties: vec![("test".to_owned(), "test".to_owned())], + wildcard_subscription_available: Some(1), + subscription_identifiers_available: Some(1), + shared_subscription_available: Some(0), + server_keep_alive: Some(1234), + response_information: Some("test".to_owned()), + server_reference: Some("test".to_owned()), + authentication_method: Some("test".to_owned()), + authentication_data: Some(Bytes::from(vec![1, 2, 3, 4])), + }; + + ConnAck { + session_present: false, + code: ConnectReturnCode::Success, + properties: Some(properties), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0x20, // Packet type + 0x57, // Remaining length + 0x00, 0x00, // Session, code + 0x54, // Properties length + 0x11, 0x00, 0x00, 0x04, 0xd2, // Session expiry interval + 0x21, 0x01, 0xb0, // Receive maximum + 0x24, 0x02, // Maximum qos + 0x25, 0x01, // Retain available + 0x27, 0x00, 0x00, 0x00, 0x64, // Maximum packet size + 0x12, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // Assigned client identifier + 0x22, 0x01, 0xc8, // Topic alias max + 0x1f, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // Reason string + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // user properties + 0x28, 0x01, // wildcard_subscription_available + 0x29, 0x01, // subscription_identifiers_available + 0x2a, 0x00, // shared_subscription_available + 0x13, 0x04, 0xd2, // server keep_alive + 0x1a, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // response_information + 0x1c, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // server reference + 0x15, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // authentication method + 0x16, 0x00, 0x04, 0x01, 0x02, 0x03, 0x04, // authentication data + ] + } + + #[test] + fn connack_parsing_works() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let connack_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let connack = ConnAck::read(fixed_header, connack_bytes).unwrap(); + + assert_eq!(connack, sample()); + } + + #[test] + fn connack_encoding_works() { + let connack = sample(); + let mut buf = BytesMut::new(); + connack.write(&mut buf).unwrap(); + assert_eq!(&buf[..], sample_bytes()); + } +} diff --git a/rumqttc/src/v5/packet/connect.rs b/rumqttc/src/v5/packet/connect.rs new file mode 100644 index 000000000..1c630ae79 --- /dev/null +++ b/rumqttc/src/v5/packet/connect.rs @@ -0,0 +1,928 @@ +use super::*; +use bytes::{Buf, Bytes}; + +/// Connection packet initiated by the client +#[derive(Debug, Clone, PartialEq)] +pub struct Connect { + /// Mqtt protocol version + pub protocol: Protocol, + /// Mqtt keep alive time + pub keep_alive: u16, + /// Client Id + pub client_id: String, + /// Clean session. Asks the broker to clear previous state + pub clean_session: bool, + /// Will that broker needs to publish when the client disconnects + pub last_will: Option, + /// Login credentials + pub login: Option, + /// Properties + pub properties: Option, +} + +impl Connect { + pub fn new>(id: S) -> Connect { + Connect { + protocol: Protocol::V5, + keep_alive: 10, + properties: None, + client_id: id.into(), + clean_session: true, + last_will: None, + login: None, + } + } + + pub fn set_login, P: Into>(&mut self, u: U, p: P) -> &mut Connect { + let login = Login { + username: u.into(), + password: p.into(), + }; + + self.login = Some(login); + self + } + + pub fn len(&self) -> usize { + let mut len = 2 + "MQTT".len() // protocol name + + 1 // protocol version + + 1 // connect flags + + 2; // keep alive + + match &self.properties { + Some(properties) => { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + None => { + // just 1 byte representing 0 len + len += 1; + } + } + + len += 2 + self.client_id.len(); + + // last will len + if let Some(last_will) = &self.last_will { + len += last_will.len(); + } + + // username and password len + if let Some(login) = &self.login { + len += login.len(); + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + // Variable header + let protocol_name = read_mqtt_string(&mut bytes)?; + let protocol_level = read_u8(&mut bytes)?; + if protocol_name != "MQTT" { + return Err(Error::InvalidProtocol); + } + + let protocol = match protocol_level { + 4 => Protocol::V4, + 5 => Protocol::V5, + num => return Err(Error::InvalidProtocolLevel(num)), + }; + + let connect_flags = read_u8(&mut bytes)?; + let clean_session = (connect_flags & 0b10) != 0; + let keep_alive = read_u16(&mut bytes)?; + + // Properties in variable header + let properties = match protocol { + Protocol::V5 => ConnectProperties::read(&mut bytes)?, + Protocol::V4 => None, + }; + + let client_id = read_mqtt_string(&mut bytes)?; + let last_will = LastWill::read(connect_flags, &mut bytes)?; + let login = Login::read(connect_flags, &mut bytes)?; + + let connect = Connect { + protocol, + keep_alive, + client_id, + clean_session, + last_will, + login, + properties, + }; + + Ok(connect) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + let len = self.len(); + buffer.put_u8(0b0001_0000); + let count = write_remaining_length(buffer, len)?; + write_mqtt_string(buffer, "MQTT"); + + match self.protocol { + Protocol::V4 => buffer.put_u8(0x04), + Protocol::V5 => buffer.put_u8(0x05), + } + + let flags_index = 1 + count + 2 + 4 + 1; + + let mut connect_flags = 0; + if self.clean_session { + connect_flags |= 0x02; + } + + buffer.put_u8(connect_flags); + buffer.put_u16(self.keep_alive); + + match &self.properties { + Some(properties) => properties.write(buffer)?, + None => { + write_remaining_length(buffer, 0)?; + } + }; + + write_mqtt_string(buffer, &self.client_id); + + if let Some(last_will) = &self.last_will { + connect_flags |= last_will.write(buffer)?; + } + + if let Some(login) = &self.login { + connect_flags |= login.write(buffer); + } + + // update connect flags + buffer[flags_index] = connect_flags; + Ok(len) + } +} + +/// LastWill that broker forwards on behalf of the client +#[derive(Debug, Clone, PartialEq)] +pub struct LastWill { + pub topic: String, + pub message: Bytes, + pub qos: QoS, + pub retain: bool, + pub properties: Option, +} + +impl LastWill { + pub fn new( + topic: impl Into, + payload: impl Into>, + qos: QoS, + retain: bool, + ) -> LastWill { + LastWill { + topic: topic.into(), + message: Bytes::from(payload.into()), + qos, + retain, + properties: None, + } + } + + fn len(&self) -> usize { + let mut len = 0; + + match &self.properties { + Some(properties) => { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + None => { + // just 1 byte representing 0 len + len += 1; + } + }; + + len += 2 + self.topic.len() + 2 + self.message.len(); + len + } + + fn read(connect_flags: u8, mut bytes: &mut Bytes) -> Result, Error> { + let last_will = match connect_flags & 0b100 { + 0 if (connect_flags & 0b0011_1000) != 0 => { + return Err(Error::IncorrectPacketFormat); + } + 0 => None, + _ => { + // Properties in variable header + let properties = WillProperties::read(&mut bytes)?; + + let will_topic = read_mqtt_string(&mut bytes)?; + let will_message = read_mqtt_bytes(&mut bytes)?; + let will_qos = qos((connect_flags & 0b11000) >> 3)?; + Some(LastWill { + topic: will_topic, + message: will_message, + qos: will_qos, + retain: (connect_flags & 0b0010_0000) != 0, + properties, + }) + } + }; + + Ok(last_will) + } + + fn write(&self, buffer: &mut BytesMut) -> Result { + let mut connect_flags = 0; + + connect_flags |= 0x04 | (self.qos as u8) << 3; + if self.retain { + connect_flags |= 0x20; + } + + match &self.properties { + Some(properties) => properties.write(buffer)?, + None => { + write_remaining_length(buffer, 0)?; + } + }; + + write_mqtt_string(buffer, &self.topic); + write_mqtt_bytes(buffer, &self.message); + Ok(connect_flags) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct WillProperties { + pub delay_interval: Option, + pub payload_format_indicator: Option, + pub message_expiry_interval: Option, + pub content_type: Option, + pub response_topic: Option, + pub correlation_data: Option, + pub user_properties: Vec<(String, String)>, +} + +impl WillProperties { + fn len(&self) -> usize { + let mut len = 0; + + if self.delay_interval.is_some() { + len += 1 + 4; + } + + if self.payload_format_indicator.is_some() { + len += 1 + 1; + } + + if self.message_expiry_interval.is_some() { + len += 1 + 4; + } + + if let Some(typ) = &self.content_type { + len += 1 + 2 + typ.len() + } + + if let Some(topic) = &self.response_topic { + len += 1 + 2 + topic.len() + } + + if let Some(data) = &self.correlation_data { + len += 1 + 2 + data.len() + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + fn read(mut bytes: &mut Bytes) -> Result, Error> { + let mut delay_interval = None; + let mut payload_format_indicator = None; + let mut message_expiry_interval = None; + let mut content_type = None; + let mut response_topic = None; + let mut correlation_data = None; + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::WillDelayInterval => { + delay_interval = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::PayloadFormatIndicator => { + payload_format_indicator = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::MessageExpiryInterval => { + message_expiry_interval = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::ContentType => { + let typ = read_mqtt_string(&mut bytes)?; + cursor += 2 + typ.len(); + content_type = Some(typ); + } + PropertyType::ResponseTopic => { + let topic = read_mqtt_string(&mut bytes)?; + cursor += 2 + topic.len(); + response_topic = Some(topic); + } + PropertyType::CorrelationData => { + let data = read_mqtt_bytes(&mut bytes)?; + cursor += 2 + data.len(); + correlation_data = Some(data); + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(WillProperties { + delay_interval, + payload_format_indicator, + message_expiry_interval, + content_type, + response_topic, + correlation_data, + user_properties, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(delay_interval) = self.delay_interval { + buffer.put_u8(PropertyType::WillDelayInterval as u8); + buffer.put_u32(delay_interval); + } + + if let Some(payload_format_indicator) = self.payload_format_indicator { + buffer.put_u8(PropertyType::PayloadFormatIndicator as u8); + buffer.put_u8(payload_format_indicator); + } + + if let Some(message_expiry_interval) = self.message_expiry_interval { + buffer.put_u8(PropertyType::MessageExpiryInterval as u8); + buffer.put_u32(message_expiry_interval); + } + + if let Some(typ) = &self.content_type { + buffer.put_u8(PropertyType::ContentType as u8); + write_mqtt_string(buffer, typ); + } + + if let Some(topic) = &self.response_topic { + buffer.put_u8(PropertyType::ResponseTopic as u8); + write_mqtt_string(buffer, topic); + } + + if let Some(data) = &self.correlation_data { + buffer.put_u8(PropertyType::CorrelationData as u8); + write_mqtt_bytes(buffer, data); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct Login { + pub username: String, + pub password: String, +} + +impl Login { + pub fn new, P: Into>(u: U, p: P) -> Login { + Login { + username: u.into(), + password: p.into(), + } + } + + fn read(connect_flags: u8, mut bytes: &mut Bytes) -> Result, Error> { + let username = match connect_flags & 0b1000_0000 { + 0 => String::new(), + _ => read_mqtt_string(&mut bytes)?, + }; + + let password = match connect_flags & 0b0100_0000 { + 0 => String::new(), + _ => read_mqtt_string(&mut bytes)?, + }; + + if username.is_empty() && password.is_empty() { + Ok(None) + } else { + Ok(Some(Login { username, password })) + } + } + + fn len(&self) -> usize { + let mut len = 0; + + if !self.username.is_empty() { + len += 2 + self.username.len(); + } + + if !self.password.is_empty() { + len += 2 + self.password.len(); + } + + len + } + + fn write(&self, buffer: &mut BytesMut) -> u8 { + let mut connect_flags = 0; + if !self.username.is_empty() { + connect_flags |= 0x80; + write_mqtt_string(buffer, &self.username); + } + + if !self.password.is_empty() { + connect_flags |= 0x40; + write_mqtt_string(buffer, &self.password); + } + + connect_flags + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct ConnectProperties { + /// Expiry interval property after loosing connection + pub session_expiry_interval: Option, + /// Maximum simultaneous packets + pub receive_maximum: Option, + /// Maximum packet size + pub max_packet_size: Option, + /// Maximum mapping integer for a topic + pub topic_alias_max: Option, + pub request_response_info: Option, + pub request_problem_info: Option, + /// List of user properties + pub user_properties: Vec<(String, String)>, + /// Method of authentication + pub authentication_method: Option, + /// Authentication data + pub authentication_data: Option, +} + +impl ConnectProperties { + fn _new() -> ConnectProperties { + ConnectProperties { + session_expiry_interval: None, + receive_maximum: None, + max_packet_size: None, + topic_alias_max: None, + request_response_info: None, + request_problem_info: None, + user_properties: Vec::new(), + authentication_method: None, + authentication_data: None, + } + } + + fn read(mut bytes: &mut Bytes) -> Result, Error> { + let mut session_expiry_interval = None; + let mut receive_maximum = None; + let mut max_packet_size = None; + let mut topic_alias_max = None; + let mut request_response_info = None; + let mut request_problem_info = None; + let mut user_properties = Vec::new(); + let mut authentication_method = None; + let mut authentication_data = None; + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + match property(prop)? { + PropertyType::SessionExpiryInterval => { + session_expiry_interval = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::ReceiveMaximum => { + receive_maximum = Some(read_u16(&mut bytes)?); + cursor += 2; + } + PropertyType::MaximumPacketSize => { + max_packet_size = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::TopicAliasMaximum => { + topic_alias_max = Some(read_u16(&mut bytes)?); + cursor += 2; + } + PropertyType::RequestResponseInformation => { + request_response_info = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::RequestProblemInformation => { + request_problem_info = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + PropertyType::AuthenticationMethod => { + let method = read_mqtt_string(&mut bytes)?; + cursor += 2 + method.len(); + authentication_method = Some(method); + } + PropertyType::AuthenticationData => { + let data = read_mqtt_bytes(&mut bytes)?; + cursor += 2 + data.len(); + authentication_data = Some(data); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(ConnectProperties { + session_expiry_interval, + receive_maximum, + max_packet_size, + topic_alias_max, + request_response_info, + request_problem_info, + user_properties, + authentication_method, + authentication_data, + })) + } + + fn len(&self) -> usize { + let mut len = 0; + + if self.session_expiry_interval.is_some() { + len += 1 + 4; + } + + if self.receive_maximum.is_some() { + len += 1 + 2; + } + + if self.max_packet_size.is_some() { + len += 1 + 4; + } + + if self.topic_alias_max.is_some() { + len += 1 + 2; + } + + if self.request_response_info.is_some() { + len += 1 + 1; + } + + if self.request_problem_info.is_some() { + len += 1 + 1; + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + if let Some(authentication_method) = &self.authentication_method { + len += 1 + 2 + authentication_method.len(); + } + + if let Some(authentication_data) = &self.authentication_data { + len += 1 + 2 + authentication_data.len(); + } + + len + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(session_expiry_interval) = self.session_expiry_interval { + buffer.put_u8(PropertyType::SessionExpiryInterval as u8); + buffer.put_u32(session_expiry_interval); + } + + if let Some(receive_maximum) = self.receive_maximum { + buffer.put_u8(PropertyType::ReceiveMaximum as u8); + buffer.put_u16(receive_maximum); + } + + if let Some(max_packet_size) = self.max_packet_size { + buffer.put_u8(PropertyType::MaximumPacketSize as u8); + buffer.put_u32(max_packet_size); + } + + if let Some(topic_alias_max) = self.topic_alias_max { + buffer.put_u8(PropertyType::TopicAliasMaximum as u8); + buffer.put_u16(topic_alias_max); + } + + if let Some(request_response_info) = self.request_response_info { + buffer.put_u8(PropertyType::RequestResponseInformation as u8); + buffer.put_u8(request_response_info); + } + + if let Some(request_problem_info) = self.request_problem_info { + buffer.put_u8(PropertyType::RequestProblemInformation as u8); + buffer.put_u8(request_problem_info); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + if let Some(authentication_method) = &self.authentication_method { + buffer.put_u8(PropertyType::AuthenticationMethod as u8); + write_mqtt_string(buffer, authentication_method); + } + + if let Some(authentication_data) = &self.authentication_data { + buffer.put_u8(PropertyType::AuthenticationData as u8); + write_mqtt_bytes(buffer, authentication_data); + } + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use super::*; + + fn sample() -> Connect { + let connect_properties = ConnectProperties { + session_expiry_interval: Some(1234), + receive_maximum: Some(432), + max_packet_size: Some(100), + topic_alias_max: Some(456), + request_response_info: Some(1), + request_problem_info: Some(1), + user_properties: vec![("test".to_owned(), "test".to_owned())], + authentication_method: Some("test".to_owned()), + authentication_data: Some(Bytes::from(vec![1, 2, 3, 4])), + }; + + let will_properties = WillProperties { + delay_interval: Some(1234), + payload_format_indicator: Some(0), + message_expiry_interval: Some(4321), + content_type: Some("test".to_owned()), + response_topic: Some("topic".to_owned()), + correlation_data: Some(Bytes::from(vec![1, 2, 3, 4])), + user_properties: vec![("test".to_owned(), "test".to_owned())], + }; + + let will = LastWill { + topic: "mydevice/status".to_string(), + message: Bytes::from(vec![b'd', b'e', b'a', b'd']), + qos: QoS::AtMostOnce, + retain: false, + properties: Some(will_properties), + }; + + let login = Login { + username: "matteo".to_string(), + password: "collina".to_string(), + }; + + Connect { + protocol: Protocol::V5, + keep_alive: 0, + properties: Some(connect_properties), + client_id: "my-device".to_string(), + clean_session: true, + last_will: Some(will), + login: Some(login), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0x10, // packet type + 0x9d, // remaining len + 0x01, // remaining len + 0x00, 0x04, // 4 + 0x4d, // M + 0x51, // Q + 0x54, // T + 0x54, // T + 0x05, // Level + 0xc6, // connect flags + 0x00, 0x00, // keep alive + 0x2f, // properties len + 0x11, 0x00, 0x00, 0x04, 0xd2, // session expiry interval + 0x21, 0x01, 0xb0, // receive_maximum + 0x27, 0x00, 0x00, 0x00, 0x64, // max packet size + 0x22, 0x01, 0xc8, // topic_alias_max + 0x19, 0x01, // request_response_info + 0x17, 0x01, // request_problem_info + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // user + 0x15, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // authentication_method + 0x16, 0x00, 0x04, 0x01, 0x02, 0x03, 0x04, // authentication_data + 0x00, 0x09, 0x6d, 0x79, 0x2d, 0x64, 0x65, 0x76, 0x69, 0x63, 0x65, // client id + 0x2f, // will properties len + 0x18, 0x00, 0x00, 0x04, 0xd2, // will delay interval + 0x01, 0x00, // payload format indicator + 0x02, 0x00, 0x00, 0x10, 0xe1, // message expiry interval + 0x03, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // content type + 0x08, 0x00, 0x05, 0x74, 0x6f, 0x70, 0x69, 0x63, // response topic + 0x09, 0x00, 0x04, 0x01, 0x02, 0x03, 0x04, // correlation_data + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // will user properties + 0x00, 0x0f, 0x6d, 0x79, 0x64, 0x65, 0x76, 0x69, 0x63, 0x65, 0x2f, 0x73, 0x74, 0x61, + 0x74, 0x75, 0x73, // will topic + 0x00, 0x04, 0x64, 0x65, 0x61, 0x64, // will payload + 0x00, 0x06, 0x6d, 0x61, 0x74, 0x74, 0x65, 0x6f, // username + 0x00, 0x07, 0x63, 0x6f, 0x6c, 0x6c, 0x69, 0x6e, 0x61, // password + ] + } + + #[test] + fn connect1_parsing_works() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let connect_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let connect = Connect::read(fixed_header, connect_bytes).unwrap(); + assert_eq!(connect, sample()); + } + + #[test] + fn connect1_encoding_works() { + let connect = sample(); + let mut buf = BytesMut::new(); + connect.write(&mut buf).unwrap(); + assert_eq!(&buf[..], sample_bytes()); + } + + fn sample2() -> Connect { + Connect { + protocol: Protocol::V5, + keep_alive: 10, + properties: None, + client_id: "hackathonmqtt5test".to_owned(), + clean_session: true, + last_will: None, + login: None, + } + } + + fn sample2_bytes() -> Vec { + vec![ + 0x10, // packet type + 0x1f, 0x00, // remaining len + 0x04, // 4 + 0x4d, 0x51, 0x54, 0x54, // MQTT + 0x05, // level + 0x02, // connect flags + 0x00, 0x0a, // keep alive + 0x00, 0x00, 0x12, 0x68, 0x61, 0x63, 0x6b, 0x61, 0x74, 0x68, 0x6f, 0x6e, 0x6d, 0x71, + 0x74, 0x74, 0x35, 0x74, 0x65, 0x73, 0x74, // payload + 0x10, 0x11, 0x12, // extra bytes in the stream + ] + } + + #[test] + fn connect2_parsing_works() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample2_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let connect_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let connect = Connect::read(fixed_header, connect_bytes).unwrap(); + assert_eq!(connect, sample2()); + } + + #[test] + fn connect2_encoding_works() { + let connect = sample2(); + let mut buf = BytesMut::new(); + connect.write(&mut buf).unwrap(); + + let expected = sample2_bytes(); + assert_eq!(&buf[..], &expected[0..(expected.len() - 3)]); + } + + fn sample3() -> Connect { + let connect_properties = ConnectProperties { + session_expiry_interval: Some(1234), + receive_maximum: Some(432), + max_packet_size: Some(100), + topic_alias_max: Some(456), + request_response_info: Some(1), + request_problem_info: Some(1), + user_properties: vec![("test".to_owned(), "test".to_owned())], + authentication_method: Some("test".to_owned()), + authentication_data: Some(Bytes::from(vec![1, 2, 3, 4])), + }; + + let will = LastWill { + topic: "mydevice/status".to_string(), + message: Bytes::from(vec![b'd', b'e', b'a', b'd']), + qos: QoS::AtMostOnce, + retain: false, + properties: None, + }; + + let login = Login { + username: "matteo".to_string(), + password: "collina".to_string(), + }; + + Connect { + protocol: Protocol::V5, + keep_alive: 0, + properties: Some(connect_properties), + client_id: "my-device".to_string(), + clean_session: true, + last_will: Some(will), + login: Some(login), + } + } + + fn sample3_bytes() -> Vec { + vec![ + 0x10, 0x6e, 0x00, 0x04, 0x4d, 0x51, 0x54, 0x54, 0x05, 0xc6, 0x00, 0x00, 0x2f, 0x11, + 0x00, 0x00, 0x04, 0xd2, 0x21, 0x01, 0xb0, 0x27, 0x00, 0x00, 0x00, 0x64, 0x22, 0x01, + 0xc8, 0x19, 0x01, 0x17, 0x01, 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, + 0x74, 0x65, 0x73, 0x74, 0x15, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x16, 0x00, 0x04, + 0x01, 0x02, 0x03, 0x04, 0x00, 0x09, 0x6d, 0x79, 0x2d, 0x64, 0x65, 0x76, 0x69, 0x63, + 0x65, 0x00, 0x00, 0x0f, 0x6d, 0x79, 0x64, 0x65, 0x76, 0x69, 0x63, 0x65, 0x2f, 0x73, + 0x74, 0x61, 0x74, 0x75, 0x73, 0x00, 0x04, 0x64, 0x65, 0x61, 0x64, 0x00, 0x06, 0x6d, + 0x61, 0x74, 0x74, 0x65, 0x6f, 0x00, 0x07, 0x63, 0x6f, 0x6c, 0x6c, 0x69, 0x6e, 0x61, + ] + } + + #[test] + fn connect3_parsing_works() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample3_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let connect_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let connect = Connect::read(fixed_header, connect_bytes).unwrap(); + assert_eq!(connect, sample3()); + } + + #[test] + fn connect3_encoding_works() { + let connect = sample3(); + let mut buf = BytesMut::new(); + connect.write(&mut buf).unwrap(); + + let expected = sample3_bytes(); + assert_eq!(&buf[..], &expected[0..(expected.len())]); + } + + #[test] + fn missing_properties_are_encoded() {} +} diff --git a/rumqttc/src/v5/packet/disconnect.rs b/rumqttc/src/v5/packet/disconnect.rs new file mode 100644 index 000000000..3662087f1 --- /dev/null +++ b/rumqttc/src/v5/packet/disconnect.rs @@ -0,0 +1,434 @@ +use std::convert::{TryFrom, TryInto}; + +use bytes::{BufMut, Bytes, BytesMut}; + +use super::*; + +use super::{property, PropertyType}; + +#[derive(Debug, Clone, Copy, PartialEq)] +#[repr(u8)] +pub enum DisconnectReasonCode { + /// Close the connection normally. Do not send the Will Message. + NormalDisconnection = 0x00, + /// The Client wishes to disconnect but requires that the Server also publishes its Will Message. + DisconnectWithWillMessage = 0x04, + /// The Connection is closed but the sender either does not wish to reveal the reason, or none of the other Reason Codes apply. + UnspecifiedError = 0x80, + /// The received packet does not conform to this specification. + MalformedPacket = 0x81, + /// An unexpected or out of order packet was received. + ProtocolError = 0x82, + /// The packet received is valid but cannot be processed by this implementation. + ImplementationSpecificError = 0x83, + /// The request is not authorized. + NotAuthorized = 0x87, + /// The Server is busy and cannot continue processing requests from this Client. + ServerBusy = 0x89, + /// The Server is shutting down. + ServerShuttingDown = 0x8B, + /// The Connection is closed because no packet has been received for 1.5 times the Keepalive time. + KeepAliveTimeout = 0x8D, + /// Another Connection using the same ClientID has connected causing this Connection to be closed. + SessionTakenOver = 0x8E, + /// The Topic Filter is correctly formed, but is not accepted by this Sever. + TopicFilterInvalid = 0x8F, + /// The Topic Name is correctly formed, but is not accepted by this Client or Server. + TopicNameInvalid = 0x90, + /// The Client or Server has received more than Receive Maximum publication for which it has not sent PUBACK or PUBCOMP. + ReceiveMaximumExceeded = 0x93, + /// The Client or Server has received a PUBLISH packet containing a Topic Alias which is greater than the Maximum Topic Alias it sent in the CONNECT or CONNACK packet. + TopicAliasInvalid = 0x94, + /// The packet size is greater than Maximum Packet Size for this Client or Server. + PacketTooLarge = 0x95, + /// The received data rate is too high. + MessageRateTooHigh = 0x96, + /// An implementation or administrative imposed limit has been exceeded. + QuotaExceeded = 0x97, + /// The Connection is closed due to an administrative action. + AdministrativeAction = 0x98, + /// The payload format does not match the one specified by the Payload Format Indicator. + PayloadFormatInvalid = 0x99, + /// The Server has does not support retained messages. + RetainNotSupported = 0x9A, + /// The Client specified a QoS greater than the QoS specified in a Maximum QoS in the CONNACK. + QoSNotSupported = 0x9B, + /// The Client should temporarily change its Server. + UseAnotherServer = 0x9C, + /// The Server is moved and the Client should permanently change its server location. + ServerMoved = 0x9D, + /// The Server does not support Shared Subscriptions. + SharedSubscriptionNotSupported = 0x9E, + /// This connection is closed because the connection rate is too high. + ConnectionRateExceeded = 0x9F, + /// The maximum connection time authorized for this connection has been exceeded. + MaximumConnectTime = 0xA0, + /// The Server does not support Subscription Identifiers; the subscription is not accepted. + SubscriptionIdentifiersNotSupported = 0xA1, + /// The Server does not support Wildcard subscription; the subscription is not accepted. + WildcardSubscriptionsNotSupported = 0xA2, +} + +impl TryFrom for DisconnectReasonCode { + type Error = Error; + + fn try_from(value: u8) -> Result { + let rc = match value { + 0x00 => Self::NormalDisconnection, + 0x04 => Self::DisconnectWithWillMessage, + 0x80 => Self::UnspecifiedError, + 0x81 => Self::MalformedPacket, + 0x82 => Self::ProtocolError, + 0x83 => Self::ImplementationSpecificError, + 0x87 => Self::NotAuthorized, + 0x89 => Self::ServerBusy, + 0x8B => Self::ServerShuttingDown, + 0x8D => Self::KeepAliveTimeout, + 0x8E => Self::SessionTakenOver, + 0x8F => Self::TopicFilterInvalid, + 0x90 => Self::TopicNameInvalid, + 0x93 => Self::ReceiveMaximumExceeded, + 0x94 => Self::TopicAliasInvalid, + 0x95 => Self::PacketTooLarge, + 0x96 => Self::MessageRateTooHigh, + 0x97 => Self::QuotaExceeded, + 0x98 => Self::AdministrativeAction, + 0x99 => Self::PayloadFormatInvalid, + 0x9A => Self::RetainNotSupported, + 0x9B => Self::QoSNotSupported, + 0x9C => Self::UseAnotherServer, + 0x9D => Self::ServerMoved, + 0x9E => Self::SharedSubscriptionNotSupported, + 0x9F => Self::ConnectionRateExceeded, + 0xA0 => Self::MaximumConnectTime, + 0xA1 => Self::SubscriptionIdentifiersNotSupported, + 0xA2 => Self::WildcardSubscriptionsNotSupported, + other => return Err(Error::InvalidConnectReturnCode(other)), + }; + + Ok(rc) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct DisconnectProperties { + /// Session Expiry Interval in seconds + pub session_expiry_interval: Option, + + /// Human readable reason for the disconnect + pub reason_string: Option, + + /// List of user properties + pub user_properties: Vec<(String, String)>, + + /// String which can be used by the Client to identify another Server to use. + pub server_reference: Option, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct Disconnect { + /// Disconnect Reason Code + pub reason_code: DisconnectReasonCode, + + /// Disconnect Properties + pub properties: Option, +} + +impl DisconnectProperties { + pub fn new() -> Self { + Self { + session_expiry_interval: None, + reason_string: None, + user_properties: Vec::new(), + server_reference: None, + } + } + + fn len(&self) -> usize { + let mut length = 0; + + if self.session_expiry_interval.is_some() { + length += 1 + 4; + } + + if let Some(reason) = &self.reason_string { + length += 1 + 2 + reason.len(); + } + + for (key, value) in self.user_properties.iter() { + length += 1 + 2 + key.len() + 2 + value.len(); + } + + if let Some(server_reference) = &self.server_reference { + length += 1 + 2 + server_reference.len(); + } + + length + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let (properties_len_len, properties_len) = length(bytes.iter())?; + + bytes.advance(properties_len_len); + + if properties_len == 0 { + return Ok(None); + } + + let mut session_expiry_interval = None; + let mut reason_string = None; + let mut user_properties = Vec::new(); + let mut server_reference = None; + + let mut cursor = 0; + + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::SessionExpiryInterval => { + session_expiry_interval = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::ReasonString => { + let reason = read_mqtt_string(&mut bytes)?; + cursor += 2 + reason.len(); + reason_string = Some(reason); + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + PropertyType::ServerReference => { + let reference = read_mqtt_string(&mut bytes)?; + cursor += 2 + reference.len(); + server_reference = Some(reference); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + let properties = Self { + session_expiry_interval, + reason_string, + user_properties, + server_reference, + }; + + Ok(Some(properties)) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let length = self.len(); + write_remaining_length(buffer, length)?; + + if let Some(session_expiry_interval) = self.session_expiry_interval { + buffer.put_u8(PropertyType::SessionExpiryInterval as u8); + buffer.put_u32(session_expiry_interval); + } + + if let Some(reason) = &self.reason_string { + buffer.put_u8(PropertyType::ReasonString as u8); + write_mqtt_string(buffer, reason); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + if let Some(reference) = &self.server_reference { + buffer.put_u8(PropertyType::ServerReference as u8); + write_mqtt_string(buffer, reference); + } + + Ok(()) + } +} + +impl Disconnect { + pub fn new() -> Self { + Self { + reason_code: DisconnectReasonCode::NormalDisconnection, + properties: None, + } + } + + fn len(&self) -> usize { + if self.reason_code == DisconnectReasonCode::NormalDisconnection + && self.properties.is_none() + { + return 2; // Packet type + 0x00 + } + + let mut length = 0; + + match &self.properties { + Some(properties) => { + length += 1; // Disconnect Reason Code + + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + length += properties_len_len + properties_len; + } + None if self.reason_code == DisconnectReasonCode::NormalDisconnection => {} + None => { + length += 1; // Disconnect Reason Code + } + }; + + length + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let packet_type = fixed_header.byte1 >> 4; + let flags = fixed_header.byte1 & 0b0000_1111; + + bytes.advance(fixed_header.fixed_header_len); + + if packet_type != PacketType::Disconnect as u8 { + return Err(Error::InvalidPacketType(packet_type)); + }; + + if flags != 0x00 { + return Err(Error::MalformedPacket); + }; + + if fixed_header.remaining_len == 0 { + return Ok(Self::new()); + } + + let reason_code = read_u8(&mut bytes)?; + + let disconnect = Self { + reason_code: reason_code.try_into()?, + properties: DisconnectProperties::extract(&mut bytes)?, + }; + + Ok(disconnect) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + buffer.put_u8(0xE0); + + let length = self.len(); + + if length == 2 { + buffer.put_u8(0x00); + + return Ok(length); + } + + let len_len = write_remaining_length(buffer, length)?; + + buffer.put_u8(self.reason_code as u8); + + if let Some(properties) = &self.properties { + properties.write(buffer)?; + } + + Ok(1 + len_len + length) + } +} + +#[cfg(test)] +mod test { + use bytes::BytesMut; + + use super::parse_fixed_header; + + use super::{Disconnect, DisconnectProperties, DisconnectReasonCode}; + + #[test] + fn disconnect1_parsing_works() { + let mut buffer = bytes::BytesMut::new(); + let packet_bytes = [ + 0xE0, // Packet type + 0x00, // Remaining length + ]; + let expected = Disconnect::new(); + + buffer.extend_from_slice(&packet_bytes[..]); + + let fixed_header = parse_fixed_header(buffer.iter()).unwrap(); + let disconnect_bytes = buffer.split_to(fixed_header.frame_length()).freeze(); + let disconnect = Disconnect::read(fixed_header, disconnect_bytes).unwrap(); + + assert_eq!(disconnect, expected); + } + + #[test] + fn disconnect1_encoding_works() { + let mut buffer = BytesMut::new(); + let disconnect = Disconnect::new(); + let expected = [ + 0xE0, // Packet type + 0x00, // Remaining length + ]; + + disconnect.write(&mut buffer).unwrap(); + + assert_eq!(&buffer[..], &expected); + } + + fn sample2() -> Disconnect { + let properties = DisconnectProperties { + // TODO: change to 2137 xD + session_expiry_interval: Some(1234), + reason_string: Some("test".to_owned()), + user_properties: vec![("test".to_owned(), "test".to_owned())], + server_reference: Some("test".to_owned()), + }; + + Disconnect { + reason_code: DisconnectReasonCode::UnspecifiedError, + properties: Some(properties), + } + } + + fn sample_bytes2() -> Vec { + vec![ + 0xE0, // Packet type + 0x22, // Remaining length + 0x80, // Disconnect Reason Code + 0x20, // Properties length + 0x11, 0x00, 0x00, 0x04, 0xd2, // Session expiry interval + 0x1F, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // Reason string + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // User properties + 0x1C, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // server reference + ] + } + + #[test] + fn disconnect2_parsing_works() { + let mut buffer = bytes::BytesMut::new(); + let packet_bytes = sample_bytes2(); + let expected = sample2(); + + buffer.extend_from_slice(&packet_bytes[..]); + + let fixed_header = parse_fixed_header(buffer.iter()).unwrap(); + let disconnect_bytes = buffer.split_to(fixed_header.frame_length()).freeze(); + let disconnect = Disconnect::read(fixed_header, disconnect_bytes).unwrap(); + + assert_eq!(disconnect, expected); + } + + #[test] + fn disconnect2_encoding_works() { + let mut buffer = BytesMut::new(); + + let disconnect = sample2(); + let expected = sample_bytes2(); + + disconnect.write(&mut buffer).unwrap(); + + assert_eq!(&buffer[..], &expected); + } +} diff --git a/rumqttc/src/v5/packet/mod.rs b/rumqttc/src/v5/packet/mod.rs new file mode 100644 index 000000000..8f954c1cf --- /dev/null +++ b/rumqttc/src/v5/packet/mod.rs @@ -0,0 +1,489 @@ +use std::{ + fmt::{self, Display, Formatter}, + slice::Iter, +}; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +mod connack; +mod connect; +mod disconnect; +mod ping; +mod puback; +mod pubcomp; +mod publish; +mod pubrec; +mod pubrel; +mod suback; +mod subscribe; +mod unsuback; +mod unsubscribe; + +pub use connack::*; +pub use connect::*; +pub use disconnect::*; +pub use ping::*; +pub use puback::*; +pub use pubcomp::*; +pub use publish::*; +pub use pubrec::*; +pub use pubrel::*; +pub use suback::*; +pub use subscribe::*; +pub use unsuback::*; +pub use unsubscribe::*; + +/// Encapsulates all MQTT packet types +#[derive(Debug, Clone, PartialEq)] +pub enum Packet { + Connect(Connect), + ConnAck(ConnAck), + Publish(Publish), + PubAck(PubAck), + PubRec(PubRec), + PubRel(PubRel), + PubComp(PubComp), + Subscribe(Subscribe), + SubAck(SubAck), + Unsubscribe(Unsubscribe), + UnsubAck(UnsubAck), + PingReq, + PingResp, + Disconnect(Disconnect), +} + +/// MQTT packet type +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PacketType { + Connect = 1, + ConnAck, + Publish, + PubAck, + PubRec, + PubRel, + PubComp, + Subscribe, + SubAck, + Unsubscribe, + UnsubAck, + PingReq, + PingResp, + Disconnect, +} + +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum PropertyType { + PayloadFormatIndicator = 1, + MessageExpiryInterval = 2, + ContentType = 3, + ResponseTopic = 8, + CorrelationData = 9, + SubscriptionIdentifier = 11, + SessionExpiryInterval = 17, + AssignedClientIdentifier = 18, + ServerKeepAlive = 19, + AuthenticationMethod = 21, + AuthenticationData = 22, + RequestProblemInformation = 23, + WillDelayInterval = 24, + RequestResponseInformation = 25, + ResponseInformation = 26, + ServerReference = 28, + ReasonString = 31, + ReceiveMaximum = 33, + TopicAliasMaximum = 34, + TopicAlias = 35, + MaximumQos = 36, + RetainAvailable = 37, + UserProperty = 38, + MaximumPacketSize = 39, + WildcardSubscriptionAvailable = 40, + SubscriptionIdentifierAvailable = 41, + SharedSubscriptionAvailable = 42, +} + +/// Error during serialization and deserialization +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Error { + NotConnect(PacketType), + UnexpectedConnect, + InvalidConnectReturnCode(u8), + InvalidReason(u8), + InvalidProtocol, + InvalidProtocolLevel(u8), + IncorrectPacketFormat, + InvalidPacketType(u8), + InvalidPropertyType(u8), + InvalidRetainForwardRule(u8), + InvalidQoS(u8), + InvalidSubscribeReasonCode(u8), + PacketIdZero, + SubscriptionIdZero, + PayloadSizeIncorrect, + PayloadTooLong, + PayloadSizeLimitExceeded(usize), + PayloadRequired, + TopicNotUtf8, + BoundaryCrossed(usize), + MalformedPacket, + MalformedRemainingLength, + /// More bytes required to frame packet. Argument + /// implies minimum additional bytes required to + /// proceed further + InsufficientBytes(usize), +} + +/// Protocol type +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Protocol { + V4, + V5, +} + +/// Quality of service +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)] +pub enum QoS { + AtMostOnce = 0, + AtLeastOnce = 1, + ExactlyOnce = 2, +} + +/// Packet type from a byte +/// +/// ```ignore +/// 7 3 0 +/// +--------------------------+--------------------------+ +/// byte 1 | MQTT Control Packet Type | Flags for each type | +/// +--------------------------+--------------------------+ +/// | Remaining Bytes Len (1/2/3/4 bytes) | +/// +-----------------------------------------------------+ +/// +/// http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Figure_2.2_- +/// ``` +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)] +pub struct FixedHeader { + /// First byte of the stream. Used to identify packet types and + /// several flags + byte1: u8, + /// Length of fixed header. Byte 1 + (1..4) bytes. So fixed header + /// len can vary from 2 bytes to 5 bytes + /// 1..4 bytes are variable length encoded to represent remaining length + fixed_header_len: usize, + /// Remaining length of the packet. Doesn't include fixed header bytes + /// Represents variable header + payload size + remaining_len: usize, +} + +impl FixedHeader { + pub fn new(byte1: u8, remaining_len_len: usize, remaining_len: usize) -> FixedHeader { + FixedHeader { + byte1, + fixed_header_len: remaining_len_len + 1, + remaining_len, + } + } + + pub fn packet_type(&self) -> Result { + let num = self.byte1 >> 4; + match num { + 1 => Ok(PacketType::Connect), + 2 => Ok(PacketType::ConnAck), + 3 => Ok(PacketType::Publish), + 4 => Ok(PacketType::PubAck), + 5 => Ok(PacketType::PubRec), + 6 => Ok(PacketType::PubRel), + 7 => Ok(PacketType::PubComp), + 8 => Ok(PacketType::Subscribe), + 9 => Ok(PacketType::SubAck), + 10 => Ok(PacketType::Unsubscribe), + 11 => Ok(PacketType::UnsubAck), + 12 => Ok(PacketType::PingReq), + 13 => Ok(PacketType::PingResp), + 14 => Ok(PacketType::Disconnect), + _ => Err(Error::InvalidPacketType(num)), + } + } + + /// Returns the size of full packet (fixed header + variable header + payload) + /// Fixed header is enough to get the size of a frame in the stream + pub fn frame_length(&self) -> usize { + self.fixed_header_len + self.remaining_len + } +} + +fn property(num: u8) -> Result { + let property = match num { + 1 => PropertyType::PayloadFormatIndicator, + 2 => PropertyType::MessageExpiryInterval, + 3 => PropertyType::ContentType, + 8 => PropertyType::ResponseTopic, + 9 => PropertyType::CorrelationData, + 11 => PropertyType::SubscriptionIdentifier, + 17 => PropertyType::SessionExpiryInterval, + 18 => PropertyType::AssignedClientIdentifier, + 19 => PropertyType::ServerKeepAlive, + 21 => PropertyType::AuthenticationMethod, + 22 => PropertyType::AuthenticationData, + 23 => PropertyType::RequestProblemInformation, + 24 => PropertyType::WillDelayInterval, + 25 => PropertyType::RequestResponseInformation, + 26 => PropertyType::ResponseInformation, + 28 => PropertyType::ServerReference, + 31 => PropertyType::ReasonString, + 33 => PropertyType::ReceiveMaximum, + 34 => PropertyType::TopicAliasMaximum, + 35 => PropertyType::TopicAlias, + 36 => PropertyType::MaximumQos, + 37 => PropertyType::RetainAvailable, + 38 => PropertyType::UserProperty, + 39 => PropertyType::MaximumPacketSize, + 40 => PropertyType::WildcardSubscriptionAvailable, + 41 => PropertyType::SubscriptionIdentifierAvailable, + 42 => PropertyType::SharedSubscriptionAvailable, + num => return Err(Error::InvalidPropertyType(num)), + }; + + Ok(property) +} + +/// Checks if the stream has enough bytes to frame a packet and returns fixed header +/// only if a packet can be framed with existing bytes in the `stream`. +/// The passed stream doesn't modify parent stream's cursor. If this function +/// returned an error, next `check` on the same parent stream is forced start +/// with cursor at 0 again (Iter is owned. Only Iter's cursor is changed internally) +pub fn check(stream: Iter, max_packet_size: usize) -> Result { + // Create fixed header if there are enough bytes in the stream + // to frame full packet + let stream_len = stream.len(); + let fixed_header = parse_fixed_header(stream)?; + + // Don't let rogue connections attack with huge payloads. + // Disconnect them before reading all that data + if fixed_header.remaining_len > max_packet_size { + return Err(Error::PayloadSizeLimitExceeded(fixed_header.remaining_len)); + } + + // If the current call fails due to insufficient bytes in the stream, + // after calculating remaining length, we extend the stream + let frame_length = fixed_header.frame_length(); + if stream_len < frame_length { + return Err(Error::InsufficientBytes(frame_length - stream_len)); + } + + Ok(fixed_header) +} + +/// Parses fixed header +fn parse_fixed_header(mut stream: Iter) -> Result { + // At least 2 bytes are necessary to frame a packet + let stream_len = stream.len(); + if stream_len < 2 { + return Err(Error::InsufficientBytes(2 - stream_len)); + } + + let byte1 = stream.next().unwrap(); + let (len_len, len) = length(stream)?; + + Ok(FixedHeader::new(*byte1, len_len, len)) +} + +/// Parses variable byte integer in the stream and returns the length +/// and number of bytes that make it. Used for remaining length calculation +/// as well as for calculating property lengths +fn length(stream: Iter) -> Result<(usize, usize), Error> { + let mut len: usize = 0; + let mut len_len = 0; + let mut done = false; + let mut shift = 0; + + // Use continuation bit at position 7 to continue reading next + // byte to frame 'length'. + // Stream 0b1xxx_xxxx 0b1yyy_yyyy 0b1zzz_zzzz 0b0www_wwww will + // be framed as number 0bwww_wwww_zzz_zzzz_yyy_yyyy_xxx_xxxx + for byte in stream { + len_len += 1; + let byte = *byte as usize; + len += (byte & 0x7F) << shift; + + // stop when continue bit is 0 + done = (byte & 0x80) == 0; + if done { + break; + } + + shift += 7; + + // Only a max of 4 bytes allowed for remaining length + // more than 4 shifts (0, 7, 14, 21) implies bad length + if shift > 21 { + return Err(Error::MalformedRemainingLength); + } + } + + // Not enough bytes to frame remaining length. wait for + // one more byte + if !done { + return Err(Error::InsufficientBytes(1)); + } + + Ok((len_len, len)) +} + +/// Reads a stream of bytes and extracts next MQTT packet out of it +pub fn read(stream: &mut BytesMut, max_size: usize) -> Result { + let fixed_header = check(stream.iter(), max_size)?; + + // Test with a stream with exactly the size to check border panics + let packet = stream.split_to(fixed_header.frame_length()); + let packet_type = fixed_header.packet_type()?; + + if fixed_header.remaining_len == 0 { + // no payload packets + return match packet_type { + PacketType::PingReq => Ok(Packet::PingReq), + PacketType::PingResp => Ok(Packet::PingResp), + _ => Err(Error::PayloadRequired), + }; + } + + let packet = packet.freeze(); + let packet = match packet_type { + PacketType::Connect => Packet::Connect(Connect::read(fixed_header, packet)?), + PacketType::ConnAck => Packet::ConnAck(ConnAck::read(fixed_header, packet)?), + PacketType::Publish => Packet::Publish(Publish::read(fixed_header, packet)?), + PacketType::PubAck => Packet::PubAck(PubAck::read(fixed_header, packet)?), + PacketType::PubRec => Packet::PubRec(PubRec::read(fixed_header, packet)?), + PacketType::PubRel => Packet::PubRel(PubRel::read(fixed_header, packet)?), + PacketType::PubComp => Packet::PubComp(PubComp::read(fixed_header, packet)?), + PacketType::Subscribe => Packet::Subscribe(Subscribe::read(fixed_header, packet)?), + PacketType::SubAck => Packet::SubAck(SubAck::read(fixed_header, packet)?), + PacketType::Unsubscribe => Packet::Unsubscribe(Unsubscribe::read(fixed_header, packet)?), + PacketType::UnsubAck => Packet::UnsubAck(UnsubAck::read(fixed_header, packet)?), + PacketType::PingReq => Packet::PingReq, + PacketType::PingResp => Packet::PingResp, + PacketType::Disconnect => Packet::Disconnect(Disconnect::read(fixed_header, packet)?), + }; + + Ok(packet) +} + +/// Reads a series of bytes with a length from a byte stream +fn read_mqtt_bytes(stream: &mut Bytes) -> Result { + let len = read_u16(stream)? as usize; + + // Prevent attacks with wrong remaining length. This method is used in + // `packet.assembly()` with (enough) bytes to frame packet. Ensures that + // reading variable len string or bytes doesn't cross promised boundary + // with `read_fixed_header()` + if len > stream.len() { + return Err(Error::BoundaryCrossed(len)); + } + + Ok(stream.split_to(len)) +} + +/// Reads a string from bytes stream +fn read_mqtt_string(stream: &mut Bytes) -> Result { + let s = read_mqtt_bytes(stream)?; + match String::from_utf8(s.to_vec()) { + Ok(v) => Ok(v), + Err(_e) => Err(Error::TopicNotUtf8), + } +} + +/// Serializes bytes to stream (including length) +fn write_mqtt_bytes(stream: &mut BytesMut, bytes: &[u8]) { + stream.put_u16(bytes.len() as u16); + stream.extend_from_slice(bytes); +} + +/// Serializes a string to stream +fn write_mqtt_string(stream: &mut BytesMut, string: &str) { + write_mqtt_bytes(stream, string.as_bytes()); +} + +/// Writes remaining length to stream and returns number of bytes for remaining length +fn write_remaining_length(stream: &mut BytesMut, len: usize) -> Result { + if len > 268_435_455 { + return Err(Error::PayloadTooLong); + } + + let mut done = false; + let mut x = len; + let mut count = 0; + + while !done { + let mut byte = (x % 128) as u8; + x /= 128; + if x > 0 { + byte |= 128; + } + + stream.put_u8(byte); + count += 1; + done = x == 0; + } + + Ok(count) +} + +/// Return number of remaining length bytes required for encoding length +fn len_len(len: usize) -> usize { + if len >= 2_097_152 { + 4 + } else if len >= 16_384 { + 3 + } else if len >= 128 { + 2 + } else { + 1 + } +} + +/// Maps a number to QoS +pub fn qos(num: u8) -> Result { + match num { + 0 => Ok(QoS::AtMostOnce), + 1 => Ok(QoS::AtLeastOnce), + 2 => Ok(QoS::ExactlyOnce), + qos => Err(Error::InvalidQoS(qos)), + } +} + +/// After collecting enough bytes to frame a packet (packet's frame()) +/// , It's possible that content itself in the stream is wrong. Like expected +/// packet id or qos not being present. In cases where `read_mqtt_string` or +/// `read_mqtt_bytes` exhausted remaining length but packet framing expects to +/// parse qos next, these pre checks will prevent `bytes` crashes +fn read_u16(stream: &mut Bytes) -> Result { + if stream.len() < 2 { + return Err(Error::MalformedPacket); + } + + Ok(stream.get_u16()) +} + +fn read_u8(stream: &mut Bytes) -> Result { + if stream.is_empty() { + return Err(Error::MalformedPacket); + } + + Ok(stream.get_u8()) +} + +fn read_u32(stream: &mut Bytes) -> Result { + if stream.len() < 4 { + return Err(Error::MalformedPacket); + } + + Ok(stream.get_u32()) +} + +impl Display for Error { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "Error = {:?}", self) + } +} diff --git a/rumqttc/src/v5/packet/ping.rs b/rumqttc/src/v5/packet/ping.rs new file mode 100644 index 000000000..1072029bd --- /dev/null +++ b/rumqttc/src/v5/packet/ping.rs @@ -0,0 +1,20 @@ +use super::*; +use bytes::{BufMut, BytesMut}; + +pub struct PingReq; + +impl PingReq { + pub fn write(&self, payload: &mut BytesMut) -> Result { + payload.put_slice(&[0xC0, 0x00]); + Ok(2) + } +} + +pub struct PingResp; + +impl PingResp { + pub fn write(&self, payload: &mut BytesMut) -> Result { + payload.put_slice(&[0xD0, 0x00]); + Ok(2) + } +} diff --git a/rumqttc/src/v5/packet/puback.rs b/rumqttc/src/v5/packet/puback.rs new file mode 100644 index 000000000..51131949e --- /dev/null +++ b/rumqttc/src/v5/packet/puback.rs @@ -0,0 +1,324 @@ +use super::*; +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +/// Return code in connack +#[derive(Debug, Clone, Copy, PartialEq)] +#[repr(u8)] +pub enum PubAckReason { + Success = 0, + NoMatchingSubscribers = 16, + UnspecifiedError = 128, + ImplementationSpecificError = 131, + NotAuthorized = 135, + TopicNameInvalid = 144, + PacketIdentifierInUse = 145, + QuotaExceeded = 151, + PayloadFormatInvalid = 153, +} + +/// Acknowledgement to QoS1 publish +#[derive(Debug, Clone, PartialEq)] +pub struct PubAck { + pub pkid: u16, + pub reason: PubAckReason, + pub properties: Option, +} + +impl PubAck { + pub fn new(pkid: u16) -> PubAck { + PubAck { + pkid, + reason: PubAckReason::Success, + properties: None, + } + } + + fn len(&self) -> usize { + let mut len = 2 + 1; // pkid + reason + + // If there are no properties, sending reason code is optional + if self.reason == PubAckReason::Success && self.properties.is_none() { + return 2; + } + + if let Some(properties) = &self.properties { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + + // Unlike other packets, property length can be ignored if there are + // no properties in acks + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + let pkid = read_u16(&mut bytes)?; + + // No reason code or properties if remaining length == 2 + if fixed_header.remaining_len == 2 { + return Ok(PubAck { + pkid, + reason: PubAckReason::Success, + properties: None, + }); + } + + // No properties len or properties if remaining len > 2 but < 4 + let ack_reason = read_u8(&mut bytes)?; + if fixed_header.remaining_len < 4 { + return Ok(PubAck { + pkid, + reason: reason(ack_reason)?, + properties: None, + }); + } + + let puback = PubAck { + pkid, + reason: reason(ack_reason)?, + properties: PubAckProperties::extract(&mut bytes)?, + }; + + Ok(puback) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + let len = self.len(); + buffer.put_u8(0x40); + + let count = write_remaining_length(buffer, len)?; + buffer.put_u16(self.pkid); + + // Reason code is optional with success if there are no properties + if self.reason == PubAckReason::Success && self.properties.is_none() { + return Ok(4); + } + + buffer.put_u8(self.reason as u8); + if let Some(properties) = &self.properties { + properties.write(buffer)?; + } + + Ok(1 + count + len) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct PubAckProperties { + pub reason_string: Option, + pub user_properties: Vec<(String, String)>, +} + +impl PubAckProperties { + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(reason) = &self.reason_string { + len += 1 + 2 + reason.len(); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut reason_string = None; + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::ReasonString => { + let reason = read_mqtt_string(&mut bytes)?; + cursor += 2 + reason.len(); + reason_string = Some(reason); + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(PubAckProperties { + reason_string, + user_properties, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(reason) = &self.reason_string { + buffer.put_u8(PropertyType::ReasonString as u8); + write_mqtt_string(buffer, reason); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } +} +/// Connection return code type +fn reason(num: u8) -> Result { + let code = match num { + 0 => PubAckReason::Success, + 16 => PubAckReason::NoMatchingSubscribers, + 128 => PubAckReason::UnspecifiedError, + 131 => PubAckReason::ImplementationSpecificError, + 135 => PubAckReason::NotAuthorized, + 144 => PubAckReason::TopicNameInvalid, + 145 => PubAckReason::PacketIdentifierInUse, + 151 => PubAckReason::QuotaExceeded, + 153 => PubAckReason::PayloadFormatInvalid, + num => return Err(Error::InvalidConnectReturnCode(num)), + }; + + Ok(code) +} + +#[cfg(v5)] +#[cfg(test)] +mod test { + use super::*; + use alloc::vec; + use bytes::BytesMut; + use pretty_assertions::assert_eq; + + fn sample() -> PubAck { + let properties = PubAckProperties { + reason_string: Some("test".to_owned()), + user_properties: vec![("test".to_owned(), "test".to_owned())], + }; + + PubAck { + pkid: 42, + reason: PubAckReason::NoMatchingSubscribers, + properties: Some(properties), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0x40, // payload type + 0x18, // remaining length + 0x00, 0x2a, // packet id + 0x10, // reason + 0x14, // properties len + 0x1f, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // reason_string + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // user properties + ] + } + + #[test] + fn puback_parsing_works() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let puback_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let puback = PubAck::read(fixed_header, puback_bytes).unwrap(); + assert_eq!(puback, sample()); + } + + #[test] + fn puback_encoding_works() { + let puback = sample(); + let mut buf = BytesMut::new(); + puback.write(&mut buf).unwrap(); + assert_eq!(&buf[..], sample_bytes()); + } + + fn sample2() -> PubAck { + PubAck { + pkid: 42, + reason: PubAckReason::NoMatchingSubscribers, + properties: None, + } + } + + fn sample2_bytes() -> Vec { + vec![0x40, 0x03, 0x00, 0x2a, 0x10] + } + + #[test] + fn puback2_parsing_works() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample2_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let puback_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let puback = PubAck::read(fixed_header, puback_bytes).unwrap(); + assert_eq!(puback, sample2()); + } + + #[test] + fn puback2_encoding_works() { + let puback = sample2(); + let mut buf = BytesMut::new(); + + puback.write(&mut buf).unwrap(); + assert_eq!(&buf[..], sample2_bytes()); + } + + fn sample3() -> PubAck { + PubAck { + pkid: 42, + reason: PubAckReason::Success, + properties: None, + } + } + + fn sample3_bytes() -> Vec { + vec![0x40, 0x02, 0x00, 0x2a] + } + + #[test] + fn puback3_parsing_works() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample3_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let puback_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let puback = PubAck::read(fixed_header, puback_bytes).unwrap(); + assert_eq!(puback, sample3()); + } + + #[test] + fn puback3_encoding_works() { + let puback = sample3(); + let mut buf = BytesMut::new(); + + puback.write(&mut buf).unwrap(); + assert_eq!(&buf[..], sample3_bytes()); + } +} diff --git a/rumqttc/src/v5/packet/pubcomp.rs b/rumqttc/src/v5/packet/pubcomp.rs new file mode 100644 index 000000000..badb97867 --- /dev/null +++ b/rumqttc/src/v5/packet/pubcomp.rs @@ -0,0 +1,237 @@ +use super::*; +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +/// Return code in connack +#[derive(Debug, Clone, Copy, PartialEq)] +#[repr(u8)] +pub enum PubCompReason { + Success = 0, + PacketIdentifierNotFound = 146, +} + +/// Acknowledgement to QoS1 publish +#[derive(Debug, Clone, PartialEq)] +pub struct PubComp { + pub pkid: u16, + pub reason: PubCompReason, + pub properties: Option, +} + +impl PubComp { + pub fn new(pkid: u16) -> PubComp { + PubComp { + pkid, + reason: PubCompReason::Success, + properties: None, + } + } + + fn len(&self) -> usize { + let mut len = 2 + 1; // pkid + reason + + // If there are no properties during success, sending reason code is optional + if self.reason == PubCompReason::Success && self.properties.is_none() { + return 2; + } + + if let Some(properties) = &self.properties { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + let pkid = read_u16(&mut bytes)?; + + if fixed_header.remaining_len == 2 { + return Ok(PubComp { + pkid, + reason: PubCompReason::Success, + properties: None, + }); + } + + let ack_reason = read_u8(&mut bytes)?; + if fixed_header.remaining_len < 4 { + return Ok(PubComp { + pkid, + reason: reason(ack_reason)?, + properties: None, + }); + } + + let puback = PubComp { + pkid, + reason: reason(ack_reason)?, + properties: PubCompProperties::extract(&mut bytes)?, + }; + + Ok(puback) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + let len = self.len(); + buffer.put_u8(0x70); + let count = write_remaining_length(buffer, len)?; + buffer.put_u16(self.pkid); + + // If there are no properties during success, sending reason code is optional + if self.reason == PubCompReason::Success && self.properties.is_none() { + return Ok(4); + } + + buffer.put_u8(self.reason as u8); + + if let Some(properties) = &self.properties { + properties.write(buffer)?; + } + + Ok(1 + count + len) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct PubCompProperties { + pub reason_string: Option, + pub user_properties: Vec<(String, String)>, +} + +impl PubCompProperties { + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(reason) = &self.reason_string { + len += 1 + 2 + reason.len(); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut reason_string = None; + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::ReasonString => { + let reason = read_mqtt_string(&mut bytes)?; + cursor += 2 + reason.len(); + reason_string = Some(reason); + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(PubCompProperties { + reason_string, + user_properties, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(reason) = &self.reason_string { + buffer.put_u8(PropertyType::ReasonString as u8); + write_mqtt_string(buffer, reason); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } +} +/// Connection return code type +fn reason(num: u8) -> Result { + let code = match num { + 0 => PubCompReason::Success, + 146 => PubCompReason::PacketIdentifierNotFound, + num => return Err(Error::InvalidConnectReturnCode(num)), + }; + + Ok(code) +} + +#[cfg(test)] +mod test { + use super::*; + use bytes::BytesMut; + use pretty_assertions::assert_eq; + + fn sample() -> PubComp { + let properties = PubCompProperties { + reason_string: Some("test".to_owned()), + user_properties: vec![("test".to_owned(), "test".to_owned())], + }; + + PubComp { + pkid: 42, + reason: PubCompReason::PacketIdentifierNotFound, + properties: Some(properties), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0x70, // payload type + 0x18, // remaining length + 0x00, 0x2a, // packet id + 0x92, // reason + 0x14, // properties len + 0x1f, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // reason_string + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // user properties + ] + } + + #[test] + fn pubcomp_parsing_works_correctly() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let pubcomp_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let pubcomp = PubComp::read(fixed_header, pubcomp_bytes).unwrap(); + assert_eq!(pubcomp, sample()); + } + + #[test] + fn pubcomp_encoding_works_correctly() { + let pubcomp = sample(); + let mut buf = BytesMut::new(); + pubcomp.write(&mut buf).unwrap(); + assert_eq!(&buf[..], sample_bytes()); + } +} diff --git a/rumqttc/src/v5/packet/publish.rs b/rumqttc/src/v5/packet/publish.rs new file mode 100644 index 000000000..9f0e228ea --- /dev/null +++ b/rumqttc/src/v5/packet/publish.rs @@ -0,0 +1,394 @@ +use super::*; +use bytes::{Buf, Bytes}; +use core::fmt; + +/// Publish packet +#[derive(Clone, PartialEq)] +pub struct Publish { + pub dup: bool, + pub qos: QoS, + pub retain: bool, + pub topic: String, + pub pkid: u16, + pub properties: Option, + pub payload: Bytes, +} + +impl Publish { + pub fn new, P: Into>>(topic: S, qos: QoS, payload: P) -> Publish { + Publish { + dup: false, + qos, + retain: false, + pkid: 0, + topic: topic.into(), + properties: None, + payload: Bytes::from(payload.into()), + } + } + + pub fn from_bytes>(topic: S, qos: QoS, payload: Bytes) -> Publish { + Publish { + dup: false, + qos, + retain: false, + pkid: 0, + topic: topic.into(), + properties: None, + payload, + } + } + + pub fn len(&self) -> usize { + let mut len = 2 + self.topic.len(); + if self.qos != QoS::AtMostOnce && self.pkid != 0 { + len += 2; + } + + match &self.properties { + Some(properties) => { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + None => { + // just 1 byte representing 0 len + len += 1; + } + } + + len += self.payload.len(); + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let qos = qos((fixed_header.byte1 & 0b0110) >> 1)?; + let dup = (fixed_header.byte1 & 0b1000) != 0; + let retain = (fixed_header.byte1 & 0b0001) != 0; + + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + let topic = read_mqtt_string(&mut bytes)?; + + // Packet identifier exists where QoS > 0 + let pkid = match qos { + QoS::AtMostOnce => 0, + QoS::AtLeastOnce | QoS::ExactlyOnce => read_u16(&mut bytes)?, + }; + + if qos != QoS::AtMostOnce && pkid == 0 { + return Err(Error::PacketIdZero); + } + + let publish = Publish { + dup, + retain, + qos, + pkid, + topic, + properties: PublishProperties::extract(&mut bytes)?, + payload: bytes, + }; + + Ok(publish) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + let len = self.len(); + + let dup = self.dup as u8; + let qos = self.qos as u8; + let retain = self.retain as u8; + buffer.put_u8(0b0011_0000 | retain | qos << 1 | dup << 3); + + let count = write_remaining_length(buffer, len)?; + write_mqtt_string(buffer, self.topic.as_str()); + + if self.qos != QoS::AtMostOnce { + let pkid = self.pkid; + if pkid == 0 { + return Err(Error::PacketIdZero); + } + + buffer.put_u16(pkid); + } + + match &self.properties { + Some(properties) => properties.write(buffer)?, + None => { + write_remaining_length(buffer, 0)?; + } + }; + + buffer.extend_from_slice(&self.payload); + + // TODO: Returned length is wrong in other packets. Fix it + Ok(1 + count + len) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct PublishProperties { + pub payload_format_indicator: Option, + pub message_expiry_interval: Option, + pub topic_alias: Option, + pub response_topic: Option, + pub correlation_data: Option, + pub user_properties: Vec<(String, String)>, + pub subscription_identifiers: Vec, + pub content_type: Option, +} + +impl PublishProperties { + fn len(&self) -> usize { + let mut len = 0; + + if self.payload_format_indicator.is_some() { + len += 1 + 1; + } + + if self.message_expiry_interval.is_some() { + len += 1 + 4; + } + + if self.topic_alias.is_some() { + len += 1 + 2; + } + + if let Some(topic) = &self.response_topic { + len += 1 + 2 + topic.len() + } + + if let Some(data) = &self.correlation_data { + len += 1 + 2 + data.len() + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + for id in self.subscription_identifiers.iter() { + len += 1 + len_len(*id); + } + + if let Some(typ) = &self.content_type { + len += 1 + 2 + typ.len() + } + + len + } + + fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut payload_format_indicator = None; + let mut message_expiry_interval = None; + let mut topic_alias = None; + let mut response_topic = None; + let mut correlation_data = None; + let mut user_properties = Vec::new(); + let mut subscription_identifiers = Vec::new(); + let mut content_type = None; + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::PayloadFormatIndicator => { + payload_format_indicator = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::MessageExpiryInterval => { + message_expiry_interval = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::TopicAlias => { + topic_alias = Some(read_u16(&mut bytes)?); + cursor += 2; + } + PropertyType::ResponseTopic => { + let topic = read_mqtt_string(&mut bytes)?; + cursor += 2 + topic.len(); + response_topic = Some(topic); + } + PropertyType::CorrelationData => { + let data = read_mqtt_bytes(&mut bytes)?; + cursor += 2 + data.len(); + correlation_data = Some(data); + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + PropertyType::SubscriptionIdentifier => { + let (id_len, id) = length(bytes.iter())?; + cursor += 1 + id_len; + bytes.advance(id_len); + subscription_identifiers.push(id); + } + PropertyType::ContentType => { + let typ = read_mqtt_string(&mut bytes)?; + cursor += 2 + typ.len(); + content_type = Some(typ); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(PublishProperties { + payload_format_indicator, + message_expiry_interval, + topic_alias, + response_topic, + correlation_data, + user_properties, + subscription_identifiers, + content_type, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(payload_format_indicator) = self.payload_format_indicator { + buffer.put_u8(PropertyType::PayloadFormatIndicator as u8); + buffer.put_u8(payload_format_indicator); + } + + if let Some(message_expiry_interval) = self.message_expiry_interval { + buffer.put_u8(PropertyType::MessageExpiryInterval as u8); + buffer.put_u32(message_expiry_interval); + } + + if let Some(topic_alias) = self.topic_alias { + buffer.put_u8(PropertyType::TopicAlias as u8); + buffer.put_u16(topic_alias); + } + + if let Some(topic) = &self.response_topic { + buffer.put_u8(PropertyType::ResponseTopic as u8); + write_mqtt_string(buffer, topic); + } + + if let Some(data) = &self.correlation_data { + buffer.put_u8(PropertyType::CorrelationData as u8); + write_mqtt_bytes(buffer, data); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + for id in self.subscription_identifiers.iter() { + buffer.put_u8(PropertyType::SubscriptionIdentifier as u8); + write_remaining_length(buffer, *id)?; + } + + if let Some(typ) = &self.content_type { + buffer.put_u8(PropertyType::ContentType as u8); + write_mqtt_string(buffer, typ); + } + + Ok(()) + } +} + +impl fmt::Debug for Publish { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "Topic = {}, Qos = {:?}, Retain = {}, Pkid = {:?}, Payload Size = {}", + self.topic, + self.qos, + self.retain, + self.pkid, + self.payload.len() + ) + } +} + +#[cfg(test)] +mod test { + use super::*; + use bytes::{Bytes, BytesMut}; + use pretty_assertions::assert_eq; + + fn sample_v5() -> Publish { + let publish_properties = PublishProperties { + payload_format_indicator: Some(1), + message_expiry_interval: Some(4321), + topic_alias: Some(100), + response_topic: Some("topic".to_owned()), + correlation_data: Some(Bytes::from(vec![1, 2, 3, 4])), + user_properties: vec![("test".to_owned(), "test".to_owned())], + subscription_identifiers: vec![120, 121], + content_type: Some("test".to_owned()), + }; + + Publish { + dup: false, + qos: QoS::ExactlyOnce, + retain: false, + topic: "test".to_string(), + pkid: 42, + properties: Some(publish_properties), + payload: Bytes::from(vec![b't', b'e', b's', b't']), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0x34, // payload type + 0x3e, // remaining len + 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // topic name + 0x00, 0x2a, // pkid + 0x31, // properties len + 0x01, 0x01, // payload format indicator + 0x02, 0x00, 0x00, 0x10, 0xe1, // message_expiry_interval + 0x23, 0x00, 0x64, // topic alias + 0x08, 0x00, 0x05, 0x74, 0x6f, 0x70, 0x69, 0x63, // response topic + 0x09, 0x00, 0x04, 0x01, 0x02, 0x03, 0x04, // correlation_data + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // user properties + 0x0b, 0x78, // subscription_identifier + 0x0b, 0x79, // subscription_identifier + 0x03, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // content_type + 0x74, 0x65, 0x73, 0x74, // payload + ] + } + + #[test] + fn publish_parsing_works() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let publish_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let publish = Publish::read(fixed_header, publish_bytes).unwrap(); + assert_eq!(publish, sample_v5()); + } + + #[test] + fn publish_encoding_works() { + let publish = sample_v5(); + let mut buf = BytesMut::new(); + publish.write(&mut buf).unwrap(); + assert_eq!(&buf[..], sample_bytes()); + } + + #[test] + fn missing_properties_are_encoded() {} +} diff --git a/rumqttc/src/v5/packet/pubrec.rs b/rumqttc/src/v5/packet/pubrec.rs new file mode 100644 index 000000000..5e8de572e --- /dev/null +++ b/rumqttc/src/v5/packet/pubrec.rs @@ -0,0 +1,252 @@ +use super::*; +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +/// Return code in connack +#[derive(Debug, Clone, Copy, PartialEq)] +#[repr(u8)] +pub enum PubRecReason { + Success = 0, + NoMatchingSubscribers = 16, + UnspecifiedError = 128, + ImplementationSpecificError = 131, + NotAuthorized = 135, + TopicNameInvalid = 144, + PacketIdentifierInUse = 145, + QuotaExceeded = 151, + PayloadFormatInvalid = 153, +} + +/// Acknowledgement to QoS1 publish +#[derive(Debug, Clone, PartialEq)] +pub struct PubRec { + pub pkid: u16, + pub reason: PubRecReason, + pub properties: Option, +} + +impl PubRec { + pub fn new(pkid: u16) -> PubRec { + PubRec { + pkid, + reason: PubRecReason::Success, + properties: None, + } + } + + fn len(&self) -> usize { + let mut len = 2 + 1; // pkid + reason + + // If there are no properties during success, sending reason code is optional + if self.reason == PubRecReason::Success && self.properties.is_none() { + return 2; + } + + if let Some(properties) = &self.properties { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + + // Unlike other packets, property length can be ignored if there are + // no properties in acks + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + let pkid = read_u16(&mut bytes)?; + if fixed_header.remaining_len == 2 { + return Ok(PubRec { + pkid, + reason: PubRecReason::Success, + properties: None, + }); + } + + let ack_reason = read_u8(&mut bytes)?; + if fixed_header.remaining_len < 4 { + return Ok(PubRec { + pkid, + reason: reason(ack_reason)?, + properties: None, + }); + } + + let puback = PubRec { + pkid, + reason: reason(ack_reason)?, + properties: PubRecProperties::extract(&mut bytes)?, + }; + + Ok(puback) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + let len = self.len(); + buffer.put_u8(0x50); + let count = write_remaining_length(buffer, len)?; + buffer.put_u16(self.pkid); + + // If there are no properties during success, sending reason code is optional + if self.reason == PubRecReason::Success && self.properties.is_none() { + return Ok(4); + } + + buffer.put_u8(self.reason as u8); + + if let Some(properties) = &self.properties { + properties.write(buffer)?; + } + + Ok(1 + count + len) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct PubRecProperties { + pub reason_string: Option, + pub user_properties: Vec<(String, String)>, +} + +impl PubRecProperties { + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(reason) = &self.reason_string { + len += 1 + 2 + reason.len(); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut reason_string = None; + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::ReasonString => { + let reason = read_mqtt_string(&mut bytes)?; + cursor += 2 + reason.len(); + reason_string = Some(reason); + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(PubRecProperties { + reason_string, + user_properties, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(reason) = &self.reason_string { + buffer.put_u8(PropertyType::ReasonString as u8); + write_mqtt_string(buffer, reason); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } +} +/// Connection return code type +fn reason(num: u8) -> Result { + let code = match num { + 0 => PubRecReason::Success, + 16 => PubRecReason::NoMatchingSubscribers, + 128 => PubRecReason::UnspecifiedError, + 131 => PubRecReason::ImplementationSpecificError, + 135 => PubRecReason::NotAuthorized, + 144 => PubRecReason::TopicNameInvalid, + 145 => PubRecReason::PacketIdentifierInUse, + 151 => PubRecReason::QuotaExceeded, + 153 => PubRecReason::PayloadFormatInvalid, + num => return Err(Error::InvalidConnectReturnCode(num)), + }; + + Ok(code) +} + +#[cfg(test)] +mod test { + use super::*; + use bytes::BytesMut; + use pretty_assertions::assert_eq; + + fn sample() -> PubRec { + let properties = PubRecProperties { + reason_string: Some("test".to_owned()), + user_properties: vec![("test".to_owned(), "test".to_owned())], + }; + + PubRec { + pkid: 42, + reason: PubRecReason::NoMatchingSubscribers, + properties: Some(properties), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0x50, // payload type + 0x18, // remaining length + 0x00, 0x2a, // packet id + 0x10, // reason + 0x14, // properties len + 0x1f, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // reason_string + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // user properties + ] + } + + #[test] + fn pubrec_parsing_works() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let pubrec_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let pubrec = PubRec::read(fixed_header, pubrec_bytes).unwrap(); + assert_eq!(pubrec, sample()); + } + + #[test] + fn pubrec_encoding_works() { + let pubrec = sample(); + let mut buf = BytesMut::new(); + pubrec.write(&mut buf).unwrap(); + assert_eq!(&buf[..], sample_bytes()); + } +} diff --git a/rumqttc/src/v5/packet/pubrel.rs b/rumqttc/src/v5/packet/pubrel.rs new file mode 100644 index 000000000..1a1a62e4d --- /dev/null +++ b/rumqttc/src/v5/packet/pubrel.rs @@ -0,0 +1,236 @@ +use super::*; +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +/// Return code in connack +#[derive(Debug, Clone, Copy, PartialEq)] +#[repr(u8)] +pub enum PubRelReason { + Success = 0, + PacketIdentifierNotFound = 146, +} + +/// Acknowledgement to QoS1 publish +#[derive(Debug, Clone, PartialEq)] +pub struct PubRel { + pub pkid: u16, + pub reason: PubRelReason, + pub properties: Option, +} + +impl PubRel { + pub fn new(pkid: u16) -> PubRel { + PubRel { + pkid, + reason: PubRelReason::Success, + properties: None, + } + } + + fn len(&self) -> usize { + let mut len = 2 + 1; // pkid + reason + + // If there are no properties during success, sending reason code is optional + if self.reason == PubRelReason::Success && self.properties.is_none() { + return 2; + } + + if let Some(properties) = &self.properties { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + let pkid = read_u16(&mut bytes)?; + if fixed_header.remaining_len == 2 { + return Ok(PubRel { + pkid, + reason: PubRelReason::Success, + properties: None, + }); + } + + let ack_reason = read_u8(&mut bytes)?; + if fixed_header.remaining_len < 4 { + return Ok(PubRel { + pkid, + reason: reason(ack_reason)?, + properties: None, + }); + } + + let puback = PubRel { + pkid, + reason: reason(ack_reason)?, + properties: PubRelProperties::extract(&mut bytes)?, + }; + + Ok(puback) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + let len = self.len(); + buffer.put_u8(0x62); + let count = write_remaining_length(buffer, len)?; + buffer.put_u16(self.pkid); + + // If there are no properties during success, sending reason code is optional + if self.reason == PubRelReason::Success && self.properties.is_none() { + return Ok(4); + } + + buffer.put_u8(self.reason as u8); + + if let Some(properties) = &self.properties { + properties.write(buffer)?; + } + + Ok(1 + count + len) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct PubRelProperties { + pub reason_string: Option, + pub user_properties: Vec<(String, String)>, +} + +impl PubRelProperties { + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(reason) = &self.reason_string { + len += 1 + 2 + reason.len(); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut reason_string = None; + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::ReasonString => { + let reason = read_mqtt_string(&mut bytes)?; + cursor += 2 + reason.len(); + reason_string = Some(reason); + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(PubRelProperties { + reason_string, + user_properties, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(reason) = &self.reason_string { + buffer.put_u8(PropertyType::ReasonString as u8); + write_mqtt_string(buffer, reason); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } +} +/// Connection return code type +fn reason(num: u8) -> Result { + let code = match num { + 0 => PubRelReason::Success, + 146 => PubRelReason::PacketIdentifierNotFound, + num => return Err(Error::InvalidConnectReturnCode(num)), + }; + + Ok(code) +} + +#[cfg(test)] +mod test { + use super::*; + use bytes::BytesMut; + use pretty_assertions::assert_eq; + + fn sample() -> PubRel { + let properties = PubRelProperties { + reason_string: Some("test".to_owned()), + user_properties: vec![("test".to_owned(), "test".to_owned())], + }; + + PubRel { + pkid: 42, + reason: PubRelReason::PacketIdentifierNotFound, + properties: Some(properties), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0x62, // payload type + 0x18, // remaining length + 0x00, 0x2a, // packet id + 0x92, // reason + 0x14, // properties len + 0x1f, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // reason_string + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // user properties + ] + } + + #[test] + fn pubrel_parsing_works() { + let mut stream = bytes::BytesMut::new(); + let packetstream = &sample_bytes(); + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let pubrel_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let pubrel = PubRel::read(fixed_header, pubrel_bytes).unwrap(); + assert_eq!(pubrel, sample()); + } + + #[test] + fn pubrel_encoding_works() { + let pubrel = sample(); + let mut buf = BytesMut::new(); + pubrel.write(&mut buf).unwrap(); + assert_eq!(&buf[..], sample_bytes()); + } +} diff --git a/rumqttc/src/v5/packet/suback.rs b/rumqttc/src/v5/packet/suback.rs new file mode 100644 index 000000000..0ec0ead05 --- /dev/null +++ b/rumqttc/src/v5/packet/suback.rs @@ -0,0 +1,263 @@ +use super::*; +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use std::convert::{TryFrom, TryInto}; + +/// Acknowledgement to subscribe +#[derive(Debug, Clone, PartialEq)] +pub struct SubAck { + pub pkid: u16, + pub return_codes: Vec, + pub properties: Option, +} + +impl SubAck { + pub fn new(pkid: u16, return_codes: Vec) -> SubAck { + SubAck { + pkid, + return_codes, + properties: None, + } + } + + pub fn len(&self) -> usize { + let mut len = 2 + self.return_codes.len(); + + match &self.properties { + Some(properties) => { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + None => { + // just 1 byte representing 0 len + len += 1; + } + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + let pkid = read_u16(&mut bytes)?; + let properties = SubAckProperties::extract(&mut bytes)?; + + if !bytes.has_remaining() { + return Err(Error::MalformedPacket); + } + + let mut return_codes = Vec::new(); + while bytes.has_remaining() { + let return_code = read_u8(&mut bytes)?; + return_codes.push(return_code.try_into()?); + } + + let suback = SubAck { + pkid, + return_codes, + properties, + }; + + Ok(suback) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + buffer.put_u8(0x90); + let remaining_len = self.len(); + let remaining_len_bytes = write_remaining_length(buffer, remaining_len)?; + + buffer.put_u16(self.pkid); + + match &self.properties { + Some(properties) => properties.write(buffer)?, + None => { + write_remaining_length(buffer, 0)?; + } + }; + + let p: Vec = self.return_codes.iter().map(|code| *code as u8).collect(); + buffer.extend_from_slice(&p); + Ok(1 + remaining_len_bytes + remaining_len) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct SubAckProperties { + pub reason_string: Option, + pub user_properties: Vec<(String, String)>, +} + +impl SubAckProperties { + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(reason) = &self.reason_string { + len += 1 + 2 + reason.len(); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut reason_string = None; + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::ReasonString => { + let reason = read_mqtt_string(&mut bytes)?; + cursor += 2 + reason.len(); + reason_string = Some(reason); + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(SubAckProperties { + reason_string, + user_properties, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(reason) = &self.reason_string { + buffer.put_u8(PropertyType::ReasonString as u8); + write_mqtt_string(buffer, reason); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } +} + +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SubscribeReasonCode { + QoS0 = 0, + QoS1 = 1, + QoS2 = 2, + Unspecified = 128, + ImplementationSpecific = 131, + NotAuthorized = 135, + TopicFilterInvalid = 143, + PkidInUse = 145, + QuotaExceeded = 151, + SharedSubscriptionsNotSupported = 158, + SubscriptionIdNotSupported = 161, + WildcardSubscriptionsNotSupported = 162, +} + +impl TryFrom for SubscribeReasonCode { + type Error = super::Error; + + fn try_from(value: u8) -> Result { + let v = match value { + 0 => SubscribeReasonCode::QoS0, + 1 => SubscribeReasonCode::QoS1, + 2 => SubscribeReasonCode::QoS2, + 128 => SubscribeReasonCode::Unspecified, + 131 => SubscribeReasonCode::ImplementationSpecific, + 135 => SubscribeReasonCode::NotAuthorized, + 143 => SubscribeReasonCode::TopicFilterInvalid, + 145 => SubscribeReasonCode::PkidInUse, + 151 => SubscribeReasonCode::QuotaExceeded, + 158 => SubscribeReasonCode::SharedSubscriptionsNotSupported, + 161 => SubscribeReasonCode::SubscriptionIdNotSupported, + 162 => SubscribeReasonCode::WildcardSubscriptionsNotSupported, + v => return Err(super::Error::InvalidSubscribeReasonCode(v)), + }; + + Ok(v) + } +} + +#[cfg(test)] +mod test { + use super::*; + use bytes::BytesMut; + use pretty_assertions::assert_eq; + + fn sample() -> SubAck { + let properties = SubAckProperties { + reason_string: Some("test".to_owned()), + user_properties: vec![("test".to_owned(), "test".to_owned())], + }; + + SubAck { + pkid: 42, + return_codes: vec![ + SubscribeReasonCode::QoS0, + SubscribeReasonCode::QoS1, + SubscribeReasonCode::QoS2, + SubscribeReasonCode::Unspecified, + ], + properties: Some(properties), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0x90, // packet type + 0x1b, // remaining len + 0x00, 0x2a, // pkid + 0x14, // properties len + 0x1f, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, + 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // user properties + 0x00, 0x01, 0x02, 0x80, // return codes + ] + } + + #[test] + fn suback_parsing_works() { + let mut stream = BytesMut::new(); + let packetstream = &sample_bytes(); + + stream.extend_from_slice(&packetstream[..]); + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let suback_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let suback = SubAck::read(fixed_header, suback_bytes).unwrap(); + assert_eq!(suback, sample()); + } + + #[test] + fn suback_encoding_works() { + let publish = sample(); + let mut buf = BytesMut::new(); + publish.write(&mut buf).unwrap(); + + // println!("{:X?}", buf); + // println!("{:#04X?}", &buf[..]); + assert_eq!(&buf[..], sample_bytes()); + } +} diff --git a/rumqttc/src/v5/packet/subscribe.rs b/rumqttc/src/v5/packet/subscribe.rs new file mode 100644 index 000000000..9a5c43d1e --- /dev/null +++ b/rumqttc/src/v5/packet/subscribe.rs @@ -0,0 +1,425 @@ +use super::*; +use bytes::{Buf, Bytes}; +use core::fmt; + +/// Subscription packet +#[derive(Clone, PartialEq)] +pub struct Subscribe { + pub pkid: u16, + pub filters: Vec, + pub properties: Option, +} + +impl Subscribe { + pub fn new>(path: S, qos: QoS) -> Subscribe { + let filter = SubscribeFilter { + path: path.into(), + qos, + nolocal: false, + preserve_retain: false, + retain_forward_rule: RetainForwardRule::OnEverySubscribe, + }; + + Subscribe { + pkid: 0, + filters: vec![filter], + properties: None, + } + } + + pub fn new_many(topics: T) -> Subscribe + where + T: IntoIterator, + { + Subscribe { + pkid: 0, + filters: topics.into_iter().collect(), + properties: None, + } + } + + pub fn empty_subscribe() -> Subscribe { + Subscribe { + pkid: 0, + filters: Vec::new(), + properties: None, + } + } + + pub fn add(&mut self, path: String, qos: QoS) -> &mut Self { + let filter = SubscribeFilter { + path, + qos, + nolocal: false, + preserve_retain: false, + retain_forward_rule: RetainForwardRule::OnEverySubscribe, + }; + + self.filters.push(filter); + self + } + + pub fn len(&self) -> usize { + let mut len = 2 + self.filters.iter().fold(0, |s, t| s + t.len()); + + if let Some(properties) = &self.properties { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } else { + // just 1 byte representing 0 len + len += 1; + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + let pkid = read_u16(&mut bytes)?; + let properties = SubscribeProperties::extract(&mut bytes)?; + + // variable header size = 2 (packet identifier) + let mut filters = Vec::new(); + + while bytes.has_remaining() { + let path = read_mqtt_string(&mut bytes)?; + let options = read_u8(&mut bytes)?; + let requested_qos = options & 0b0000_0011; + + let nolocal = options >> 2 & 0b0000_0001; + let nolocal = nolocal != 0; + + let preserve_retain = options >> 3 & 0b0000_0001; + let preserve_retain = preserve_retain != 0; + + let retain_forward_rule = (options >> 4) & 0b0000_0011; + let retain_forward_rule = match retain_forward_rule { + 0 => RetainForwardRule::OnEverySubscribe, + 1 => RetainForwardRule::OnNewSubscribe, + 2 => RetainForwardRule::Never, + r => return Err(Error::InvalidRetainForwardRule(r)), + }; + + filters.push(SubscribeFilter { + path, + qos: qos(requested_qos)?, + nolocal, + preserve_retain, + retain_forward_rule, + }); + } + + let subscribe = Subscribe { + pkid, + filters, + properties, + }; + + Ok(subscribe) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + // write packet type + buffer.put_u8(0x82); + + // write remaining length + let remaining_len = self.len(); + let remaining_len_bytes = write_remaining_length(buffer, remaining_len)?; + + // write packet id + buffer.put_u16(self.pkid); + + match &self.properties { + Some(properties) => properties.write(buffer)?, + None => { + write_remaining_length(buffer, 0)?; + } + }; + + // write filters + for filter in self.filters.iter() { + filter.write(buffer); + } + + Ok(1 + remaining_len_bytes + remaining_len) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct SubscribeProperties { + pub id: Option, + pub user_properties: Vec<(String, String)>, +} + +impl SubscribeProperties { + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(id) = &self.id { + len += 1 + len_len(*id); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut id = None; + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::SubscriptionIdentifier => { + let (id_len, sub_id) = length(bytes.iter())?; + // TODO: Validate 1 +. Tests are working either way + cursor += 1 + id_len; + bytes.advance(id_len); + id = Some(sub_id) + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(SubscribeProperties { + id, + user_properties, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(id) = &self.id { + buffer.put_u8(PropertyType::SubscriptionIdentifier as u8); + write_remaining_length(buffer, *id)?; + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } +} + +/// Subscription filter +#[derive(Clone, PartialEq)] +pub struct SubscribeFilter { + pub path: String, + pub qos: QoS, + pub nolocal: bool, + pub preserve_retain: bool, + pub retain_forward_rule: RetainForwardRule, +} + +impl SubscribeFilter { + pub fn new(path: String, qos: QoS) -> SubscribeFilter { + SubscribeFilter { + path, + qos, + nolocal: false, + preserve_retain: false, + retain_forward_rule: RetainForwardRule::OnEverySubscribe, + } + } + + pub fn set_nolocal(&mut self, flag: bool) -> &mut Self { + self.nolocal = flag; + self + } + + pub fn set_preserve_retain(&mut self, flag: bool) -> &mut Self { + self.preserve_retain = flag; + self + } + + pub fn set_retain_forward_rule(&mut self, rule: RetainForwardRule) -> &mut Self { + self.retain_forward_rule = rule; + self + } + + pub fn len(&self) -> usize { + // filter len + filter + options + 2 + self.path.len() + 1 + } + + fn write(&self, buffer: &mut BytesMut) { + let mut options = 0; + options |= self.qos as u8; + + if self.nolocal { + options |= 1 << 2; + } + + if self.preserve_retain { + options |= 1 << 3; + } + + match self.retain_forward_rule { + RetainForwardRule::OnEverySubscribe => options |= 0 << 4, + RetainForwardRule::OnNewSubscribe => options |= 1 << 4, + RetainForwardRule::Never => options |= 2 << 4, + } + + write_mqtt_string(buffer, self.path.as_str()); + buffer.put_u8(options); + } +} + +#[derive(Debug, Clone, PartialEq)] +pub enum RetainForwardRule { + OnEverySubscribe, + OnNewSubscribe, + Never, +} + +impl fmt::Debug for Subscribe { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "Filters = {:?}, Packet id = {:?}", + self.filters, self.pkid + ) + } +} + +impl fmt::Debug for SubscribeFilter { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "Filter = {}, Qos = {:?}, Nolocal = {}, Preserve retain = {}, Forward rule = {:?}", + self.path, self.qos, self.nolocal, self.preserve_retain, self.retain_forward_rule + ) + } +} + +#[cfg(test)] +mod test { + use super::*; + use bytes::BytesMut; + use pretty_assertions::assert_eq; + + fn sample() -> Subscribe { + let subscribe_properties = SubscribeProperties { + id: Some(100), + user_properties: vec![("test".to_owned(), "test".to_owned())], + }; + + let mut filter = SubscribeFilter::new("hello".to_owned(), QoS::AtLeastOnce); + filter + .set_nolocal(true) + .set_preserve_retain(true) + .set_retain_forward_rule(RetainForwardRule::Never); + + Subscribe { + pkid: 42, + filters: vec![filter], + properties: Some(subscribe_properties), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0x82, // packet type + 0x1a, // remaining length + 0x00, 0x2a, // pkid + 0x0f, // properties len + 0x0b, 0x64, // subscription identifier + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // user properties + 0x00, 0x05, 0x68, 0x65, 0x6c, 0x6c, 0x6f, // filter + 0x2d, // options + ] + } + + #[test] + fn subscribe_parsing_works() { + let mut stream = BytesMut::new(); + let packetstream = &sample_bytes(); + + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let subscribe_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let subscribe = Subscribe::read(fixed_header, subscribe_bytes).unwrap(); + assert_eq!(subscribe, sample()); + } + + #[test] + fn subscribe_encoding_works() { + let publish = sample(); + let mut buf = BytesMut::new(); + publish.write(&mut buf).unwrap(); + + // println!("{:X?}", buf); + // println!("{:#04X?}", &buf[..]); + assert_eq!(&buf[..], sample_bytes()); + } + + fn sample2() -> Subscribe { + let filter = SubscribeFilter::new("hello/world".to_owned(), QoS::AtLeastOnce); + Subscribe { + pkid: 42, + filters: vec![filter], + properties: None, + } + } + + fn sample2_bytes() -> Vec { + vec![ + 0x82, 0x11, 0x00, 0x2a, 0x00, 0x00, 0x0b, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x2f, 0x77, + 0x6f, 0x72, 0x6c, 0x64, 0x01, + ] + } + + #[test] + fn subscribe2_parsing_works() { + let mut stream = BytesMut::new(); + let packetstream = &sample2_bytes(); + + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let subscribe_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let subscribe = Subscribe::read(fixed_header, subscribe_bytes).unwrap(); + assert_eq!(subscribe, sample2()); + } + + #[test] + fn subscribe2_encoding_works() { + let publish = sample2(); + let mut buf = BytesMut::new(); + publish.write(&mut buf).unwrap(); + + // println!("{:X?}", buf); + // println!("{:#04X?}", &buf[..]); + assert_eq!(&buf[..], sample2_bytes()); + } +} diff --git a/rumqttc/src/v5/packet/unsuback.rs b/rumqttc/src/v5/packet/unsuback.rs new file mode 100644 index 000000000..ce01bb4cd --- /dev/null +++ b/rumqttc/src/v5/packet/unsuback.rs @@ -0,0 +1,249 @@ +use super::*; +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +//// Return code in connack +#[derive(Debug, Clone, Copy, PartialEq)] +#[repr(u8)] +pub enum UnsubAckReason { + Success = 0x00, + NoSubscriptionExisted = 0x11, + UnspecifiedError = 0x80, + ImplementationSpecificError = 0x83, + NotAuthorized = 0x87, + TopicFilterInvalid = 0x8F, + PacketIdentifierInUse = 0x91, +} + +/// Acknowledgement to unsubscribe +#[derive(Debug, Clone, PartialEq)] +pub struct UnsubAck { + pub pkid: u16, + pub reasons: Vec, + pub properties: Option, +} + +impl UnsubAck { + pub fn new(pkid: u16) -> UnsubAck { + UnsubAck { + pkid, + reasons: Vec::new(), + properties: None, + } + } + + pub fn len(&self) -> usize { + let mut len = 2 + self.reasons.len(); + + match &self.properties { + Some(properties) => { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + None => { + // just 1 byte representing 0 len + len += 1; + } + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + let pkid = read_u16(&mut bytes)?; + let properties = UnsubAckProperties::extract(&mut bytes)?; + + if !bytes.has_remaining() { + return Err(Error::MalformedPacket); + } + + let mut reasons = Vec::new(); + while bytes.has_remaining() { + let r = read_u8(&mut bytes)?; + reasons.push(reason(r)?); + } + + let unsuback = UnsubAck { + pkid, + reasons, + properties, + }; + + Ok(unsuback) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + buffer.put_u8(0xB0); + let remaining_len = self.len(); + let remaining_len_bytes = write_remaining_length(buffer, remaining_len)?; + + buffer.put_u16(self.pkid); + + match &self.properties { + Some(properties) => properties.write(buffer)?, + None => { + write_remaining_length(buffer, 0)?; + } + }; + + let p: Vec = self.reasons.iter().map(|code| *code as u8).collect(); + buffer.extend_from_slice(&p); + Ok(1 + remaining_len_bytes + remaining_len) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct UnsubAckProperties { + pub reason_string: Option, + pub user_properties: Vec<(String, String)>, +} + +impl UnsubAckProperties { + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(reason) = &self.reason_string { + len += 1 + 2 + reason.len(); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut reason_string = None; + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::ReasonString => { + let reason = read_mqtt_string(&mut bytes)?; + cursor += 2 + reason.len(); + reason_string = Some(reason); + } + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(UnsubAckProperties { + reason_string, + user_properties, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(reason) = &self.reason_string { + buffer.put_u8(PropertyType::ReasonString as u8); + write_mqtt_string(buffer, reason); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } +} + +/// Connection return code type +fn reason(num: u8) -> Result { + let code = match num { + 0x00 => UnsubAckReason::Success, + 0x11 => UnsubAckReason::NoSubscriptionExisted, + 0x80 => UnsubAckReason::UnspecifiedError, + 0x83 => UnsubAckReason::ImplementationSpecificError, + 0x87 => UnsubAckReason::NotAuthorized, + 0x8F => UnsubAckReason::TopicFilterInvalid, + 0x91 => UnsubAckReason::PacketIdentifierInUse, + num => return Err(Error::InvalidSubscribeReasonCode(num)), + }; + + Ok(code) +} + +#[cfg(test)] +mod test { + use super::*; + use bytes::BytesMut; + use pretty_assertions::assert_eq; + + fn sample() -> UnsubAck { + let properties = UnsubAckProperties { + reason_string: Some("test".to_owned()), + user_properties: vec![("test".to_owned(), "test".to_owned())], + }; + + UnsubAck { + pkid: 10, + reasons: vec![ + UnsubAckReason::NotAuthorized, + UnsubAckReason::TopicFilterInvalid, + ], + properties: Some(properties), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0xb0, // packet type + 0x19, // remaining len + 0x00, 0x0a, // pkid + 0x14, // properties len + 0x1f, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, // reason string + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // user properties + 0x87, 0x8f, // reasons + ] + } + + #[test] + fn unsuback_parsing_works() { + let mut stream = BytesMut::new(); + let packetstream = &sample_bytes(); + + stream.extend_from_slice(&packetstream[..]); + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let unsuback_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let unsuback = UnsubAck::read(fixed_header, unsuback_bytes).unwrap(); + assert_eq!(unsuback, sample()); + } + + #[test] + fn unsuback_encoding_works() { + let publish = sample(); + let mut buf = BytesMut::new(); + publish.write(&mut buf).unwrap(); + + // println!("{:X?}", buf); + // println!("{:#04X?}", &buf[..]); + assert_eq!(&buf[..], sample_bytes()); + } +} diff --git a/rumqttc/src/v5/packet/unsubscribe.rs b/rumqttc/src/v5/packet/unsubscribe.rs new file mode 100644 index 000000000..8700c5132 --- /dev/null +++ b/rumqttc/src/v5/packet/unsubscribe.rs @@ -0,0 +1,238 @@ +use super::*; +use bytes::{Buf, Bytes}; + +/// Unsubscribe packet +#[derive(Debug, Clone, PartialEq)] +pub struct Unsubscribe { + pub pkid: u16, + pub filters: Vec, + pub properties: Option, +} + +impl Unsubscribe { + pub fn new>(topic: S) -> Unsubscribe { + Unsubscribe { + pkid: 0, + filters: vec![topic.into()], + properties: None, + } + } + + pub fn len(&self) -> usize { + // Packet id + length of filters (unlike subscribe, this just a string. + // Hence 2 is prefixed for len per filter) + let mut len = 2 + self.filters.iter().fold(0, |s, t| 2 + s + t.len()); + + if let Some(properties) = &self.properties { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } else { + // just 1 byte representing 0 len + len += 1; + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + let pkid = read_u16(&mut bytes)?; + dbg!(pkid); + let properties = UnsubscribeProperties::extract(&mut bytes)?; + + let mut filters = Vec::with_capacity(1); + while bytes.has_remaining() { + let filter = read_mqtt_string(&mut bytes)?; + filters.push(filter); + } + + let unsubscribe = Unsubscribe { + pkid, + filters, + properties, + }; + Ok(unsubscribe) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + buffer.put_u8(0xA2); + + // write remaining length + let remaining_len = self.len(); + let remaining_len_bytes = write_remaining_length(buffer, remaining_len)?; + + // write packet id + buffer.put_u16(self.pkid); + + match &self.properties { + Some(properties) => properties.write(buffer)?, + None => { + write_remaining_length(buffer, 0)?; + } + }; + + // write filters + for filter in self.filters.iter() { + write_mqtt_string(buffer, filter); + } + + Ok(1 + remaining_len_bytes + remaining_len) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct UnsubscribeProperties { + pub user_properties: Vec<(String, String)>, +} + +impl UnsubscribeProperties { + fn len(&self) -> usize { + let mut len = 0; + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::UserProperty => { + let key = read_mqtt_string(&mut bytes)?; + let value = read_mqtt_string(&mut bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(UnsubscribeProperties { user_properties })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use super::*; + use bytes::BytesMut; + use pretty_assertions::assert_eq; + + fn sample() -> Unsubscribe { + let properties = UnsubscribeProperties { + user_properties: vec![("test".to_owned(), "test".to_owned())], + }; + + Unsubscribe { + pkid: 10, + filters: vec!["hello".to_owned(), "world".to_owned()], + properties: Some(properties), + } + } + + fn sample_bytes() -> Vec { + vec![ + 0xa2, // packet type + 0x1e, // remaining len + 0x00, 0x0a, // pkid + 0x0d, // properties len + 0x26, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x04, 0x74, 0x65, 0x73, + 0x74, // user properties + 0x00, 0x05, 0x68, 0x65, 0x6c, 0x6c, 0x6f, // filter 1 + 0x00, 0x05, 0x77, 0x6f, 0x72, 0x6c, 0x64, // filter 2 + ] + } + + #[test] + fn unsubscribe_parsing_works() { + let mut stream = BytesMut::new(); + let packetstream = &sample_bytes(); + + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let unsubscribe_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let unsubscribe = Unsubscribe::read(fixed_header, unsubscribe_bytes).unwrap(); + assert_eq!(unsubscribe, sample()); + } + + #[test] + fn subscribe_encoding_works() { + let publish = sample(); + let mut buf = BytesMut::new(); + publish.write(&mut buf).unwrap(); + + println!("{:X?}", buf); + println!("{:#04X?}", &buf[..]); + assert_eq!(&buf[..], sample_bytes()); + } + + fn sample2() -> Unsubscribe { + Unsubscribe { + pkid: 10, + filters: vec!["hello".to_owned()], + properties: None, + } + } + + fn sample2_bytes() -> Vec { + vec![ + 0xa2, 0x0a, 0x00, 0x0a, 0x00, 0x00, 0x05, 0x68, 0x65, 0x6c, 0x6c, 0x6f, + ] + } + + #[test] + fn subscribe2_parsing_works() { + let mut stream = BytesMut::new(); + let packetstream = &sample2_bytes(); + + stream.extend_from_slice(&packetstream[..]); + + let fixed_header = parse_fixed_header(stream.iter()).unwrap(); + let subscribe_bytes = stream.split_to(fixed_header.frame_length()).freeze(); + let subscribe = Unsubscribe::read(fixed_header, subscribe_bytes).unwrap(); + assert_eq!(subscribe, sample2()); + } + + #[test] + fn subscribe2_encoding_works() { + let publish = sample2(); + let mut buf = BytesMut::new(); + publish.write(&mut buf).unwrap(); + + // println!("{:X?}", buf); + // println!("{:#04X?}", &buf[..]); + assert_eq!(&buf[..], sample2_bytes()); + } +} diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs new file mode 100644 index 000000000..b6a8ebfd9 --- /dev/null +++ b/rumqttc/src/v5/state.rs @@ -0,0 +1,759 @@ +use std::{ + collections::VecDeque, + io, mem, + sync::{Arc, Mutex}, + time::Instant, +}; + +use bytes::BytesMut; + +use crate::v5::{outgoing_buf::OutgoingBuf, packet::*, Incoming, Request}; + +/// Errors during state handling +#[derive(Debug, thiserror::Error)] +pub enum StateError { + /// Io Error while state is passed to network + #[error("Io error {0:?}")] + Io(#[from] io::Error), + /// Broker's error reply to client's connect packet + #[error("Connect return code `{0:?}`")] + Connect(ConnectReturnCode), + /// Invalid state for a given operation + #[error("Invalid state for a given operation")] + InvalidState, + /// Received a packet (ack) which isn't asked for + #[error("Received unsolicited ack pkid {0}")] + Unsolicited(u16), + /// Last pingreq isn't acked + #[error("Last pingreq isn't acked")] + AwaitPingResp, + /// Received a wrong packet while waiting for another packet + #[error("Received a wrong packet while waiting for another packet")] + WrongPacket, + #[error("Timeout while waiting to resolve collision")] + CollisionTimeout, + #[error("Mqtt serialization/deserialization error")] + Deserialization(Error), +} + +impl From for StateError { + fn from(e: Error) -> StateError { + StateError::Deserialization(e) + } +} + +/// State of the mqtt connection. +// Design: Methods will just modify the state of the object without doing any network operations +// Design: All inflight queues are maintained in a pre initialized vec with index as packet id. +// This is done for 2 reasons +// Bad acks or out of order acks aren't O(n) causing cpu spikes +// Any missing acks from the broker are detected during the next recycled use of packet ids +#[derive(Debug, Clone)] +pub struct MqttState { + /// Status of last ping + pub await_pingresp: bool, + /// Collision ping count. Collisions stop user requests + /// which inturn trigger pings. Multiple pings without + /// resolving collisions will result in error + pub collision_ping_count: usize, + /// Last incoming packet time + last_incoming: Instant, + /// Last outgoing packet time + last_outgoing: Instant, + /// Number of outgoing inflight publishes + pub(crate) inflight: u16, + /// Outgoing QoS 1, 2 publishes which aren't acked yet + pub(crate) outgoing_pub: Vec>, + /// Packet ids of released QoS 2 publishes + pub(crate) outgoing_rel: Vec>, + /// Packet ids on incoming QoS 2 publishes + pub(crate) incoming_pub: Vec>, + /// Last collision due to broker not acking in order + pub collision: Option, + /// Write buffer + pub write: BytesMut, + /// Indicates if acknowledgements should be send immediately + pub manual_acks: bool, + pub(crate) incoming_buf: Arc>>, + pub(crate) outgoing_buf: Arc>, +} + +impl MqttState { + /// Creates new mqtt state. Same state should be used during a + /// connection for persistent sessions while new state should + /// instantiated for clean sessions + pub fn new(max_inflight: u16, manual_acks: bool, cap: usize) -> Self { + MqttState { + await_pingresp: false, + collision_ping_count: 0, + last_incoming: Instant::now(), + last_outgoing: Instant::now(), + inflight: 0, + // index 0 is wasted as 0 is not a valid packet id + outgoing_pub: vec![None; max_inflight as usize + 1], + outgoing_rel: vec![None; max_inflight as usize + 1], + incoming_pub: vec![None; std::u16::MAX as usize + 1], + collision: None, + // TODO: Optimize these sizes later + write: BytesMut::with_capacity(10 * 1024), + manual_acks, + incoming_buf: Arc::new(Mutex::new(VecDeque::with_capacity(cap))), + outgoing_buf: OutgoingBuf::new(max_inflight as usize), + } + } + + /// Returns inflight outgoing packets and clears internal queues + pub fn clean(&mut self) -> Vec { + let mut pending = Vec::with_capacity(100); + // remove and collect pending publishes + for publish in self.outgoing_pub.iter_mut() { + if let Some(publish) = publish.take() { + let request = Request::Publish(publish); + pending.push(request); + } + } + + // remove and collect pending releases + for rel in self.outgoing_rel.iter_mut() { + if let Some(pkid) = rel.take() { + let request = Request::PubRel(PubRel::new(pkid)); + pending.push(request); + } + } + + // remove packed ids of incoming qos2 publishes + for id in self.incoming_pub.iter_mut() { + id.take(); + } + + self.await_pingresp = false; + self.collision_ping_count = 0; + self.inflight = 0; + pending + } + + #[inline] + pub fn inflight(&self) -> u16 { + self.inflight + } + + #[inline] + pub fn cur_pkid(&self) -> u16 { + self.outgoing_buf.lock().unwrap().pkid_counter + } + + /// Consolidates handling of all outgoing mqtt packet logic. Returns a packet which should + /// be put on to the network by the eventloop + pub fn handle_outgoing_packet(&mut self, request: Request) -> Result<(), StateError> { + match request { + Request::Publish(publish) => self.outgoing_publish(publish)?, + Request::PubRel(pubrel) => self.outgoing_pubrel(pubrel)?, + Request::Subscribe(subscribe) => self.outgoing_subscribe(subscribe)?, + Request::Unsubscribe(unsubscribe) => self.outgoing_unsubscribe(unsubscribe)?, + Request::PingReq => self.outgoing_ping()?, + Request::Disconnect => self.outgoing_disconnect()?, + Request::PubAck(puback) => self.outgoing_puback(puback)?, + Request::PubRec(pubrec) => self.outgoing_pubrec(pubrec)?, + _ => unimplemented!(), + }; + + self.last_outgoing = Instant::now(); + Ok(()) + } + + /// Consolidates handling of all incoming mqtt packets. Returns a `Notification` which for the + /// user to consume and `Packet` which for the eventloop to put on the network + /// E.g For incoming QoS1 publish packet, this method returns (Publish, Puback). Publish packet will + /// be forwarded to user and Pubck packet will be written to network + pub fn handle_incoming_packet(&mut self, packet: Incoming) -> Result<(), StateError> { + let out = match &packet { + Incoming::PingResp => self.handle_incoming_pingresp(), + Incoming::Publish(publish) => self.handle_incoming_publish(publish), + Incoming::SubAck(_suback) => self.handle_incoming_suback(), + Incoming::UnsubAck(_unsuback) => self.handle_incoming_unsuback(), + Incoming::PubAck(puback) => self.handle_incoming_puback(puback), + Incoming::PubRec(pubrec) => self.handle_incoming_pubrec(pubrec), + Incoming::PubRel(pubrel) => self.handle_incoming_pubrel(pubrel), + Incoming::PubComp(pubcomp) => self.handle_incoming_pubcomp(pubcomp), + _ => { + error!("Invalid incoming packet = {:?}", packet); + return Err(StateError::WrongPacket); + } + }; + + out?; + self.incoming_buf.lock().unwrap().push_back(packet); + self.last_incoming = Instant::now(); + Ok(()) + } + + #[inline] + fn handle_incoming_suback(&mut self) -> Result<(), StateError> { + Ok(()) + } + + #[inline] + fn handle_incoming_unsuback(&mut self) -> Result<(), StateError> { + Ok(()) + } + + /// Results in a publish notification in all the QoS cases. Replys with an ack + /// in case of QoS1 and Replys rec in case of QoS while also storing the message + fn handle_incoming_publish(&mut self, publish: &Publish) -> Result<(), StateError> { + let qos = publish.qos; + + match qos { + QoS::AtMostOnce => {} + QoS::AtLeastOnce => { + if !self.manual_acks { + let puback = PubAck::new(publish.pkid); + self.outgoing_puback(puback)? + } + } + QoS::ExactlyOnce => { + let pkid = publish.pkid; + self.incoming_pub[pkid as usize] = Some(pkid); + if !self.manual_acks { + let pubrec = PubRec::new(pkid); + self.outgoing_pubrec(pubrec)?; + } + } + } + + Ok(()) + } + + fn handle_incoming_puback(&mut self, puback: &PubAck) -> Result<(), StateError> { + let v = match mem::replace(&mut self.outgoing_pub[puback.pkid as usize], None) { + Some(_) => { + self.inflight -= 1; + Ok(()) + } + None => { + error!("Unsolicited puback packet: {:?}", puback.pkid); + Err(StateError::Unsolicited(puback.pkid)) + } + }; + + if let Some(publish) = self.check_collision(puback.pkid) { + self.outgoing_pub[publish.pkid as usize] = Some(publish.clone()); + self.inflight += 1; + + publish.write(&mut self.write)?; + self.collision_ping_count = 0; + } + + v + } + + fn handle_incoming_pubrec(&mut self, pubrec: &PubRec) -> Result<(), StateError> { + match mem::replace(&mut self.outgoing_pub[pubrec.pkid as usize], None) { + Some(_) => { + // NOTE: Inflight - 1 for qos2 in comp + self.outgoing_rel[pubrec.pkid as usize] = Some(pubrec.pkid); + PubRel::new(pubrec.pkid).write(&mut self.write)?; + Ok(()) + } + None => { + error!("Unsolicited pubrec packet: {:?}", pubrec.pkid); + Err(StateError::Unsolicited(pubrec.pkid)) + } + } + } + + fn handle_incoming_pubrel(&mut self, pubrel: &PubRel) -> Result<(), StateError> { + match mem::replace(&mut self.incoming_pub[pubrel.pkid as usize], None) { + Some(_) => { + PubComp::new(pubrel.pkid).write(&mut self.write)?; + Ok(()) + } + None => { + error!("Unsolicited pubrel packet: {:?}", pubrel.pkid); + Err(StateError::Unsolicited(pubrel.pkid)) + } + } + } + + fn handle_incoming_pubcomp(&mut self, pubcomp: &PubComp) -> Result<(), StateError> { + if let Some(publish) = self.check_collision(pubcomp.pkid) { + publish.write(&mut self.write)?; + self.collision_ping_count = 0; + } + + match mem::replace(&mut self.outgoing_rel[pubcomp.pkid as usize], None) { + Some(_) => { + self.inflight -= 1; + Ok(()) + } + None => { + error!("Unsolicited pubcomp packet: {:?}", pubcomp.pkid); + Err(StateError::Unsolicited(pubcomp.pkid)) + } + } + } + + #[inline] + fn handle_incoming_pingresp(&mut self) -> Result<(), StateError> { + self.await_pingresp = false; + Ok(()) + } + + /// Adds next packet identifier to QoS 1 and 2 publish packets and returns + /// it buy wrapping publish in packet + fn outgoing_publish(&mut self, publish: Publish) -> Result<(), StateError> { + if publish.qos != QoS::AtMostOnce { + // client should set proper pkid + let pkid = publish.pkid; + if self + .outgoing_pub + .get(publish.pkid as usize) + .unwrap() + .is_some() + { + info!("Collision on packet id = {:?}", publish.pkid); + self.collision = Some(publish); + return Ok(()); + } + + // if there is an existing publish at this pkid, this implies that broker hasn't acked this + // packet yet. This error is possible only when broker isn't acking sequentially + self.outgoing_pub[pkid as usize] = Some(publish.clone()); + self.inflight += 1; + }; + + debug!( + "Publish. Topic = {}, Pkid = {:?}, Payload Size = {:?}", + publish.topic, + publish.pkid, + publish.payload.len() + ); + + publish.write(&mut self.write)?; + Ok(()) + } + + #[inline] + fn outgoing_pubrel(&mut self, pubrel: PubRel) -> Result<(), StateError> { + let pubrel = self.save_pubrel(pubrel)?; + + debug!("Pubrel. Pkid = {}", pubrel.pkid); + PubRel::new(pubrel.pkid).write(&mut self.write)?; + Ok(()) + } + + #[inline] + fn outgoing_puback(&mut self, puback: PubAck) -> Result<(), StateError> { + puback.write(&mut self.write)?; + Ok(()) + } + + #[inline] + fn outgoing_pubrec(&mut self, pubrec: PubRec) -> Result<(), StateError> { + pubrec.write(&mut self.write)?; + Ok(()) + } + + /// check when the last control packet/pingreq packet is received and return + /// the status which tells if keep alive time has exceeded + /// NOTE: status will be checked for zero keepalive times also + fn outgoing_ping(&mut self) -> Result<(), StateError> { + let elapsed_in = self.last_incoming.elapsed(); + let elapsed_out = self.last_outgoing.elapsed(); + + if self.collision.is_some() { + self.collision_ping_count += 1; + if self.collision_ping_count >= 2 { + return Err(StateError::CollisionTimeout); + } + } + + // raise error if last ping didn't receive ack + if self.await_pingresp { + return Err(StateError::AwaitPingResp); + } + + self.await_pingresp = true; + + debug!( + "Pingreq, + last incoming packet before {} millisecs, + last outgoing request before {} millisecs", + elapsed_in.as_millis(), + elapsed_out.as_millis() + ); + + PingReq.write(&mut self.write)?; + Ok(()) + } + + #[inline] + fn outgoing_subscribe(&mut self, subscription: Subscribe) -> Result<(), StateError> { + // client should set correct pkid + debug!( + "Subscribe. Topics = {:?}, Pkid = {:?}", + subscription.filters, subscription.pkid + ); + + subscription.write(&mut self.write)?; + Ok(()) + } + + #[inline] + fn outgoing_unsubscribe(&mut self, unsub: Unsubscribe) -> Result<(), StateError> { + debug!( + "Unsubscribe. Topics = {:?}, Pkid = {:?}", + unsub.filters, unsub.pkid + ); + + unsub.write(&mut self.write)?; + Ok(()) + } + + #[inline] + fn outgoing_disconnect(&mut self) -> Result<(), StateError> { + debug!("Disconnect"); + + Disconnect::new().write(&mut self.write)?; + Ok(()) + } + + #[inline] + fn check_collision(&mut self, pkid: u16) -> Option { + if let Some(publish) = &self.collision { + if publish.pkid == pkid { + return self.collision.take(); + } + } + + None + } + + #[inline] + fn save_pubrel(&mut self, pubrel: PubRel) -> Result { + // pubrel's pkid should already be set correct + self.outgoing_rel[pubrel.pkid as usize] = Some(pubrel.pkid); + Ok(pubrel) + } + + #[inline] + pub fn increment_pkid(&self) -> u16 { + self.outgoing_buf.lock().unwrap().increment_pkid() + } + + ///// http://stackoverflow.com/questions/11115364/mqtt-messageid-practical-implementation + ///// Packet ids are incremented till maximum set inflight messages and reset to 1 after that. + ///// + //fn next_pkid(&mut self) -> u16 { + // let next_pkid = self.last_pkid + 1; + + // // When next packet id is at the edge of inflight queue, + // // set await flag. This instructs eventloop to stop + // // processing requests until all the inflight publishes + // // are acked + // if next_pkid == self.max_inflight { + // self.last_pkid = 0; + // return next_pkid; + // } + + // self.last_pkid = next_pkid; + // next_pkid + //} +} + +#[cfg(test)] +mod test { + use super::{MqttState, StateError}; + use crate::v5::{packet::*, Incoming, MqttOptions, Request}; + + fn build_outgoing_publish(qos: QoS) -> Publish { + let topic = "hello/world".to_owned(); + let payload = vec![1, 2, 3]; + + let mut publish = Publish::new(topic, QoS::AtLeastOnce, payload); + publish.qos = qos; + publish + } + + fn build_incoming_publish(qos: QoS, pkid: u16) -> Publish { + let topic = "hello/world".to_owned(); + let payload = vec![1, 2, 3]; + + let mut publish = Publish::new(topic, QoS::AtLeastOnce, payload); + publish.pkid = pkid; + publish.qos = qos; + publish + } + + fn build_mqttstate() -> MqttState { + MqttState::new(100, false, 100) + } + + #[test] + fn next_pkid_increments_as_expected() { + let mqtt = build_mqttstate(); + + for i in 1..=100 { + let pkid = mqtt.increment_pkid(); + + // loops between 0-99. % 100 == 0 implies border + let expected = i % 100; + if expected == 0 { + break; + } + + assert_eq!(expected, pkid); + } + } + + #[test] + fn outgoing_publish_should_set_pkid_and_add_publish_to_queue() { + let mut mqtt = build_mqttstate(); + + // QoS0 Publish + let mut publish = build_outgoing_publish(QoS::AtMostOnce); + publish.pkid = 1; + + // QoS 0 publish shouldn't be saved in queue + mqtt.outgoing_publish(publish).unwrap(); + assert_eq!(mqtt.cur_pkid(), 0); + assert_eq!(mqtt.inflight, 0); + + // QoS1 Publish + let mut publish = build_outgoing_publish(QoS::AtLeastOnce); + publish.pkid = 2; + + // Packet id should be set and publish should be saved in queue + mqtt.outgoing_publish(publish.clone()).unwrap(); + // cur_pkid == 0 as there is no client to update it + assert_eq!(mqtt.cur_pkid(), 0); + assert_eq!(mqtt.inflight, 1); + + // Packet id should be incremented and publish should be saved in queue + publish.pkid = 3; + mqtt.outgoing_publish(publish).unwrap(); + // cur_pkid == 0 as there is no client to update it + assert_eq!(mqtt.cur_pkid(), 0); + assert_eq!(mqtt.inflight, 2); + + // QoS1 Publish + let mut publish = build_outgoing_publish(QoS::ExactlyOnce); + publish.pkid = 4; + + // Packet id should be set and publish should be saved in queue + mqtt.outgoing_publish(publish.clone()).unwrap(); + // cur_pkid == 0 as there is no client to update it + assert_eq!(mqtt.cur_pkid(), 0); + assert_eq!(mqtt.inflight, 3); + + publish.pkid = 5; + // Packet id should be incremented and publish should be saved in queue + mqtt.outgoing_publish(publish).unwrap(); + // cur_pkid == 0 as there is no client to update it + assert_eq!(mqtt.cur_pkid(), 0); + assert_eq!(mqtt.inflight, 4); + } + + #[test] + fn incoming_publish_should_be_added_to_queue_correctly() { + let mut mqtt = build_mqttstate(); + + // QoS0, 1, 2 Publishes + let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); + let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); + let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); + + mqtt.handle_incoming_publish(&publish1).unwrap(); + mqtt.handle_incoming_publish(&publish2).unwrap(); + mqtt.handle_incoming_publish(&publish3).unwrap(); + + let pkid = mqtt.incoming_pub[3].unwrap(); + + // only qos2 publish should be add to queue + assert_eq!(pkid, 3); + } + + #[test] + fn incoming_publish_should_be_acked() { + let mut mqtt = build_mqttstate(); + + // QoS0, 1, 2 Publishes + let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); + let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); + let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); + + mqtt.handle_incoming_publish(&publish1).unwrap(); + mqtt.handle_incoming_publish(&publish2).unwrap(); + mqtt.handle_incoming_publish(&publish3).unwrap(); + } + + #[test] + fn incoming_publish_should_not_be_acked_with_manual_acks() { + let mut mqtt = build_mqttstate(); + mqtt.manual_acks = true; + + // QoS0, 1, 2 Publishes + let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); + let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); + let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); + + mqtt.handle_incoming_publish(&publish1).unwrap(); + mqtt.handle_incoming_publish(&publish2).unwrap(); + mqtt.handle_incoming_publish(&publish3).unwrap(); + + let pkid = mqtt.incoming_pub[3].unwrap(); + assert_eq!(pkid, 3); + + assert!(mqtt.incoming_buf.lock().unwrap().is_empty()); + } + + #[test] + fn incoming_qos2_publish_should_send_rec_to_network_and_publish_to_user() { + let mut mqtt = build_mqttstate(); + let publish = build_incoming_publish(QoS::ExactlyOnce, 1); + + mqtt.handle_incoming_publish(&publish).unwrap(); + let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + match packet { + Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), + _ => panic!("Invalid network request: {:?}", packet), + } + } + + #[test] + fn incoming_puback_should_remove_correct_publish_from_queue() { + let mut mqtt = build_mqttstate(); + + let mut publish1 = build_outgoing_publish(QoS::AtLeastOnce); + let mut publish2 = build_outgoing_publish(QoS::ExactlyOnce); + publish1.pkid = 1; + publish2.pkid = 2; + + mqtt.outgoing_publish(publish1).unwrap(); + mqtt.outgoing_publish(publish2).unwrap(); + assert_eq!(mqtt.inflight, 2); + + mqtt.handle_incoming_puback(&PubAck::new(1)).unwrap(); + assert_eq!(mqtt.inflight, 1); + + mqtt.handle_incoming_puback(&PubAck::new(2)).unwrap(); + assert_eq!(mqtt.inflight, 0); + + assert!(mqtt.outgoing_pub[1].is_none()); + assert!(mqtt.outgoing_pub[2].is_none()); + } + + #[test] + fn incoming_pubrec_should_release_publish_from_queue_and_add_relid_to_rel_queue() { + let mut mqtt = build_mqttstate(); + + let mut publish1 = build_outgoing_publish(QoS::AtLeastOnce); + let mut publish2 = build_outgoing_publish(QoS::ExactlyOnce); + publish1.pkid = 1; + publish2.pkid = 2; + + let _publish_out = mqtt.outgoing_publish(publish1); + let _publish_out = mqtt.outgoing_publish(publish2); + + mqtt.handle_incoming_pubrec(&PubRec::new(2)).unwrap(); + assert_eq!(mqtt.inflight, 2); + + // check if the remaining element's pkid is 1 + let backup = mqtt.outgoing_pub[1].clone(); + assert_eq!(backup.unwrap().pkid, 1); + + // check if the qos2 element's release pkid is 2 + assert_eq!(mqtt.outgoing_rel[2].unwrap(), 2); + } + + #[test] + fn incoming_pubrec_should_send_release_to_network_and_nothing_to_user() { + let mut mqtt = build_mqttstate(); + + let mut publish = build_outgoing_publish(QoS::ExactlyOnce); + publish.pkid = 1; + mqtt.outgoing_publish(publish).unwrap(); + let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + match packet { + Packet::Publish(publish) => assert_eq!(publish.pkid, 1), + packet => panic!("Invalid network request: {:?}", packet), + } + + mqtt.handle_incoming_pubrec(&PubRec::new(1)).unwrap(); + let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + match packet { + Packet::PubRel(pubrel) => assert_eq!(pubrel.pkid, 1), + packet => panic!("Invalid network request: {:?}", packet), + } + } + + #[test] + fn incoming_pubrel_should_send_comp_to_network_and_nothing_to_user() { + let mut mqtt = build_mqttstate(); + let publish = build_incoming_publish(QoS::ExactlyOnce, 1); + + mqtt.handle_incoming_publish(&publish).unwrap(); + let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + match packet { + Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), + packet => panic!("Invalid network request: {:?}", packet), + } + + mqtt.handle_incoming_pubrel(&PubRel::new(1)).unwrap(); + let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + match packet { + Packet::PubComp(pubcomp) => assert_eq!(pubcomp.pkid, 1), + packet => panic!("Invalid network request: {:?}", packet), + } + } + + #[test] + fn incoming_pubcomp_should_release_correct_pkid_from_release_queue() { + let mut mqtt = build_mqttstate(); + let mut publish = build_outgoing_publish(QoS::ExactlyOnce); + publish.pkid = 1; + + mqtt.outgoing_publish(publish).unwrap(); + mqtt.handle_incoming_pubrec(&PubRec::new(1)).unwrap(); + + mqtt.handle_incoming_pubcomp(&PubComp::new(1)).unwrap(); + assert_eq!(mqtt.inflight, 0); + } + + #[test] + fn outgoing_ping_handle_should_throw_errors_for_no_pingresp() { + let mut mqtt = build_mqttstate(); + let mut opts = MqttOptions::new("test", "localhost", 1883); + opts.set_keep_alive(std::time::Duration::from_secs(10)); + mqtt.outgoing_ping().unwrap(); + + // network activity other than pingresp + let mut publish = build_outgoing_publish(QoS::AtLeastOnce); + publish.pkid = 1; + mqtt.handle_outgoing_packet(Request::Publish(publish)) + .unwrap(); + mqtt.handle_incoming_packet(Incoming::PubAck(PubAck::new(1))) + .unwrap(); + + // should throw error because we didn't get pingresp for previous ping + match mqtt.outgoing_ping() { + Ok(_) => panic!("Should throw pingresp await error"), + Err(StateError::AwaitPingResp) => (), + Err(e) => panic!("Should throw pingresp await error. Error = {:?}", e), + } + } + + #[test] + fn outgoing_ping_handle_should_succeed_if_pingresp_is_received() { + let mut mqtt = build_mqttstate(); + + let mut opts = MqttOptions::new("test", "localhost", 1883); + opts.set_keep_alive(std::time::Duration::from_secs(10)); + + // should ping + mqtt.outgoing_ping().unwrap(); + mqtt.handle_incoming_packet(Incoming::PingResp).unwrap(); + + // should ping + mqtt.outgoing_ping().unwrap(); + } +} diff --git a/rumqttc/src/v5/tls.rs b/rumqttc/src/v5/tls.rs new file mode 100644 index 000000000..3936b2ca8 --- /dev/null +++ b/rumqttc/src/v5/tls.rs @@ -0,0 +1,130 @@ +use tokio::net::TcpStream; +use tokio_rustls::rustls; +use tokio_rustls::rustls::client::InvalidDnsNameError; +use tokio_rustls::rustls::{ + Certificate, ClientConfig, OwnedTrustAnchor, PrivateKey, RootCertStore, ServerName, +}; +use tokio_rustls::webpki; +use tokio_rustls::{client::TlsStream, TlsConnector}; + +use crate::v5::{Key, MqttOptions, TlsConfiguration}; + +use std::convert::TryFrom; +use std::io; +use std::io::{BufReader, Cursor}; +use std::net::AddrParseError; +use std::sync::Arc; + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("Addr")] + Addr(#[from] AddrParseError), + #[error("I/O")] + Io(#[from] io::Error), + #[error("Web Pki")] + WebPki(#[from] webpki::Error), + #[error("DNS name")] + DNSName(#[from] InvalidDnsNameError), + #[error("TLS error")] + TLS(#[from] rustls::Error), + #[error("No valid cert in chain")] + NoValidCertInChain, +} + +// The cert handling functions return unit right now, this is a shortcut +impl From<()> for Error { + fn from(_: ()) -> Self { + Error::NoValidCertInChain + } +} + +pub async fn tls_connector(tls_config: &TlsConfiguration) -> Result { + let config = match tls_config { + TlsConfiguration::Simple { + ca, + alpn, + client_auth, + } => { + // Add ca to root store if the connection is TLS + let mut root_cert_store = RootCertStore::empty(); + let certs = rustls_pemfile::certs(&mut BufReader::new(Cursor::new(ca)))?; + + let trust_anchors = certs.iter().map_while(|cert| { + if let Ok(ta) = webpki::TrustAnchor::try_from_cert_der(&cert[..]) { + Some(OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + )) + } else { + None + } + }); + + root_cert_store.add_server_trust_anchors(trust_anchors); + + if root_cert_store.is_empty() { + return Err(Error::NoValidCertInChain); + } + + let config = ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_cert_store); + + // Add der encoded client cert and key + let mut config = if let Some(client) = client_auth.as_ref() { + let certs = + rustls_pemfile::certs(&mut BufReader::new(Cursor::new(client.0.clone())))?; + // load appropriate Key as per the user request. The underlying signature algorithm + // of key generation determines the Signature Algorithm during the TLS Handskahe. + let read_keys = match &client.1 { + Key::RSA(k) => rustls_pemfile::rsa_private_keys(&mut BufReader::new( + Cursor::new(k.clone()), + )), + Key::ECC(k) => rustls_pemfile::pkcs8_private_keys(&mut BufReader::new( + Cursor::new(k.clone()), + )), + }; + let keys = match read_keys { + Ok(v) => v, + Err(_e) => return Err(Error::NoValidCertInChain), + }; + + // Get the first key. Error if it's not valid + let key = match keys.first() { + Some(k) => k.clone(), + None => return Err(Error::NoValidCertInChain), + }; + + let certs = certs.into_iter().map(Certificate).collect(); + + config.with_single_cert(certs, PrivateKey(key))? + } else { + config.with_no_client_auth() + }; + + // Set ALPN + if let Some(alpn) = alpn.as_ref() { + config.alpn_protocols.extend_from_slice(alpn); + } + + Arc::new(config) + } + TlsConfiguration::Rustls(tls_client_config) => tls_client_config.clone(), + }; + + Ok(TlsConnector::from(config)) +} + +pub async fn tls_connect( + options: &MqttOptions, + tls_config: &TlsConfiguration, +) -> Result, Error> { + let addr = options.broker_addr.as_str(); + let port = options.port; + let connector = tls_connector(tls_config).await?; + let domain = ServerName::try_from(addr)?; + let tcp = TcpStream::connect((addr, port)).await?; + let tls = connector.connect(domain, tcp).await?; + Ok(tls) +}