Skip to content

Commit

Permalink
Use InflightProtocolDataQueue in libp2p-request-response
Browse files Browse the repository at this point in the history
  • Loading branch information
thomaseizinger committed Nov 14, 2023
1 parent 76ad5b1 commit f600aa1
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 164 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions protocols/request-response/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ instant = "0.1.12"
libp2p-core = { workspace = true }
libp2p-swarm = { workspace = true }
libp2p-identity = { workspace = true }
libp2p-protocol-utils = { workspace = true }
rand = "0.8"
serde = { version = "1.0", optional = true}
serde_json = { version = "1.0.108", optional = true }
Expand Down
288 changes: 124 additions & 164 deletions protocols/request-response/src/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,8 @@ use crate::{InboundRequestId, OutboundRequestId, EMPTY_QUEUE_SHRINK_THRESHOLD};

use futures::channel::mpsc;
use futures::{channel::oneshot, prelude::*};
use libp2p_swarm::handler::{
ConnectionEvent, DialUpgradeError, FullyNegotiatedInbound, FullyNegotiatedOutbound,
ListenUpgradeError,
};
use libp2p_protocol_utils::InflightProtocolDataQueue;
use libp2p_swarm::handler::{ConnectionEvent, FullyNegotiatedInbound};
use libp2p_swarm::{
handler::{ConnectionHandler, ConnectionHandlerEvent, StreamUpgradeError},
SubstreamProtocol,
Expand All @@ -47,6 +45,7 @@ use std::{
task::{Context, Poll},
time::Duration,
};
use void::Void;

/// A connection handler for a request response [`Behaviour`](super::Behaviour) protocol.
pub struct Handler<TCodec>
Expand All @@ -59,10 +58,13 @@ where
codec: TCodec,
/// Queue of events to emit in `poll()`.
pending_events: VecDeque<Event<TCodec>>,
/// Outbound upgrades waiting to be emitted as an `OutboundSubstreamRequest`.
pending_outbound: VecDeque<OutboundMessage<TCodec>>,

requested_outbound: VecDeque<OutboundMessage<TCodec>>,
pending_streams: InflightProtocolDataQueue<
(OutboundRequestId, TCodec::Request),
SmallVec<[TCodec::Protocol; 2]>,
Result<(libp2p_swarm::Stream, TCodec::Protocol), StreamUpgradeError<Void>>,
>,

/// A channel for receiving inbound requests.
inbound_receiver: mpsc::Receiver<(
InboundRequestId,
Expand Down Expand Up @@ -102,8 +104,7 @@ where
Self {
inbound_protocols,
codec,
pending_outbound: VecDeque::new(),
requested_outbound: Default::default(),
pending_streams: InflightProtocolDataQueue::default(),
inbound_receiver,
inbound_sender,
pending_events: VecDeque::new(),
Expand Down Expand Up @@ -167,92 +168,6 @@ where
tracing::warn!("Dropping inbound stream because we are at capacity")
}
}

fn on_fully_negotiated_outbound(
&mut self,
FullyNegotiatedOutbound {
protocol: (mut stream, protocol),
info: (),
}: FullyNegotiatedOutbound<
<Self as ConnectionHandler>::OutboundProtocol,
<Self as ConnectionHandler>::OutboundOpenInfo,
>,
) {
let message = self
.requested_outbound
.pop_front()
.expect("negotiated a stream without a pending message");

let mut codec = self.codec.clone();
let request_id = message.request_id;

let send = async move {
let write = codec.write_request(&protocol, &mut stream, message.request);
write.await?;
stream.close().await?;
let read = codec.read_response(&protocol, &mut stream);
let response = read.await?;

Ok(Event::Response {
request_id,
response,
})
};

if self
.worker_streams
.try_push(RequestId::Outbound(request_id), send.boxed())
.is_err()
{
tracing::warn!("Dropping outbound stream because we are at capacity")
}
}

fn on_dial_upgrade_error(
&mut self,
DialUpgradeError { error, info: () }: DialUpgradeError<
<Self as ConnectionHandler>::OutboundOpenInfo,
<Self as ConnectionHandler>::OutboundProtocol,
>,
) {
let message = self
.requested_outbound
.pop_front()
.expect("negotiated a stream without a pending message");

match error {
StreamUpgradeError::Timeout => {
self.pending_events
.push_back(Event::OutboundTimeout(message.request_id));
}
StreamUpgradeError::NegotiationFailed => {
// The remote merely doesn't support the protocol(s) we requested.
// This is no reason to close the connection, which may
// successfully communicate with other protocols already.
// An event is reported to permit user code to react to the fact that
// the remote peer does not support the requested protocol(s).
self.pending_events
.push_back(Event::OutboundUnsupportedProtocols(message.request_id));
}
StreamUpgradeError::Apply(e) => void::unreachable(e),
StreamUpgradeError::Io(e) => {
tracing::debug!(
"outbound stream for request {} failed: {e}, retrying",
message.request_id
);
self.requested_outbound.push_back(message);
}
}
}
fn on_listen_upgrade_error(
&mut self,
ListenUpgradeError { error, .. }: ListenUpgradeError<
<Self as ConnectionHandler>::InboundOpenInfo,
<Self as ConnectionHandler>::InboundProtocol,
>,
) {
void::unreachable(error)
}
}

