diff --git a/Cargo.toml b/Cargo.toml index b64c2f68756..7249e70cb4c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,7 +19,7 @@ default = [ "identify", "kad", "gossipsub", - "mdns", + "mdns-async-io", "mplex", "noise", "ping", @@ -46,7 +46,8 @@ identify = ["dep:libp2p-identify", "libp2p-metrics?/identify"] kad = ["dep:libp2p-kad", "libp2p-metrics?/kad"] gossipsub = ["dep:libp2p-gossipsub", "libp2p-metrics?/gossipsub"] metrics = ["dep:libp2p-metrics"] -mdns = ["dep:libp2p-mdns"] +mdns-async-io = ["dep:libp2p-mdns", "libp2p-mdns?/async-io"] +mdns-tokio = ["dep:libp2p-mdns", "libp2p-mdns?/tokio"] mplex = ["dep:libp2p-mplex"] noise = ["dep:libp2p-noise"] ping = ["dep:libp2p-ping", "libp2p-metrics?/ping"] @@ -106,7 +107,7 @@ smallvec = "1.6.1" [target.'cfg(not(any(target_os = "emscripten", target_os = "wasi", target_os = "unknown")))'.dependencies] libp2p-deflate = { version = "0.35.0", path = "transports/deflate", optional = true } libp2p-dns = { version = "0.35.0", path = "transports/dns", optional = true, default-features = false } -libp2p-mdns = { version = "0.40.0", path = "protocols/mdns", optional = true } +libp2p-mdns = { version = "0.40.0", path = "protocols/mdns", optional = true, default-features = false } libp2p-tcp = { version = "0.35.0", path = "transports/tcp", default-features = false, optional = true } libp2p-websocket = { version = "0.37.0", path = "transports/websocket", optional = true } @@ -160,7 +161,7 @@ required-features = ["floodsub"] [[example]] name = "chat-tokio" -required-features = ["tcp-tokio", "mdns"] +required-features = ["tcp-tokio", "mdns-tokio"] [[example]] name = "file-sharing" diff --git a/examples/chat-tokio.rs b/examples/chat-tokio.rs index 0bd44bdabdc..f82d30934c9 100644 --- a/examples/chat-tokio.rs +++ b/examples/chat-tokio.rs @@ -25,7 +25,7 @@ //! The example is run per node as follows: //! //! ```sh -//! cargo run --example chat-tokio --features="tcp-tokio mdns" +//! cargo run --example chat-tokio --features="tcp-tokio mdns-tokio" //! ``` //! //! Alternatively, to run with the minimal set of features and crates: @@ -33,7 +33,7 @@ //! ```sh //!cargo run --example chat-tokio \\ //! --no-default-features \\ -//! --features="floodsub mplex noise tcp-tokio mdns" +//! --features="floodsub mplex noise tcp-tokio mdns-tokio" //! ``` use futures::StreamExt; @@ -41,7 +41,11 @@ use libp2p::{ core::upgrade, floodsub::{self, Floodsub, FloodsubEvent}, identity, - mdns::{Mdns, MdnsEvent}, + mdns::{ + MdnsEvent, + // `TokioMdns` is available through the `mdns-tokio` feature. + TokioMdns, + }, mplex, noise, swarm::{SwarmBuilder, SwarmEvent}, @@ -88,7 +92,7 @@ async fn main() -> Result<(), Box> { #[behaviour(out_event = "MyBehaviourEvent")] struct MyBehaviour { floodsub: Floodsub, - mdns: Mdns, + mdns: TokioMdns, } #[allow(clippy::large_enum_variant)] @@ -111,7 +115,7 @@ async fn main() -> Result<(), Box> { // Create a Swarm to manage peers and events. let mut swarm = { - let mdns = Mdns::new(Default::default()).await?; + let mdns = TokioMdns::new(Default::default()).await?; let mut behaviour = MyBehaviour { floodsub: Floodsub::new(peer_id), mdns, diff --git a/protocols/mdns/CHANGELOG.md b/protocols/mdns/CHANGELOG.md index a6b05044fa5..4a22bdb1c3f 100644 --- a/protocols/mdns/CHANGELOG.md +++ b/protocols/mdns/CHANGELOG.md @@ -2,6 +2,14 @@ - Update to `libp2p-swarm` `v0.39.0`. +- Allow users to choose between async-io and tokio runtime + in the mdns protocol implementation. `async-io` is a default + feature, with an additional `tokio` feature (see [PR 2748]) + +- Fix high CPU usage with Tokio library (see [PR 2748]). + +[PR 2748]: https://github.com/libp2p/rust-libp2p/pull/2748 + # 0.39.0 - Update to `libp2p-swarm` `v0.38.0`. diff --git a/protocols/mdns/Cargo.toml b/protocols/mdns/Cargo.toml index 306f052cc69..10883a82ec0 100644 --- a/protocols/mdns/Cargo.toml +++ b/protocols/mdns/Cargo.toml @@ -11,7 +11,6 @@ keywords = ["peer-to-peer", "libp2p", "networking"] categories = ["network-programming", "asynchronous"] [dependencies] -async-io = "1.3.1" data-encoding = "2.3.2" dns-parser = "0.8.0" futures = "0.3.13" @@ -25,8 +24,26 @@ smallvec = "1.6.1" socket2 = { version = "0.4.0", features = ["all"] } void = "1.0.2" +async-io = { version = "1.3.1", optional = true } +tokio = { version = "1.19", default-features = false, features = ["net", "time"], optional = true} + +[features] +default = ["async-io"] +tokio = ["dep:tokio"] +async-io = ["dep:async-io"] + [dev-dependencies] async-std = { version = "1.9.0", features = ["attributes"] } env_logger = "0.9.0" -libp2p = { path = "../..", default-features = false, features = ["mdns", "tcp-async-io", "dns-async-std", "websocket", "noise", "mplex", "yamux"] } -tokio = { version = "1.15", default-features = false, features = ["macros", "rt", "rt-multi-thread", "time"] } +libp2p = { path = "../..", default-features = false, features = ["mdns-async-io", "tcp-async-io", "dns-async-std", "tcp-tokio", "dns-tokio", "websocket", "noise", "mplex", "yamux"] } +tokio = { version = "1.19", default-features = false, features = ["macros", "rt", "rt-multi-thread", "time"] } + + +[[test]] +name = "use-async-std" +required-features = ["async-io"] + +[[test]] +name = "use-tokio" +required-features = ["tokio"] + diff --git a/protocols/mdns/src/behaviour.rs b/protocols/mdns/src/behaviour.rs index 244b2b784dd..854bd885a22 100644 --- a/protocols/mdns/src/behaviour.rs +++ b/protocols/mdns/src/behaviour.rs @@ -19,11 +19,14 @@ // DEALINGS IN THE SOFTWARE. mod iface; +mod socket; +mod timer; use self::iface::InterfaceState; +use crate::behaviour::{socket::AsyncSocket, timer::Builder}; use crate::MdnsConfig; -use async_io::Timer; use futures::prelude::*; +use futures::Stream; use if_watch::{IfEvent, IfWatcher}; use libp2p_core::transport::ListenerId; use libp2p_core::{Multiaddr, PeerId}; @@ -35,10 +38,24 @@ use smallvec::SmallVec; use std::collections::hash_map::{Entry, HashMap}; use std::{cmp, fmt, io, net::IpAddr, pin::Pin, task::Context, task::Poll, time::Instant}; +#[cfg(feature = "async-io")] +use crate::behaviour::{socket::asio::AsyncUdpSocket, timer::asio::AsyncTimer}; + +/// The type of a [`GenMdns`] using the `async-io` implementation. +#[cfg(feature = "async-io")] +pub type Mdns = GenMdns; + +#[cfg(feature = "tokio")] +use crate::behaviour::{socket::tokio::TokioUdpSocket, timer::tokio::TokioTimer}; + +/// The type of a [`GenMdns`] using the `tokio` implementation. +#[cfg(feature = "tokio")] +pub type TokioMdns = GenMdns; + /// A `NetworkBehaviour` for mDNS. Automatically discovers peers on the local network and adds /// them to the topology. #[derive(Debug)] -pub struct Mdns { +pub struct GenMdns { /// InterfaceState config. config: MdnsConfig, @@ -46,7 +63,7 @@ pub struct Mdns { if_watch: IfWatcher, /// Mdns interface states. - iface_states: HashMap, + iface_states: HashMap>, /// List of nodes that we have discovered, the address, and when their TTL expires. /// @@ -57,10 +74,13 @@ pub struct Mdns { /// Future that fires when the TTL of at least one node in `discovered_nodes` expires. /// /// `None` if `discovered_nodes` is empty. - closest_expiration: Option, + closest_expiration: Option, } -impl Mdns { +impl GenMdns +where + T: Builder, +{ /// Builds a new `Mdns` behaviour. pub async fn new(config: MdnsConfig) -> io::Result { let if_watch = if_watch::IfWatcher::new().await?; @@ -91,11 +111,15 @@ impl Mdns { *expires = now; } } - self.closest_expiration = Some(Timer::at(now)); + self.closest_expiration = Some(T::at(now)); } } -impl NetworkBehaviour for Mdns { +impl NetworkBehaviour for GenMdns +where + T: Builder + Stream, + S: AsyncSocket, +{ type ConnectionHandler = DummyConnectionHandler; type OutEvent = MdnsEvent; @@ -219,8 +243,9 @@ impl NetworkBehaviour for Mdns { return Poll::Ready(NetworkBehaviourAction::GenerateEvent(event)); } if let Some(closest_expiration) = closest_expiration { - let mut timer = Timer::at(closest_expiration); - let _ = Pin::new(&mut timer).poll(cx); + let mut timer = T::at(closest_expiration); + let _ = Pin::new(&mut timer).poll_next(cx); + self.closest_expiration = Some(timer); } Poll::Pending diff --git a/protocols/mdns/src/behaviour/iface.rs b/protocols/mdns/src/behaviour/iface.rs index e4971e36b1a..c5bacced138 100644 --- a/protocols/mdns/src/behaviour/iface.rs +++ b/protocols/mdns/src/behaviour/iface.rs @@ -23,9 +23,8 @@ mod query; use self::dns::{build_query, build_query_response, build_service_discovery_response}; use self::query::MdnsPacket; +use crate::behaviour::{socket::AsyncSocket, timer::Builder}; use crate::MdnsConfig; -use async_io::{Async, Timer}; -use futures::prelude::*; use libp2p_core::{address_translation, multiaddr::Protocol, Multiaddr, PeerId}; use libp2p_swarm::PollParameters; use socket2::{Domain, Socket, Type}; @@ -34,20 +33,20 @@ use std::{ io, iter, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket}, pin::Pin, - task::Context, + task::{Context, Poll}, time::{Duration, Instant}, }; /// An mDNS instance for a networking interface. To discover all peers when having multiple /// interfaces an [`InterfaceState`] is required for each interface. #[derive(Debug)] -pub struct InterfaceState { +pub struct InterfaceState { /// Address this instance is bound to. addr: IpAddr, /// Receive socket. - recv_socket: Async, + recv_socket: U, /// Send socket. - send_socket: Async, + send_socket: U, /// Buffer used for receiving data from the main socket. /// RFC6762 discourages packets larger than the interface MTU, but allows sizes of up to 9000 /// bytes, if it can be ensured that all participating devices can handle such large packets. @@ -60,7 +59,7 @@ pub struct InterfaceState { /// Discovery interval. query_interval: Duration, /// Discovery timer. - timeout: Timer, + timeout: T, /// Multicast address. multicast_addr: IpAddr, /// Discovered addresses. @@ -69,7 +68,11 @@ pub struct InterfaceState { ttl: Duration, } -impl InterfaceState { +impl InterfaceState +where + U: AsyncSocket, + T: Builder + futures::Stream, +{ /// Builds a new [`InterfaceState`]. pub fn new(addr: IpAddr, config: MdnsConfig) -> io::Result { log::info!("creating instance on iface {}", addr); @@ -83,7 +86,7 @@ impl InterfaceState { socket.set_multicast_loop_v4(true)?; socket.set_multicast_ttl_v4(255)?; socket.join_multicast_v4(&*crate::IPV4_MDNS_MULTICAST_ADDRESS, &addr)?; - Async::new(UdpSocket::from(socket))? + U::from_std(UdpSocket::from(socket))? } IpAddr::V6(_) => { let socket = Socket::new(Domain::IPV6, Type::DGRAM, Some(socket2::Protocol::UDP))?; @@ -94,7 +97,7 @@ impl InterfaceState { socket.set_multicast_loop_v6(true)?; // TODO: find interface matching addr. socket.join_multicast_v6(&*crate::IPV6_MDNS_MULTICAST_ADDRESS, 0)?; - Async::new(UdpSocket::from(socket))? + U::from_std(UdpSocket::from(socket))? } }; let bind_addr = match addr { @@ -107,7 +110,8 @@ impl InterfaceState { SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0) } }; - let send_socket = Async::new(UdpSocket::bind(bind_addr)?)?; + let send_socket = U::from_std(UdpSocket::bind(bind_addr)?)?; + // randomize timer to prevent all converging and firing at the same time. let query_interval = { use rand::Rng; @@ -127,19 +131,18 @@ impl InterfaceState { send_buffer: Default::default(), discovered: Default::default(), query_interval, - timeout: Timer::interval_at(Instant::now(), query_interval), + timeout: T::interval_at(Instant::now(), query_interval), multicast_addr, ttl: config.ttl, }) } pub fn reset_timer(&mut self) { - self.timeout.set_interval(self.query_interval); + self.timeout = T::interval(self.query_interval); } pub fn fire_timer(&mut self) { - self.timeout - .set_interval_at(Instant::now(), self.query_interval); + self.timeout = T::interval_at(Instant::now(), self.query_interval); } fn inject_mdns_packet(&mut self, packet: MdnsPacket, params: &impl PollParameters) { @@ -171,17 +174,17 @@ impl InterfaceState { let new_expiration = Instant::now() + peer.ttl(); - let mut addrs: Vec = Vec::new(); for addr in peer.addresses() { if let Some(new_addr) = address_translation(addr, &observed) { - addrs.push(new_addr.clone()) + self.discovered.push_back(( + *peer.id(), + new_addr.clone(), + new_expiration, + )); } - addrs.push(addr.clone()) - } - for addr in addrs { self.discovered - .push_back((*peer.id(), addr, new_expiration)); + .push_back((*peer.id(), addr.clone(), new_expiration)); } } } @@ -198,43 +201,49 @@ impl InterfaceState { params: &impl PollParameters, ) -> Option<(PeerId, Multiaddr, Instant)> { // Poll receive socket. - while self.recv_socket.poll_readable(cx).is_ready() { - match self - .recv_socket - .recv_from(&mut self.recv_buffer) - .now_or_never() - { - Some(Ok((len, from))) => { + while let Poll::Ready(data) = + Pin::new(&mut self.recv_socket).poll_read(cx, &mut self.recv_buffer) + { + match data { + Ok((len, from)) => { if let Some(packet) = MdnsPacket::new_from_bytes(&self.recv_buffer[..len], from) { self.inject_mdns_packet(packet, params); } } - Some(Err(err)) => log::error!("Failed reading datagram: {}", err), - None => {} + Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => { + // No more bytes available on the socket to read + break; + } + Err(err) => { + log::error!("failed reading datagram: {}", err); + } } } + // Send responses. - while self.send_socket.poll_writable(cx).is_ready() { - if let Some(packet) = self.send_buffer.pop_front() { - match self - .send_socket - .send_to(&packet, SocketAddr::new(self.multicast_addr, 5353)) - .now_or_never() - { - Some(Ok(_)) => log::trace!("sent packet on iface {}", self.addr), - Some(Err(err)) => { - log::error!("error sending packet on iface {}: {}", self.addr, err) - } - None => self.send_buffer.push_front(packet), + while let Some(packet) = self.send_buffer.pop_front() { + match Pin::new(&mut self.send_socket).poll_write( + cx, + &packet, + SocketAddr::new(self.multicast_addr, 5353), + ) { + Poll::Ready(Ok(_)) => log::trace!("sent packet on iface {}", self.addr), + Poll::Ready(Err(err)) => { + log::error!("error sending packet on iface {} {}", self.addr, err); + } + Poll::Pending => { + self.send_buffer.push_front(packet); + break; } - } else if Pin::new(&mut self.timeout).poll_next(cx).is_ready() { - log::trace!("sending query on iface {}", self.addr); - self.send_buffer.push_back(build_query()); - } else { - break; } } + + if Pin::new(&mut self.timeout).poll_next(cx).is_ready() { + log::trace!("sending query on iface {}", self.addr); + self.send_buffer.push_back(build_query()); + } + // Emit discovered event. self.discovered.pop_front() } diff --git a/protocols/mdns/src/behaviour/socket.rs b/protocols/mdns/src/behaviour/socket.rs new file mode 100644 index 00000000000..4406ed33fde --- /dev/null +++ b/protocols/mdns/src/behaviour/socket.rs @@ -0,0 +1,134 @@ +// Copyright 2018 Parity Technologies (UK) Ltd. +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +use std::{ + io::Error, + marker::Unpin, + net::{SocketAddr, UdpSocket}, + task::{Context, Poll}, +}; + +/// Interface that must be implemented by the different runtimes to use the [`UdpSocket`] in async mode +pub trait AsyncSocket: Unpin + Send + 'static { + /// Create the async socket from the [`std::net::UdpSocket`] + fn from_std(socket: UdpSocket) -> std::io::Result + where + Self: Sized; + + /// Attempts to receive a single packet on the socket from the remote address to which it is connected. + fn poll_read( + &mut self, + _cx: &mut Context, + _buf: &mut [u8], + ) -> Poll>; + + /// Attempts to send data on the socket to a given address. + fn poll_write( + &mut self, + _cx: &mut Context, + _packet: &[u8], + _to: SocketAddr, + ) -> Poll>; +} + +#[cfg(feature = "async-io")] +pub mod asio { + use super::*; + use async_io::Async; + use futures::FutureExt; + + /// AsyncIo UdpSocket + pub type AsyncUdpSocket = Async; + + impl AsyncSocket for AsyncUdpSocket { + fn from_std(socket: UdpSocket) -> std::io::Result { + Async::new(socket) + } + + fn poll_read( + &mut self, + cx: &mut Context, + buf: &mut [u8], + ) -> Poll> { + // Poll receive socket. + futures::ready!(self.poll_readable(cx))?; + match self.recv_from(buf).now_or_never() { + Some(data) => Poll::Ready(data), + None => Poll::Pending, + } + } + + fn poll_write( + &mut self, + cx: &mut Context, + packet: &[u8], + to: SocketAddr, + ) -> Poll> { + futures::ready!(self.poll_writable(cx))?; + match self.send_to(packet, to).now_or_never() { + Some(Ok(_)) => Poll::Ready(Ok(())), + Some(Err(err)) => Poll::Ready(Err(err)), + None => Poll::Pending, + } + } + } +} + +#[cfg(feature = "tokio")] +pub mod tokio { + use super::*; + use ::tokio::{io::ReadBuf, net::UdpSocket as TkUdpSocket}; + + /// Tokio ASync Socket` + pub type TokioUdpSocket = TkUdpSocket; + + impl AsyncSocket for TokioUdpSocket { + fn from_std(socket: UdpSocket) -> std::io::Result { + socket.set_nonblocking(true)?; + TokioUdpSocket::from_std(socket) + } + + fn poll_read( + &mut self, + cx: &mut Context, + buf: &mut [u8], + ) -> Poll> { + let mut rbuf = ReadBuf::new(buf); + match self.poll_recv_from(cx, &mut rbuf) { + Poll::Pending => Poll::Pending, + Poll::Ready(Err(err)) => Poll::Ready(Err(err)), + Poll::Ready(Ok(addr)) => Poll::Ready(Ok((rbuf.filled().len(), addr))), + } + } + + fn poll_write( + &mut self, + cx: &mut Context, + packet: &[u8], + to: SocketAddr, + ) -> Poll> { + match self.poll_send_to(cx, packet, to) { + Poll::Pending => Poll::Pending, + Poll::Ready(Err(err)) => Poll::Ready(Err(err)), + Poll::Ready(Ok(_len)) => Poll::Ready(Ok(())), + } + } + } +} diff --git a/protocols/mdns/src/behaviour/timer.rs b/protocols/mdns/src/behaviour/timer.rs new file mode 100644 index 00000000000..fbdeb065b70 --- /dev/null +++ b/protocols/mdns/src/behaviour/timer.rs @@ -0,0 +1,128 @@ +// Copyright 2018 Parity Technologies (UK) Ltd. +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +use std::{ + marker::Unpin, + pin::Pin, + task::{Context, Poll}, + time::{Duration, Instant}, +}; + +/// Simple wrapper for the differents type of timers +#[derive(Debug)] +pub struct Timer { + inner: T, +} + +/// Builder interface to homogenize the differents implementations +pub trait Builder: Send + Unpin + 'static { + /// Creates a timer that emits an event once at the given time instant. + fn at(instant: Instant) -> Self; + + /// Creates a timer that emits events periodically. + fn interval(duration: Duration) -> Self; + + /// Creates a timer that emits events periodically, starting at start. + fn interval_at(start: Instant, duration: Duration) -> Self; +} + +#[cfg(feature = "async-io")] +pub mod asio { + use super::*; + use async_io::Timer as AsioTimer; + use futures::Stream; + + /// Async Timer + pub type AsyncTimer = Timer; + + impl Builder for AsyncTimer { + fn at(instant: Instant) -> Self { + Self { + inner: AsioTimer::at(instant), + } + } + + fn interval(duration: Duration) -> Self { + Self { + inner: AsioTimer::interval(duration), + } + } + + fn interval_at(start: Instant, duration: Duration) -> Self { + Self { + inner: AsioTimer::interval_at(start, duration), + } + } + } + + impl Stream for AsyncTimer { + type Item = Instant; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_next(cx) + } + } +} + +#[cfg(feature = "tokio")] +pub mod tokio { + use super::*; + use ::tokio::time::{self, Instant as TokioInstant, Interval, MissedTickBehavior}; + use futures::Stream; + + /// Tokio wrapper + pub type TokioTimer = Timer; + + impl Builder for TokioTimer { + fn at(instant: Instant) -> Self { + // Taken from: https://docs.rs/async-io/1.7.0/src/async_io/lib.rs.html#91 + let mut inner = time::interval_at( + TokioInstant::from_std(instant), + Duration::new(std::u64::MAX, 1_000_000_000 - 1), + ); + inner.set_missed_tick_behavior(MissedTickBehavior::Skip); + Self { inner } + } + + fn interval(duration: Duration) -> Self { + let mut inner = time::interval_at(TokioInstant::now() + duration, duration); + inner.set_missed_tick_behavior(MissedTickBehavior::Skip); + Self { inner } + } + + fn interval_at(start: Instant, duration: Duration) -> Self { + let mut inner = time::interval_at(TokioInstant::from_std(start), duration); + inner.set_missed_tick_behavior(MissedTickBehavior::Skip); + Self { inner } + } + } + + impl Stream for TokioTimer { + type Item = TokioInstant; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_tick(cx).map(Some) + } + + fn size_hint(&self) -> (usize, Option) { + (std::usize::MAX, None) + } + } +} diff --git a/protocols/mdns/src/lib.rs b/protocols/mdns/src/lib.rs index a99eab691a2..3b484c91daa 100644 --- a/protocols/mdns/src/lib.rs +++ b/protocols/mdns/src/lib.rs @@ -26,16 +26,22 @@ //! //! # Usage //! -//! This crate provides the `Mdns` struct which implements the `NetworkBehaviour` trait. This -//! struct will automatically discover other libp2p nodes on the local network. +//! This crate provides a `Mdns` and `TokioMdns`, depending on the enabled features, which +//! implements the `NetworkBehaviour` trait. This struct will automatically discover other +//! libp2p nodes on the local network. //! use lazy_static::lazy_static; use std::net::{Ipv4Addr, Ipv6Addr}; use std::time::Duration; mod behaviour; +pub use crate::behaviour::{GenMdns, MdnsEvent}; -pub use crate::behaviour::{Mdns, MdnsEvent}; +#[cfg(feature = "async-io")] +pub use crate::behaviour::Mdns; + +#[cfg(feature = "tokio")] +pub use crate::behaviour::TokioMdns; /// The DNS service name for all libp2p peers used to query for addresses. const SERVICE_NAME: &[u8] = b"_p2p._udp.local"; diff --git a/protocols/mdns/tests/smoke.rs b/protocols/mdns/tests/use-async-std.rs similarity index 87% rename from protocols/mdns/tests/smoke.rs rename to protocols/mdns/tests/use-async-std.rs index d123e5abce7..683aed338ce 100644 --- a/protocols/mdns/tests/smoke.rs +++ b/protocols/mdns/tests/use-async-std.rs @@ -28,6 +28,35 @@ use libp2p::{ use std::error::Error; use std::time::Duration; +#[async_std::test] +async fn test_discovery_async_std_ipv4() -> Result<(), Box> { + run_discovery_test(MdnsConfig::default()).await +} + +#[async_std::test] +async fn test_discovery_async_std_ipv6() -> Result<(), Box> { + let config = MdnsConfig { + enable_ipv6: true, + ..Default::default() + }; + run_discovery_test(config).await +} + +#[async_std::test] +async fn test_expired_async_std() -> Result<(), Box> { + env_logger::try_init().ok(); + let config = MdnsConfig { + ttl: Duration::from_secs(1), + query_interval: Duration::from_secs(10), + ..Default::default() + }; + + async_std::future::timeout(Duration::from_secs(6), run_peer_expiration_test(config)) + .await + .map(|_| ()) + .map_err(|e| Box::new(e) as Box) +} + async fn create_swarm(config: MdnsConfig) -> Result, Box> { let id_keys = identity::Keypair::generate_ed25519(); let peer_id = PeerId::from(id_keys.public()); @@ -78,34 +107,6 @@ async fn run_discovery_test(config: MdnsConfig) -> Result<(), Box> { } } -#[async_std::test] -async fn test_discovery_async_std_ipv4() -> Result<(), Box> { - run_discovery_test(MdnsConfig::default()).await -} - -#[tokio::test] -async fn test_discovery_tokio_ipv4() -> Result<(), Box> { - run_discovery_test(MdnsConfig::default()).await -} - -#[async_std::test] -async fn test_discovery_async_std_ipv6() -> Result<(), Box> { - let config = MdnsConfig { - enable_ipv6: true, - ..Default::default() - }; - run_discovery_test(config).await -} - -#[tokio::test] -async fn test_discovery_tokio_ipv6() -> Result<(), Box> { - let config = MdnsConfig { - enable_ipv6: true, - ..Default::default() - }; - run_discovery_test(config).await -} - async fn run_peer_expiration_test(config: MdnsConfig) -> Result<(), Box> { let mut a = create_swarm(config.clone()).await?; let mut b = create_swarm(config).await?; @@ -136,32 +137,3 @@ async fn run_peer_expiration_test(config: MdnsConfig) -> Result<(), Box Result<(), Box> { - env_logger::try_init().ok(); - let config = MdnsConfig { - ttl: Duration::from_secs(1), - query_interval: Duration::from_secs(10), - ..Default::default() - }; - - async_std::future::timeout(Duration::from_secs(6), run_peer_expiration_test(config)) - .await - .map(|_| ()) - .map_err(|e| Box::new(e) as Box) -} - -#[tokio::test] -async fn test_expired_tokio() -> Result<(), Box> { - env_logger::try_init().ok(); - let config = MdnsConfig { - ttl: Duration::from_secs(1), - query_interval: Duration::from_secs(10), - ..Default::default() - }; - - tokio::time::timeout(Duration::from_secs(6), run_peer_expiration_test(config)) - .await - .unwrap() -} diff --git a/protocols/mdns/tests/use-tokio.rs b/protocols/mdns/tests/use-tokio.rs new file mode 100644 index 00000000000..9d6cacd76cb --- /dev/null +++ b/protocols/mdns/tests/use-tokio.rs @@ -0,0 +1,153 @@ +// Copyright 2018 Parity Technologies (UK) Ltd. +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE.use futures::StreamExt; +use futures::StreamExt; +use libp2p::{ + identity, + mdns::{MdnsConfig, MdnsEvent, TokioMdns}, + swarm::{Swarm, SwarmEvent}, + PeerId, +}; +use std::error::Error; +use std::time::Duration; + +#[tokio::test] +async fn test_discovery_tokio_ipv4() -> Result<(), Box> { + run_discovery_test(MdnsConfig::default()).await +} + +#[tokio::test] +async fn test_discovery_tokio_ipv6() -> Result<(), Box> { + let config = MdnsConfig { + enable_ipv6: true, + ..Default::default() + }; + run_discovery_test(config).await +} + +#[tokio::test] +async fn test_expired_tokio() -> Result<(), Box> { + env_logger::try_init().ok(); + let config = MdnsConfig { + ttl: Duration::from_secs(1), + query_interval: Duration::from_secs(10), + ..Default::default() + }; + + run_peer_expiration_test(config).await +} + +async fn create_swarm(config: MdnsConfig) -> Result, Box> { + let id_keys = identity::Keypair::generate_ed25519(); + let peer_id = PeerId::from(id_keys.public()); + let transport = libp2p::tokio_development_transport(id_keys)?; + let behaviour = TokioMdns::new(config).await?; + let mut swarm = Swarm::new(transport, behaviour, peer_id); + swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?; + Ok(swarm) +} + +async fn run_discovery_test(config: MdnsConfig) -> Result<(), Box> { + env_logger::try_init().ok(); + let mut a = create_swarm(config.clone()).await?; + let mut b = create_swarm(config).await?; + let mut discovered_a = false; + let mut discovered_b = false; + loop { + futures::select! { + ev = a.select_next_some() => match ev { + SwarmEvent::Behaviour(MdnsEvent::Discovered(peers)) => { + for (peer, _addr) in peers { + if peer == *b.local_peer_id() { + if discovered_a { + return Ok(()); + } else { + discovered_b = true; + } + } + } + } + _ => {} + }, + ev = b.select_next_some() => match ev { + SwarmEvent::Behaviour(MdnsEvent::Discovered(peers)) => { + for (peer, _addr) in peers { + if peer == *a.local_peer_id() { + if discovered_b { + return Ok(()); + } else { + discovered_a = true; + } + } + } + } + _ => {} + } + } + } +} + +async fn run_peer_expiration_test(config: MdnsConfig) -> Result<(), Box> { + let mut a = create_swarm(config.clone()).await?; + let mut b = create_swarm(config).await?; + let expired_at = tokio::time::sleep(Duration::from_secs(15)); + tokio::pin!(expired_at); + + loop { + tokio::select! { + _ev = &mut expired_at => { + panic!(); + }, + ev = a.select_next_some() => match ev { + SwarmEvent::Behaviour(MdnsEvent::Expired(peers)) => { + for (peer, _addr) in peers { + if peer == *b.local_peer_id() { + return Ok(()); + } + } + } + SwarmEvent::Behaviour(MdnsEvent::Discovered(peers)) => { + for (peer, _addr) in peers { + if peer == *b.local_peer_id() { + expired_at.as_mut().reset(tokio::time::Instant::now() + tokio::time::Duration::from_secs(2)); + } + } + } + _ => {} + }, + ev = b.select_next_some() => match ev { + SwarmEvent::Behaviour(MdnsEvent::Expired(peers)) => { + for (peer, _addr) in peers { + if peer == *a.local_peer_id() { + return Ok(()); + } + } + } + SwarmEvent::Behaviour(MdnsEvent::Discovered(peers)) => { + for (peer, _addr) in peers { + if peer == *a.local_peer_id() { + expired_at.as_mut().reset(tokio::time::Instant::now() + tokio::time::Duration::from_secs(2)); + } + } + } + _ => {} + } + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 6bb577b1f52..3ed00408cb5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -79,8 +79,11 @@ pub use libp2p_identify as identify; #[cfg_attr(docsrs, doc(cfg(feature = "kad")))] #[doc(inline)] pub use libp2p_kad as kad; -#[cfg(feature = "mdns")] -#[cfg_attr(docsrs, doc(cfg(feature = "mdns")))] +#[cfg(any(feature = "mdns-async-io", feature = "mdns-tokio"))] +#[cfg_attr( + docsrs, + doc(cfg(any(feature = "mdns-tokio", feature = "mdns-async-io"))) +)] #[cfg(not(any(target_os = "emscripten", target_os = "wasi", target_os = "unknown")))] #[doc(inline)] pub use libp2p_mdns as mdns;