Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for ping pong in the WebSocket transport #646

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.toml
Expand Up @@ -77,8 +77,8 @@ http-tls = ["http", "reqwest/default-tls"]
http-native-tls = ["http", "reqwest/native-tls"]
http-rustls-tls = ["http", "reqwest/rustls-tls"]
signing = ["secp256k1", "once_cell"]
ws-tokio = ["soketto", "url", "tokio", "tokio-util", "headers"]
ws-async-std = ["soketto", "url", "async-std", "headers"]
ws-tokio = ["soketto", "url", "tokio", "tokio-util", "headers", "tokio-stream"]
ws-async-std = ["soketto", "url", "async-std", "headers", "tokio-stream"]
ws-tls-tokio = ["async-native-tls", "async-native-tls/runtime-tokio", "ws-tokio"]
ws-tls-async-std = ["async-native-tls", "async-native-tls/runtime-async-std", "ws-async-std"]
ipc-tokio = ["tokio", "tokio-stream", "tokio-util"]
Expand Down
14 changes: 13 additions & 1 deletion src/helpers.rs
Expand Up @@ -6,7 +6,7 @@ use futures::{
Future,
};
use pin_project::pin_project;
use serde::de::DeserializeOwned;
use serde::{de::DeserializeOwned, Deserialize};
use std::{marker::PhantomData, pin::Pin};

/// Takes any type which is deserializable from rpc::Value and such a value and
Expand Down Expand Up @@ -87,6 +87,18 @@ where
}
}

/// Extract the response id from slice. Used to obtain the response id if the deserialization of the whole response fails,
/// workraround for https://github.com/tomusdrw/rust-web3/issues/566
pub fn response_id_from_slice(response: &[u8]) -> Option<rpc::Id> {
#[derive(Deserialize)]
struct JustId {
id: rpc::Id,
}

let value: JustId = serde_json::from_slice(response).ok()?;
Some(value.id)
}

/// Parse bytes slice into JSON-RPC notification.
pub fn to_notification_from_slice(notification: &[u8]) -> error::Result<rpc::Notification> {
serde_json::from_slice(notification).map_err(|e| error::Error::InvalidResponse(format!("{:?}", e)))
Expand Down
148 changes: 111 additions & 37 deletions src/transports/ws.rs
Expand Up @@ -8,6 +8,7 @@ use crate::{
};
use futures::{
channel::{mpsc, oneshot},
future,
task::{Context, Poll},
AsyncRead, AsyncWrite, Future, FutureExt, Stream, StreamExt,
};
Expand All @@ -17,10 +18,12 @@ use soketto::{
};
use std::{
collections::BTreeMap,
convert::TryInto,
fmt,
marker::Unpin,
pin::Pin,
sync::{atomic, Arc},
time::{Duration, Instant},
};
use url::Url;

Expand All @@ -41,6 +44,8 @@ type BatchResult = error::Result<Vec<SingleResult>>;
type Pending = oneshot::Sender<BatchResult>;
type Subscription = mpsc::UnboundedSender<rpc::Value>;

const PING_PONG_INTERVAL: Duration = Duration::from_secs(20);

/// Stream, either plain TCP or TLS.
enum MaybeTlsStream<P, T> {
/// Unencrypted socket stream.
Expand Down Expand Up @@ -95,6 +100,27 @@ struct WsServerTask {
subscriptions: BTreeMap<SubscriptionId, Subscription>,
sender: connection::Sender<MaybeTlsStream<TcpStream, TlsStream>>,
receiver: connection::Receiver<MaybeTlsStream<TcpStream, TlsStream>>,
ping_pong_interval: Option<Duration>,
}

#[cfg(target_arch = "wasm32")]
fn interval_stream(interval: Option<Duration>) -> impl Stream<Item = Instant> {
if interval.is_some() {
log::warn!("Ignoring the ping pong interval, feature unsupported on wasm32");
}
future::pending().into_stream()
}

#[cfg(not(target_arch = "wasm32"))]
fn interval_stream(interval: Option<Duration>) -> impl Stream<Item = Instant> {
use tokio::time;
use tokio_stream::wrappers::IntervalStream;
if let Some(interval) = interval {
let interval = time::interval(interval);
IntervalStream::new(interval).map(|instant| instant.into_std()).boxed()
} else {
future::pending().into_stream().boxed()
}
}

impl WsServerTask {
Expand Down Expand Up @@ -187,6 +213,7 @@ impl WsServerTask {
subscriptions: Default::default(),
sender,
receiver,
ping_pong_interval: PING_PONG_INTERVAL.into(),
})
}

Expand All @@ -196,8 +223,11 @@ impl WsServerTask {
mut sender,
mut pending,
mut subscriptions,
ping_pong_interval,
} = self;

let mut ping_pong_interval = interval_stream(ping_pong_interval);

let receiver = as_data_stream(receiver).fuse();
let requests = requests.fuse();
pin_mut!(receiver);
Expand Down Expand Up @@ -239,6 +269,13 @@ impl WsServerTask {
},
None => break,
},
_ = ping_pong_interval.next().fuse() => {
log::trace!("Pinging the WS connection");
let data = [].as_slice().try_into().unwrap();
if let Err(e) = sender.send_ping(data).await {
log::error!("Sending ping failed: {}", e);
}
}
complete => break,
}
}
Expand All @@ -257,56 +294,93 @@ fn as_data_stream<T: Unpin + futures::AsyncRead + futures::AsyncWrite>(
})
}