/// The events emitted by the [`Handler`].
Expand Down Expand Up @@ -382,82 +297,129 @@ where
}

fn on_behaviour_event(&mut self, request: Self::FromBehaviour) {
self.pending_outbound.push_back(request);
let OutboundMessage {
request_id,
request,
protocols,
} = request;

self.pending_streams
.enqueue_request(protocols, (request_id, request));
}

#[tracing::instrument(level = "trace", name = "ConnectionHandler::poll", skip(self, cx))]
fn poll(
&mut self,
cx: &mut Context<'_>,
) -> Poll<ConnectionHandlerEvent<Protocol<TCodec::Protocol>, (), Self::ToBehaviour>> {
match self.worker_streams.poll_unpin(cx) {
Poll::Ready((_, Ok(Ok(event)))) => {
return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event));
}
Poll::Ready((RequestId::Inbound(id), Ok(Err(e)))) => {
return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
Event::InboundStreamFailed {
request_id: id,
error: e,
},
));
}
Poll::Ready((RequestId::Outbound(id), Ok(Err(e)))) => {
return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
Event::OutboundStreamFailed {
request_id: id,
error: e,
},
));
}
Poll::Ready((RequestId::Inbound(id), Err(futures_bounded::Timeout { .. }))) => {
return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
Event::InboundTimeout(id),
));
}
Poll::Ready((RequestId::Outbound(id), Err(futures_bounded::Timeout { .. }))) => {
return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
Event::OutboundTimeout(id),
));
loop {
match self.worker_streams.poll_unpin(cx) {
Poll::Ready((_, Ok(Ok(event)))) => {
return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event));
}
Poll::Ready((RequestId::Inbound(id), Ok(Err(e)))) => {
return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
Event::InboundStreamFailed {
request_id: id,
error: e,
},
));
}
Poll::Ready((RequestId::Outbound(id), Ok(Err(e)))) => {
return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
Event::OutboundStreamFailed {
request_id: id,
error: e,
},
));
}
Poll::Ready((RequestId::Inbound(id), Err(futures_bounded::Timeout { .. }))) => {
return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
Event::InboundTimeout(id),
));
}
Poll::Ready((RequestId::Outbound(id), Err(futures_bounded::Timeout { .. }))) => {
return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
Event::OutboundTimeout(id),
));
}
Poll::Pending => {}
}
Poll::Pending => {}
}

// Drain pending events that were produced by `worker_streams`.
if let Some(event) = self.pending_events.pop_front() {
return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event));
} else if self.pending_events.capacity() > EMPTY_QUEUE_SHRINK_THRESHOLD {
self.pending_events.shrink_to_fit();
}

