diff --git a/protocols/kad/CHANGELOG.md b/protocols/kad/CHANGELOG.md index b2a06dc691c..a217fa78b4f 100644 --- a/protocols/kad/CHANGELOG.md +++ b/protocols/kad/CHANGELOG.md @@ -15,6 +15,8 @@ See [PR 5122](https://github.com/libp2p/rust-libp2p/pull/5122). - Compute `jobs_query_capacity` accurately. See [PR 5148](https://github.com/libp2p/rust-libp2p/pull/5148). +- Introduce `AsyncBehaviour`, a wrapper of `Behaviour` allowing to easily track Kademlia queries. + See [PR 5294](https://github.com/libp2p/rust-libp2p/pull/5294). ## 0.45.3 diff --git a/protocols/kad/src/async_behaviour.rs b/protocols/kad/src/async_behaviour.rs new file mode 100644 index 00000000000..d6b02013c1e --- /dev/null +++ b/protocols/kad/src/async_behaviour.rs @@ -0,0 +1,356 @@ +use std::{collections::HashMap, task::Poll}; + +use futures::{channel::mpsc, StreamExt}; +use libp2p_core::{Endpoint, Multiaddr}; +use libp2p_identity::PeerId; +use libp2p_swarm::{ + ConnectionDenied, ConnectionId, FromSwarm, NetworkBehaviour, THandler, THandlerInEvent, + THandlerOutEvent, ToSwarm, +}; + +use crate::{ + kbucket, store, AddProviderResult, Behaviour, BootstrapResult, Event, GetClosestPeersResult, + GetProvidersResult, GetRecordResult, NoKnownPeers, PutRecordResult, QueryId, QueryResult, + QueryStats, Quorum, Record, RecordKey, +}; + +/// The results of Kademlia queries (strongly typed) and only +/// those initiated by the user. +pub struct AsyncQueryResult { + pub id: QueryId, + pub result: T, + pub stats: QueryStats, +} +impl AsyncQueryResult { + fn map(self, f: impl Fn(T) -> Out) -> AsyncQueryResult { + AsyncQueryResult { + id: self.id, + stats: self.stats, + result: f(self.result), + } + } +} + +type UnboundedQueryResultSender = mpsc::UnboundedSender>; + +enum QueryResultSender { + Bootstrap(UnboundedQueryResultSender), + GetClosestPeers(UnboundedQueryResultSender), + GetProviders(UnboundedQueryResultSender), + StartProviding(UnboundedQueryResultSender), + GetRecord(UnboundedQueryResultSender), + PutRecord(UnboundedQueryResultSender), +} + +/// A handle to receive [`AsyncQueryResult`]. +#[must_use = "Streams do nothing unless polled."] +pub struct AsyncQueryResultStream(mpsc::UnboundedReceiver>); +impl futures::Stream for AsyncQueryResultStream { + type Item = AsyncQueryResult; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.0.poll_next_unpin(cx) + } +} + +/// A wrapper of [`Behaviour`] allowing to easily track +/// [`QueryResult`] of user initiated Kademlia queries. +/// +/// For each queries like [`Behaviour::bootstrap`], [`Behaviour::get_closest_peers`], etc +/// a corresponding method ([`AsyncBehaviour::bootstrap_async`], [`AsyncBehaviour::get_closest_peers_async`]) +/// is available, allowing the developer to be notified from a typed [`AsyncQueryResultStream`] +/// instead from the normal [`Event::OutboundQueryProgressed`] event. +/// +/// If a [`QueryResult`] has no corresponding [`AsyncQueryResultStream`] or +/// if the corresponding one has been dropped, it will simply be forwarded to the `Swarm` +/// with an [`Event::OutboundQueryProgressed`] like nothing happen. +/// +/// For more information, see [`Behaviour`]. +pub struct AsyncBehaviour { + inner: Behaviour, + query_result_senders: HashMap, +} + +impl std::ops::Deref for AsyncBehaviour { + type Target = Behaviour; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl std::ops::DerefMut for AsyncBehaviour { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + +impl AsyncBehaviour +where + TStore: store::RecordStore + Send + 'static, +{ + pub fn new(inner: Behaviour) -> Self { + Self { + inner, + query_result_senders: Default::default(), + } + } + + fn handle_inner_event(&mut self, event: Event) -> Option<::ToSwarm> { + match event { + Event::OutboundQueryProgressed { + id, + result, + stats, + step, + } => { + fn do_send( + sender: &UnboundedQueryResultSender, + id: QueryId, + result: T, + stats: QueryStats, + ) -> Option> { + match sender.unbounded_send(AsyncQueryResult { id, result, stats }) { + Ok(_) => { + // The event has been successfully sent into the channel so there is no + // need to forward it backup to the swarm. + None + } + Err(err) => { + // The receiver is closed. This is probably normal (the user got what he needed and dropped the receiver). + // So we don't log anything but we still forward this event back up to the swarm. + Some(err.into_inner()) + } + } + } + + let Some(sender) = self.query_result_senders.get(&id) else { + // This query was either not triggered by the user or the receiver has been dropped and removed + // so we simply forward it back up to the swarm like nothing happened. + return Some(Event::OutboundQueryProgressed { + id, + result, + stats, + step, + }); + }; + let event = match (result, sender) { + (QueryResult::Bootstrap(result), QueryResultSender::Bootstrap(sender)) => { + do_send(sender, id, result, stats).map(|qr| qr.map(QueryResult::Bootstrap)) + } + ( + QueryResult::GetClosestPeers(result), + QueryResultSender::GetClosestPeers(sender), + ) => do_send(sender, id, result, stats) + .map(|qr| qr.map(QueryResult::GetClosestPeers)), + ( + QueryResult::GetProviders(result), + QueryResultSender::GetProviders(sender), + ) => do_send(sender, id, result, stats) + .map(|qr| qr.map(QueryResult::GetProviders)), + ( + QueryResult::StartProviding(result), + QueryResultSender::StartProviding(sender), + ) => do_send(sender, id, result, stats) + .map(|qr| qr.map(QueryResult::StartProviding)), + (QueryResult::GetRecord(result), QueryResultSender::GetRecord(sender)) => { + do_send(sender, id, result, stats).map(|qr| qr.map(QueryResult::GetRecord)) + } + (QueryResult::PutRecord(result), QueryResultSender::PutRecord(sender)) => { + do_send(sender, id, result, stats).map(|qr| qr.map(QueryResult::PutRecord)) + } + (result, _) => { + unreachable!("Wrong sender type for query {id} : unable to send {result:?}") + } + }; + + if let Some(AsyncQueryResult { id, result, stats }) = event { + // The receiver was closed so we were unable to send the result. + // We remove the sender and forward the event back up to the swarm + self.query_result_senders.remove(&id); + return Some(Event::OutboundQueryProgressed { + id, + result, + stats, + step, + }); + } + + if step.last { + // This was the last query_result and we just send it successfully. + // We remove the sender. Dropping it will close the channel and + // the receiver will be notified. + self.query_result_senders.remove(&id); + } + + None + } + event => Some(event), + } + } + + fn add_query( + &mut self, + query_id: QueryId, + f: impl Fn(UnboundedQueryResultSender) -> QueryResultSender, + ) -> AsyncQueryResultStream { + let (tx, rx) = mpsc::unbounded(); + self.query_result_senders.insert(query_id, f(tx)); + AsyncQueryResultStream(rx) + } + + /// Start a Bootstrap query and capture its results in a typed [`AsyncQueryResultStream`]. + /// + /// For more information, see [`Behaviour::bootstrap`]. + pub fn bootstrap_async( + &mut self, + ) -> Result, NoKnownPeers> { + let query_id = self.inner.bootstrap()?; + Ok(self.add_query(query_id, QueryResultSender::Bootstrap)) + } + + /// Start a GetClosestPeers query and capture its results in a typed [`AsyncQueryResultStream`]. + /// + /// For more information, see [`Behaviour::get_closest_peers`]. + pub fn get_closest_peers_async( + &mut self, + key: K, + ) -> AsyncQueryResultStream + where + K: Into> + Into> + Clone, + { + let query_id = self.inner.get_closest_peers(key); + self.add_query(query_id, QueryResultSender::GetClosestPeers) + } + + /// Start a GetProviders query and capture its results in a typed [`AsyncQueryResultStream`]. + /// + /// For more information, see [`Behaviour::get_providers`]. + pub fn get_providers_async( + &mut self, + key: RecordKey, + ) -> AsyncQueryResultStream { + let query_id = self.inner.get_providers(key); + self.add_query(query_id, QueryResultSender::GetProviders) + } + + /// Start a StartProviding query and capture its results in a typed [`AsyncQueryResultStream`]. + /// + /// For more information, see [`Behaviour::start_providing`]. + pub fn start_providing_async( + &mut self, + key: RecordKey, + ) -> Result, store::Error> { + let query_id = self.inner.start_providing(key)?; + Ok(self.add_query(query_id, QueryResultSender::StartProviding)) + } + + /// Start a GetRecord query and capture its results in a typed [`AsyncQueryResultStream`]. + /// + /// For more information, see [`Behaviour::get_record`]. + pub fn get_record_async(&mut self, key: RecordKey) -> AsyncQueryResultStream { + let query_id = self.inner.get_record(key); + self.add_query(query_id, QueryResultSender::GetRecord) + } + + /// Start a PutRecord query and capture its results in a typed [`AsyncQueryResultStream`]. + /// + /// For more information, see [`Behaviour::put_record`]. + pub fn put_record_async( + &mut self, + record: Record, + quorum: Quorum, + ) -> Result, store::Error> { + let query_id = self.inner.put_record(record, quorum)?; + Ok(self.add_query(query_id, QueryResultSender::PutRecord)) + } +} + +impl NetworkBehaviour for AsyncBehaviour +where + TStore: store::RecordStore + Send + 'static, +{ + type ConnectionHandler = as NetworkBehaviour>::ConnectionHandler; + type ToSwarm = as NetworkBehaviour>::ToSwarm; + + fn handle_pending_inbound_connection( + &mut self, + connection_id: ConnectionId, + local_addr: &Multiaddr, + remote_addr: &Multiaddr, + ) -> Result<(), ConnectionDenied> { + self.inner + .handle_pending_inbound_connection(connection_id, local_addr, remote_addr) + } + + fn handle_established_inbound_connection( + &mut self, + connection_id: ConnectionId, + peer: PeerId, + local_addr: &Multiaddr, + remote_addr: &Multiaddr, + ) -> Result, ConnectionDenied> { + self.inner.handle_established_inbound_connection( + connection_id, + peer, + local_addr, + remote_addr, + ) + } + + fn handle_pending_outbound_connection( + &mut self, + connection_id: ConnectionId, + maybe_peer: Option, + addresses: &[Multiaddr], + effective_role: Endpoint, + ) -> Result, ConnectionDenied> { + self.inner.handle_pending_outbound_connection( + connection_id, + maybe_peer, + addresses, + effective_role, + ) + } + + fn handle_established_outbound_connection( + &mut self, + connection_id: ConnectionId, + peer: PeerId, + addr: &Multiaddr, + role_override: Endpoint, + ) -> Result, ConnectionDenied> { + self.inner + .handle_established_outbound_connection(connection_id, peer, addr, role_override) + } + + fn on_swarm_event(&mut self, event: FromSwarm<'_>) { + self.inner.on_swarm_event(event); + } + + fn on_connection_handler_event( + &mut self, + peer_id: PeerId, + connection_id: ConnectionId, + event: THandlerOutEvent, + ) { + self.inner + .on_connection_handler_event(peer_id, connection_id, event); + } + + fn poll( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>> { + while let Poll::Ready(event) = self.inner.poll(cx) { + if let Some(event) = event.map_out_opt(|e| self.handle_inner_event(e)) { + return Poll::Ready(event); + } + } + + Poll::Pending + } +} diff --git a/protocols/kad/src/lib.rs b/protocols/kad/src/lib.rs index bc01b9fd3ce..54dc0185b9e 100644 --- a/protocols/kad/src/lib.rs +++ b/protocols/kad/src/lib.rs @@ -36,6 +36,7 @@ #![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))] mod addresses; +mod async_behaviour; mod behaviour; mod bootstrap; mod handler; @@ -55,6 +56,7 @@ mod proto { } pub use addresses::Addresses; +pub use async_behaviour::{AsyncBehaviour, AsyncQueryResult, AsyncQueryResultStream}; pub use behaviour::{ AddProviderContext, AddProviderError, AddProviderOk, AddProviderPhase, AddProviderResult, BootstrapError, BootstrapOk, BootstrapResult, GetClosestPeersError, GetClosestPeersOk, diff --git a/swarm/CHANGELOG.md b/swarm/CHANGELOG.md index 75e18a6a5af..fb8c3028208 100644 --- a/swarm/CHANGELOG.md +++ b/swarm/CHANGELOG.md @@ -5,6 +5,8 @@ The address is broadcast to all behaviours via `FromSwarm::NewExternalAddrOfPeer`. Protocols that want to collect these addresses can use the new `PeerAddresses` utility. See [PR 4371](https://github.com/libp2p/rust-libp2p/pull/4371). +- Add utility function `map_out_opt` on `ToSwarm`. + See [PR 5294](https://github.com/libp2p/rust-libp2p/pull/5294). ## 0.44.1 diff --git a/swarm/src/behaviour.rs b/swarm/src/behaviour.rs index 5070871a4c1..20f725e33b4 100644 --- a/swarm/src/behaviour.rs +++ b/swarm/src/behaviour.rs @@ -387,6 +387,43 @@ impl ToSwarm { }, } } + + /// Map the event the swarm will return to an optional one. + pub fn map_out_opt( + self, + f: impl FnOnce(TOutEvent) -> Option, + ) -> Option> { + match self { + ToSwarm::GenerateEvent(e) => f(e).map(ToSwarm::GenerateEvent), + ToSwarm::Dial { opts } => Some(ToSwarm::Dial { opts }), + ToSwarm::ListenOn { opts } => Some(ToSwarm::ListenOn { opts }), + ToSwarm::RemoveListener { id } => Some(ToSwarm::RemoveListener { id }), + ToSwarm::NotifyHandler { + peer_id, + handler, + event, + } => Some(ToSwarm::NotifyHandler { + peer_id, + handler, + event, + }), + ToSwarm::NewExternalAddrCandidate(addr) => { + Some(ToSwarm::NewExternalAddrCandidate(addr)) + } + ToSwarm::ExternalAddrConfirmed(addr) => Some(ToSwarm::ExternalAddrConfirmed(addr)), + ToSwarm::ExternalAddrExpired(addr) => Some(ToSwarm::ExternalAddrExpired(addr)), + ToSwarm::CloseConnection { + peer_id, + connection, + } => Some(ToSwarm::CloseConnection { + peer_id, + connection, + }), + ToSwarm::NewExternalAddrOfPeer { peer_id, address } => { + Some(ToSwarm::NewExternalAddrOfPeer { peer_id, address }) + } + } + } } /// The options w.r.t. which connection handler to notify of an event.