From add1ff6e50b335f2c8cacdb9e0f4dba90ad22f3b Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Thu, 23 Nov 2023 10:36:09 +1100 Subject: [PATCH] refactor(kad): use `oneshot`s and `async-await` for outbound streams This refactoring addresses several aspects of the current handler implementation: - Remove the manual state machine for outbound streams in favor of using `async-await`. - Use `oneshot`s to track the number of requested outbound streams - Use `futures_bounded::FuturesMap` to track the execution of a stream, thus applying a timeout to the entire request. Resolves: #3130. Related: #3268. Related: #4510. Pull-Request: #4901. --- Cargo.lock | 1 + protocols/kad/Cargo.toml | 1 + protocols/kad/src/behaviour.rs | 1 + protocols/kad/src/handler.rs | 374 ++++++++++++--------------------- 4 files changed, 135 insertions(+), 242 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4a7655a3004..e037dcac20b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2720,6 +2720,7 @@ dependencies = [ "either", "fnv", "futures", + "futures-bounded", "futures-timer", "instant", "libp2p-core", diff --git a/protocols/kad/Cargo.toml b/protocols/kad/Cargo.toml index 04101d51026..1e4c788cf00 100644 --- a/protocols/kad/Cargo.toml +++ b/protocols/kad/Cargo.toml @@ -19,6 +19,7 @@ asynchronous-codec = { workspace = true } futures = "0.3.29" libp2p-core = { workspace = true } libp2p-swarm = { workspace = true } +futures-bounded = { workspace = true } quick-protobuf = "0.8" quick-protobuf-codec = { workspace = true } libp2p-identity = { workspace = true, features = ["rand"] } diff --git a/protocols/kad/src/behaviour.rs b/protocols/kad/src/behaviour.rs index 5a4b737c998..cde4fbb8536 100644 --- a/protocols/kad/src/behaviour.rs +++ b/protocols/kad/src/behaviour.rs @@ -3173,6 +3173,7 @@ impl QueryInfo { multiaddrs: external_addresses.clone(), connection_ty: crate::protocol::ConnectionType::Connected, }, + query_id, }, }, QueryInfo::GetRecord { key, .. } => HandlerIn::GetRecord { diff --git a/protocols/kad/src/handler.rs b/protocols/kad/src/handler.rs index adfb076541c..318261d8d21 100644 --- a/protocols/kad/src/handler.rs +++ b/protocols/kad/src/handler.rs @@ -25,22 +25,22 @@ use crate::protocol::{ use crate::record::{self, Record}; use crate::QueryId; use either::Either; +use futures::channel::oneshot; use futures::prelude::*; use futures::stream::SelectAll; use libp2p_core::{upgrade, ConnectedPoint}; use libp2p_identity::PeerId; -use libp2p_swarm::handler::{ - ConnectionEvent, DialUpgradeError, FullyNegotiatedInbound, FullyNegotiatedOutbound, -}; +use libp2p_swarm::handler::{ConnectionEvent, FullyNegotiatedInbound, FullyNegotiatedOutbound}; use libp2p_swarm::{ ConnectionHandler, ConnectionHandlerEvent, Stream, StreamUpgradeError, SubstreamProtocol, SupportedProtocols, }; use std::collections::VecDeque; use std::task::Waker; +use std::time::Duration; use std::{error, fmt, io, marker::PhantomData, pin::Pin, task::Context, task::Poll}; -const MAX_NUM_SUBSTREAMS: usize = 32; +const MAX_NUM_STREAMS: usize = 32; /// Protocol handler that manages substreams for the Kademlia protocol /// on a single connection with a peer. @@ -59,15 +59,16 @@ pub struct Handler { /// Next unique ID of a connection. next_connec_unique_id: UniqueConnecId, - /// List of active outbound substreams with the state they are in. - outbound_substreams: SelectAll, + /// List of active outbound streams. + outbound_substreams: futures_bounded::FuturesMap>>, - /// Number of outbound streams being upgraded right now. - num_requested_outbound_streams: usize, + /// Contains one [`oneshot::Sender`] per outbound stream that we have requested. + pending_streams: + VecDeque, StreamUpgradeError>>>, /// List of outbound substreams that are waiting to become active next. /// Contains the request we want to send, and the user data if we expect an answer. - pending_messages: VecDeque<(KadRequestMsg, Option)>, + pending_messages: VecDeque<(KadRequestMsg, QueryId)>, /// List of active inbound substreams with the state they are in. inbound_substreams: SelectAll, @@ -95,24 +96,6 @@ struct ProtocolStatus { reported: bool, } -/// State of an active outbound substream. -enum OutboundSubstreamState { - /// Waiting to send a message to the remote. - PendingSend(KadOutStreamSink, KadRequestMsg, Option), - /// Waiting to flush the substream so that the data arrives to the remote. - PendingFlush(KadOutStreamSink, Option), - /// Waiting for an answer back from the remote. - // TODO: add timeout - WaitingAnswer(KadOutStreamSink, QueryId), - /// An error happened on the substream and we should report the error to the user. - ReportError(HandlerQueryErr, QueryId), - /// The substream is being closed. - Closing(KadOutStreamSink), - /// The substream is complete and will not perform any more work. - Done, - Poisoned, -} - /// State of an active inbound substream. enum InboundSubstreamState { /// Waiting for a request from the remote. @@ -292,8 +275,6 @@ pub enum HandlerEvent { /// Error that can happen when requesting an RPC query. #[derive(Debug)] pub enum HandlerQueryErr { - /// Error while trying to perform the query. - Upgrade(StreamUpgradeError), /// Received an answer that doesn't correspond to the request. UnexpectedMessage, /// I/O error in the substream. @@ -303,9 +284,6 @@ pub enum HandlerQueryErr { impl fmt::Display for HandlerQueryErr { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - HandlerQueryErr::Upgrade(err) => { - write!(f, "Error while performing Kademlia query: {err}") - } HandlerQueryErr::UnexpectedMessage => { write!( f, @@ -322,19 +300,12 @@ impl fmt::Display for HandlerQueryErr { impl error::Error for HandlerQueryErr { fn source(&self) -> Option<&(dyn error::Error + 'static)> { match self { - HandlerQueryErr::Upgrade(err) => Some(err), HandlerQueryErr::UnexpectedMessage => None, HandlerQueryErr::Io(err) => Some(err), } } } -impl From> for HandlerQueryErr { - fn from(err: StreamUpgradeError) -> Self { - HandlerQueryErr::Upgrade(err) - } -} - /// Event to send to the handler. #[derive(Debug)] pub enum HandlerIn { @@ -355,7 +326,7 @@ pub enum HandlerIn { FindNodeReq { /// Identifier of the node. key: Vec, - /// Custom user data. Passed back in the out event when the results arrive. + /// ID of the query that generated this request. query_id: QueryId, }, @@ -374,7 +345,7 @@ pub enum HandlerIn { GetProvidersReq { /// Identifier being searched. key: record::Key, - /// Custom user data. Passed back in the out event when the results arrive. + /// ID of the query that generated this request. query_id: QueryId, }, @@ -399,13 +370,15 @@ pub enum HandlerIn { key: record::Key, /// Known provider for this key. provider: KadPeer, + /// ID of the query that generated this request. + query_id: QueryId, }, /// Request to retrieve a record from the DHT. GetRecord { /// The key of the record. key: record::Key, - /// Custom data. Passed back in the out event when the results arrive. + /// ID of the query that generated this request. query_id: QueryId, }, @@ -422,7 +395,7 @@ pub enum HandlerIn { /// Put a value into the dht records. PutRecord { record: Record, - /// Custom data. Passed back in the out event when the results arrive. + /// ID of the query that generated this request. query_id: QueryId, }, @@ -480,8 +453,11 @@ impl Handler { remote_peer_id, next_connec_unique_id: UniqueConnecId(0), inbound_substreams: Default::default(), - outbound_substreams: Default::default(), - num_requested_outbound_streams: 0, + outbound_substreams: futures_bounded::FuturesMap::new( + Duration::from_secs(10), + MAX_NUM_STREAMS, + ), + pending_streams: Default::default(), pending_messages: Default::default(), protocol_status: None, remote_supported_protocols: Default::default(), @@ -490,20 +466,18 @@ impl Handler { fn on_fully_negotiated_outbound( &mut self, - FullyNegotiatedOutbound { protocol, info: () }: FullyNegotiatedOutbound< + FullyNegotiatedOutbound { + protocol: stream, + info: (), + }: FullyNegotiatedOutbound< ::OutboundProtocol, ::OutboundOpenInfo, >, ) { - if let Some((msg, query_id)) = self.pending_messages.pop_front() { - self.outbound_substreams - .push(OutboundSubstreamState::PendingSend(protocol, msg, query_id)); - } else { - debug_assert!(false, "Requested outbound stream without message") + if let Some(sender) = self.pending_streams.pop_front() { + let _ = sender.send(Ok(stream)); } - self.num_requested_outbound_streams -= 1; - if self.protocol_status.is_none() { // Upon the first successfully negotiated substream, we know that the // remote is configured with the same protocol name and we want @@ -539,7 +513,7 @@ impl Handler { }); } - if self.inbound_substreams.len() == MAX_NUM_SUBSTREAMS { + if self.inbound_substreams.len() == MAX_NUM_STREAMS { if let Some(s) = self.inbound_substreams.iter_mut().find(|s| { matches!( s, @@ -573,24 +547,42 @@ impl Handler { }); } - fn on_dial_upgrade_error( - &mut self, - DialUpgradeError { - info: (), error, .. - }: DialUpgradeError< - ::OutboundOpenInfo, - ::OutboundProtocol, - >, - ) { - // TODO: cache the fact that the remote doesn't support kademlia at all, so that we don't - // continue trying + /// Takes the given [`KadRequestMsg`] and composes it into an outbound request-response protocol handshake using a [`oneshot::channel`]. + fn queue_new_stream(&mut self, id: QueryId, msg: KadRequestMsg) { + let (sender, receiver) = oneshot::channel(); + + self.pending_streams.push_back(sender); + let result = self.outbound_substreams.try_push(id, async move { + let mut stream = receiver + .await + .map_err(|_| io::Error::from(io::ErrorKind::BrokenPipe))? + .map_err(|e| match e { + StreamUpgradeError::Timeout => io::ErrorKind::TimedOut.into(), + StreamUpgradeError::Apply(e) => e, + StreamUpgradeError::NegotiationFailed => { + io::Error::new(io::ErrorKind::ConnectionRefused, "protocol not supported") + } + StreamUpgradeError::Io(e) => e, + })?; - if let Some((_, Some(query_id))) = self.pending_messages.pop_front() { - self.outbound_substreams - .push(OutboundSubstreamState::ReportError(error.into(), query_id)); - } + let has_answer = !matches!(msg, KadRequestMsg::AddProvider { .. }); + + stream.send(msg).await?; + stream.close().await?; + + if !has_answer { + return Ok(None); + } - self.num_requested_outbound_streams -= 1; + let msg = stream.next().await.ok_or(io::ErrorKind::UnexpectedEof)??; + + Ok(Some(msg)) + }); + + debug_assert!( + result.is_ok(), + "Expected to not create more streams than allowed" + ); } } @@ -627,7 +619,7 @@ impl ConnectionHandler for Handler { } HandlerIn::FindNodeReq { key, query_id } => { let msg = KadRequestMsg::FindNode { key }; - self.pending_messages.push_back((msg, Some(query_id))); + self.pending_messages.push_back((msg, query_id)); } HandlerIn::FindNodeRes { closer_peers, @@ -635,7 +627,7 @@ impl ConnectionHandler for Handler { } => self.answer_pending_request(request_id, KadResponseMsg::FindNode { closer_peers }), HandlerIn::GetProvidersReq { key, query_id } => { let msg = KadRequestMsg::GetProviders { key }; - self.pending_messages.push_back((msg, Some(query_id))); + self.pending_messages.push_back((msg, query_id)); } HandlerIn::GetProvidersRes { closer_peers, @@ -648,17 +640,21 @@ impl ConnectionHandler for Handler { provider_peers, }, ), - HandlerIn::AddProvider { key, provider } => { + HandlerIn::AddProvider { + key, + provider, + query_id, + } => { let msg = KadRequestMsg::AddProvider { key, provider }; - self.pending_messages.push_back((msg, None)); + self.pending_messages.push_back((msg, query_id)); } HandlerIn::GetRecord { key, query_id } => { let msg = KadRequestMsg::GetValue { key }; - self.pending_messages.push_back((msg, Some(query_id))); + self.pending_messages.push_back((msg, query_id)); } HandlerIn::PutRecord { record, query_id } => { let msg = KadRequestMsg::PutValue { record }; - self.pending_messages.push_back((msg, Some(query_id))); + self.pending_messages.push_back((msg, query_id)); } HandlerIn::GetRecordRes { record, @@ -712,44 +708,68 @@ impl ConnectionHandler for Handler { ) -> Poll< ConnectionHandlerEvent, > { - match &mut self.protocol_status { - Some(status) if !status.reported => { - status.reported = true; - let event = if status.supported { - HandlerEvent::ProtocolConfirmed { - endpoint: self.endpoint.clone(), - } - } else { - HandlerEvent::ProtocolNotSupported { - endpoint: self.endpoint.clone(), - } - }; + loop { + match &mut self.protocol_status { + Some(status) if !status.reported => { + status.reported = true; + let event = if status.supported { + HandlerEvent::ProtocolConfirmed { + endpoint: self.endpoint.clone(), + } + } else { + HandlerEvent::ProtocolNotSupported { + endpoint: self.endpoint.clone(), + } + }; - return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event)); + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event)); + } + _ => {} } - _ => {} - } - if let Poll::Ready(Some(event)) = self.outbound_substreams.poll_next_unpin(cx) { - return Poll::Ready(event); - } + match self.outbound_substreams.poll_unpin(cx) { + Poll::Ready((query, Ok(Ok(Some(response))))) => { + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( + process_kad_response(response, query), + )) + } + Poll::Ready((_, Ok(Ok(None)))) => { + continue; + } + Poll::Ready((query_id, Ok(Err(e)))) => { + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( + HandlerEvent::QueryError { + error: HandlerQueryErr::Io(e), + query_id, + }, + )) + } + Poll::Ready((query_id, Err(_timeout))) => { + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( + HandlerEvent::QueryError { + error: HandlerQueryErr::Io(io::ErrorKind::TimedOut.into()), + query_id, + }, + )) + } + Poll::Pending => {} + } - if let Poll::Ready(Some(event)) = self.inbound_substreams.poll_next_unpin(cx) { - return Poll::Ready(event); - } + if let Poll::Ready(Some(event)) = self.inbound_substreams.poll_next_unpin(cx) { + return Poll::Ready(event); + } - let num_in_progress_outbound_substreams = - self.outbound_substreams.len() + self.num_requested_outbound_streams; - if num_in_progress_outbound_substreams < MAX_NUM_SUBSTREAMS - && self.num_requested_outbound_streams < self.pending_messages.len() - { - self.num_requested_outbound_streams += 1; - return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { - protocol: SubstreamProtocol::new(self.protocol_config.clone(), ()), - }); - } + if self.outbound_substreams.len() < MAX_NUM_STREAMS { + if let Some((msg, id)) = self.pending_messages.pop_front() { + self.queue_new_stream(id, msg); + return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { + protocol: SubstreamProtocol::new(self.protocol_config.clone(), ()), + }); + } + } - Poll::Pending + return Poll::Pending; + } } fn on_connection_event( @@ -768,8 +788,10 @@ impl ConnectionHandler for Handler { ConnectionEvent::FullyNegotiatedInbound(fully_negotiated_inbound) => { self.on_fully_negotiated_inbound(fully_negotiated_inbound) } - ConnectionEvent::DialUpgradeError(dial_upgrade_error) => { - self.on_dial_upgrade_error(dial_upgrade_error) + ConnectionEvent::DialUpgradeError(ev) => { + if let Some(sender) = self.pending_streams.pop_front() { + let _ = sender.send(Err(ev.error)); + } } ConnectionEvent::RemoteProtocolsChange(change) => { let dirty = self.remote_supported_protocols.on_protocols_change(change); @@ -839,138 +861,6 @@ impl Handler { } } -impl futures::Stream for OutboundSubstreamState { - type Item = ConnectionHandlerEvent; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.get_mut(); - - loop { - match std::mem::replace(this, OutboundSubstreamState::Poisoned) { - OutboundSubstreamState::PendingSend(mut substream, msg, query_id) => { - match substream.poll_ready_unpin(cx) { - Poll::Ready(Ok(())) => match substream.start_send_unpin(msg) { - Ok(()) => { - *this = OutboundSubstreamState::PendingFlush(substream, query_id); - } - Err(error) => { - *this = OutboundSubstreamState::Done; - let event = query_id.map(|query_id| { - ConnectionHandlerEvent::NotifyBehaviour( - HandlerEvent::QueryError { - error: HandlerQueryErr::Io(error), - query_id, - }, - ) - }); - - return Poll::Ready(event); - } - }, - Poll::Pending => { - *this = OutboundSubstreamState::PendingSend(substream, msg, query_id); - return Poll::Pending; - } - Poll::Ready(Err(error)) => { - *this = OutboundSubstreamState::Done; - let event = query_id.map(|query_id| { - ConnectionHandlerEvent::NotifyBehaviour(HandlerEvent::QueryError { - error: HandlerQueryErr::Io(error), - query_id, - }) - }); - - return Poll::Ready(event); - } - } - } - OutboundSubstreamState::PendingFlush(mut substream, query_id) => { - match substream.poll_flush_unpin(cx) { - Poll::Ready(Ok(())) => { - if let Some(query_id) = query_id { - *this = OutboundSubstreamState::WaitingAnswer(substream, query_id); - } else { - *this = OutboundSubstreamState::Closing(substream); - } - } - Poll::Pending => { - *this = OutboundSubstreamState::PendingFlush(substream, query_id); - return Poll::Pending; - } - Poll::Ready(Err(error)) => { - *this = OutboundSubstreamState::Done; - let event = query_id.map(|query_id| { - ConnectionHandlerEvent::NotifyBehaviour(HandlerEvent::QueryError { - error: HandlerQueryErr::Io(error), - query_id, - }) - }); - - return Poll::Ready(event); - } - } - } - OutboundSubstreamState::WaitingAnswer(mut substream, query_id) => { - match substream.poll_next_unpin(cx) { - Poll::Ready(Some(Ok(msg))) => { - *this = OutboundSubstreamState::Closing(substream); - let event = process_kad_response(msg, query_id); - - return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour( - event, - ))); - } - Poll::Pending => { - *this = OutboundSubstreamState::WaitingAnswer(substream, query_id); - return Poll::Pending; - } - Poll::Ready(Some(Err(error))) => { - *this = OutboundSubstreamState::Done; - let event = HandlerEvent::QueryError { - error: HandlerQueryErr::Io(error), - query_id, - }; - - return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour( - event, - ))); - } - Poll::Ready(None) => { - *this = OutboundSubstreamState::Done; - let event = HandlerEvent::QueryError { - error: HandlerQueryErr::Io(io::ErrorKind::UnexpectedEof.into()), - query_id, - }; - - return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour( - event, - ))); - } - } - } - OutboundSubstreamState::ReportError(error, query_id) => { - *this = OutboundSubstreamState::Done; - let event = HandlerEvent::QueryError { error, query_id }; - - return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(event))); - } - OutboundSubstreamState::Closing(mut stream) => match stream.poll_close_unpin(cx) { - Poll::Ready(Ok(())) | Poll::Ready(Err(_)) => return Poll::Ready(None), - Poll::Pending => { - *this = OutboundSubstreamState::Closing(stream); - return Poll::Pending; - } - }, - OutboundSubstreamState::Done => { - *this = OutboundSubstreamState::Done; - return Poll::Ready(None); - } - OutboundSubstreamState::Poisoned => unreachable!(), - } - } - } -} - impl futures::Stream for InboundSubstreamState { type Item = ConnectionHandlerEvent;