// Check for inbound requests.
if let Poll::Ready(Some((id, rq, rs_sender))) = self.inbound_receiver.poll_next_unpin(cx) {
// We received an inbound request.

return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(Event::Request {
request_id: id,
request: rq,
sender: rs_sender,
}));
}
// Drain pending events that were produced by `worker_streams`.
if let Some(event) = self.pending_events.pop_front() {
return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event));
} else if self.pending_events.capacity() > EMPTY_QUEUE_SHRINK_THRESHOLD {
self.pending_events.shrink_to_fit();
}

// Emit outbound requests.
if let Some(request) = self.pending_outbound.pop_front() {
let protocols = request.protocols.clone();
self.requested_outbound.push_back(request);
// Check for inbound requests.
if let Poll::Ready(Some((id, rq, rs_sender))) =
self.inbound_receiver.poll_next_unpin(cx)
{
// We received an inbound request.

return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(Event::Request {
request_id: id,
request: rq,
sender: rs_sender,
}));
}

return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
protocol: SubstreamProtocol::new(Protocol { protocols }, ()),
});
}
match self.pending_streams.next_completed() {
Some((Ok((mut stream, protocol)), (request_id, request))) => {
let mut codec = self.codec.clone();

let send = async move {
let write = codec.write_request(&protocol, &mut stream, request);
write.await?;
stream.close().await?;
let read = codec.read_response(&protocol, &mut stream);
let response = read.await?;

Ok(Event::Response {
request_id,
response,
})
};

if self
.worker_streams
.try_push(RequestId::Outbound(request_id), send.boxed())
.is_err()
{
tracing::warn!("Dropping outbound stream because we are at capacity")
}
continue;
}
Some((Err(StreamUpgradeError::Timeout), (request_id, _))) => {
return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
Event::OutboundTimeout(request_id),
));
}
Some((Err(StreamUpgradeError::NegotiationFailed), (request_id, _))) => {
return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
Event::OutboundUnsupportedProtocols(request_id),
));
}
Some((Err(StreamUpgradeError::Io(error)), (request_id, _))) => {
return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
Event::OutboundStreamFailed { request_id, error },
));
}
Some((Err(StreamUpgradeError::Apply(void)), _)) => void::unreachable(void),
None => {}
}

debug_assert!(self.pending_outbound.is_empty());
// Emit outbound requests.
if let Some(protocols) = self.pending_streams.next_request() {
return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
protocol: SubstreamProtocol::new(Protocol { protocols }, ()),
});
}

if self.pending_outbound.capacity() > EMPTY_QUEUE_SHRINK_THRESHOLD {
self.pending_outbound.shrink_to_fit();
return Poll::Pending;
}

Poll::Pending
}

fn on_connection_event(
Expand All @@ -473,15 +435,13 @@ where
ConnectionEvent::FullyNegotiatedInbound(fully_negotiated_inbound) => {
self.on_fully_negotiated_inbound(fully_negotiated_inbound)
}
ConnectionEvent::FullyNegotiatedOutbound(fully_negotiated_outbound) => {
self.on_fully_negotiated_outbound(fully_negotiated_outbound)
}
ConnectionEvent::DialUpgradeError(dial_upgrade_error) => {
self.on_dial_upgrade_error(dial_upgrade_error)
ConnectionEvent::FullyNegotiatedOutbound(ev) => {
self.pending_streams.submit_response(Ok(ev.protocol));
}
ConnectionEvent::ListenUpgradeError(listen_upgrade_error) => {
self.on_listen_upgrade_error(listen_upgrade_error)
ConnectionEvent::DialUpgradeError(ev) => {
self.pending_streams.submit_response(Err(ev.error));
}
ConnectionEvent::ListenUpgradeError(ev) => void::unreachable(ev.error),
_ => {}
}
}
Expand Down

0 comments on commit f600aa1

Please sign in to comment.