enum ParsedMessage {
/// Represents a JSON-RPC notification
Notification(rpc::Notification),
/// Represents a valid JSON-RPC response
Response(rpc::Response),
/// Represents an invalid JSON-RPC response
InvalidResponse(rpc::Id),
}

fn parse_message(data: &[u8]) -> Option<ParsedMessage> {
if let Ok(notification) = helpers::to_notification_from_slice(data) {
Some(ParsedMessage::Notification(notification))
} else if let Ok(response) = helpers::to_response_from_slice(data) {
Some(ParsedMessage::Response(response))
} else if let Some(id) = helpers::response_id_from_slice(data) {
Some(ParsedMessage::InvalidResponse(id))
} else {
None
}
}

fn respond(id: rpc::Id, outputs: Result<Vec<rpc::Output>, Error>, pending: &mut BTreeMap<RequestId, Pending>) {
if let rpc::Id::Num(num) = id {
if let Some(request) = pending.remove(&(num as usize)) {
log::trace!("Responding to (id: {:?}) with {:?}", num, outputs);
let response = outputs.and_then(helpers::to_results_from_outputs);
if let Err(err) = request.send(response) {
log::warn!("Sending a response to deallocated channel: {:?}", err);
}
} else {
log::warn!("Got response for unknown request (id: {:?})", num);
}
} else {
log::warn!("Got unsupported response (id: {:?})", id);
}
}

fn handle_message(
data: &[u8],
subscriptions: &BTreeMap<SubscriptionId, Subscription>,
pending: &mut BTreeMap<RequestId, Pending>,
) {
log::trace!("Message received: {:?}", data);
if let Ok(notification) = helpers::to_notification_from_slice(data) {
if let rpc::Params::Map(params) = notification.params {
let id = params.get("subscription");
let result = params.get("result");

if let (Some(&rpc::Value::String(ref id)), Some(result)) = (id, result) {
let id: SubscriptionId = id.clone().into();
if let Some(stream) = subscriptions.get(&id) {
if let Err(e) = stream.unbounded_send(result.clone()) {
log::error!("Error sending notification: {:?} (id: {:?}", e, id);
log::trace!("Message received: {:?}", String::from_utf8_lossy(data));
match parse_message(data) {
Some(ParsedMessage::Notification(notification)) => {
if let rpc::Params::Map(params) = notification.params {
let id = params.get("subscription");
let result = params.get("result");
log::debug!("params={:#?}", params);

if let (Some(&rpc::Value::String(ref id)), Some(result)) = (id, result) {
let id: SubscriptionId = id.clone().into();
log::debug!("subscriptions={:#?}", subscriptions);

if let Some(stream) = subscriptions.get(&id) {
if let Err(e) = stream.unbounded_send(result.clone()) {
log::error!("Error sending notification: {:?} (id: {:?}", e, id);
}
} else {
log::warn!("Got notification for unknown subscription (id: {:?})", id);
}
} else {
log::warn!("Got notification for unknown subscription (id: {:?})", id);
log::error!("Got unsupported notification (id: {:?})", id);
}
} else {
log::error!("Got unsupported notification (id: {:?})", id);
}
}
} else {
let response = helpers::to_response_from_slice(data);
let outputs = match response {
Ok(rpc::Response::Single(output)) => vec![output],
Ok(rpc::Response::Batch(outputs)) => outputs,
_ => vec![],
};
Some(ParsedMessage::Response(response)) => {
let outputs = match response {
rpc::Response::Single(output) => vec![output],
rpc::Response::Batch(outputs) => outputs,
};

let id = match outputs.get(0) {
Some(&rpc::Output::Success(ref success)) => success.id.clone(),
Some(&rpc::Output::Failure(ref failure)) => failure.id.clone(),
None => rpc::Id::Num(0),
};
let id = match outputs.get(0).unwrap() {
&rpc::Output::Success(ref success) => success.id.clone(),
&rpc::Output::Failure(ref failure) => failure.id.clone(),
};

if let rpc::Id::Num(num) = id {
if let Some(request) = pending.remove(&(num as usize)) {
log::trace!("Responding to (id: {:?}) with {:?}", num, outputs);
if let Err(err) = request.send(helpers::to_results_from_outputs(outputs)) {
log::warn!("Sending a response to deallocated channel: {:?}", err);
}
} else {
log::warn!("Got response for unknown request (id: {:?})", num);
}
} else {
log::warn!("Got unsupported response (id: {:?})", id);
respond(id, Ok(outputs), pending);
}
Some(ParsedMessage::InvalidResponse(id)) => {
let error = Error::Decoder(String::from_utf8_lossy(data).to_string());
respond(id, Err(error), pending);
}
None => log::warn!(
"Got invalid response, which could not be parsed: {}",
String::from_utf8_lossy(data)
),
}
}

Expand Down