Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into na-subscription-des…
Browse files Browse the repository at this point in the history
…ign2
  • Loading branch information
dvdplm committed May 18, 2021
2 parents 8605f63 + 8780fce commit 531e805
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 13 deletions.
2 changes: 2 additions & 0 deletions ws-client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ pin-project = "1"
thiserror = "1"
url = "2"
webpki = { version = "0.22", features = ["std"] }
rustls = "0.19.1"
rustls-native-certs = "0.5.0"

[dev-dependencies]
env_logger = "0.8"
Expand Down
22 changes: 18 additions & 4 deletions ws-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,20 @@
// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.

use crate::helpers::{
build_unsubscribe_message, process_batch_response, process_error_response, process_notification,
process_single_response, process_subscription_response, stop_subscription,
};
use crate::traits::{Client, SubscriptionClient};
use crate::transport::{parse_url, Receiver as WsReceiver, Sender as WsSender, WsTransportClientBuilder};
use crate::v2::error::JsonRpcErrorAlloc;
use crate::v2::params::{Id, JsonRpcParams};
use crate::v2::request::{JsonRpcCallSer, JsonRpcNotificationSer};
use crate::v2::response::{JsonRpcNotifResponse, JsonRpcResponse, JsonRpcSubscriptionResponse};
use crate::TEN_MB_SIZE_BYTES;
use crate::{
helpers::{
build_unsubscribe_message, process_batch_response, process_error_response, process_notification,
process_single_response, process_subscription_response, stop_subscription,
},
transport::CertificateStore,
};
use crate::{
manager::RequestManager, BatchMessage, Error, FrontToBack, NotificationHandler, RegisterNotificationMessage,
RequestMessage, Subscription, SubscriptionMessage,
Expand All @@ -47,6 +50,7 @@ use futures::{
prelude::*,
sink::SinkExt,
};

use serde::de::DeserializeOwned;
use std::{
borrow::Cow,
Expand Down Expand Up @@ -167,6 +171,7 @@ impl RequestIdGuard {
/// Configuration.
#[derive(Clone, Debug)]
pub struct WsClientBuilder<'a> {
certificate_store: CertificateStore,
max_request_body_size: u32,
request_timeout: Option<Duration>,
connection_timeout: Duration,
Expand All @@ -179,6 +184,7 @@ pub struct WsClientBuilder<'a> {
impl<'a> Default for WsClientBuilder<'a> {
fn default() -> Self {
Self {
certificate_store: CertificateStore::Native,
max_request_body_size: TEN_MB_SIZE_BYTES,
request_timeout: None,
connection_timeout: Duration::from_secs(10),
Expand All @@ -191,6 +197,12 @@ impl<'a> Default for WsClientBuilder<'a> {
}

impl<'a> WsClientBuilder<'a> {
/// Set wheather to use system certificates
pub fn certificate_store(mut self, certificate_store: CertificateStore) -> Self {
self.certificate_store = certificate_store;
self
}

/// Set max request body size.
pub fn max_request_body_size(mut self, size: u32) -> Self {
self.max_request_body_size = size;
Expand Down Expand Up @@ -248,6 +260,7 @@ impl<'a> WsClientBuilder<'a> {
///
/// `wss://host` - port 443 is used
pub async fn build(self, url: &'a str) -> Result<WsClient, Error> {
let certificate_store = self.certificate_store;
let max_capacity_per_subscription = self.max_notifs_per_subscription;
let max_concurrent_requests = self.max_concurrent_requests;
let request_timeout = self.request_timeout;
Expand All @@ -257,6 +270,7 @@ impl<'a> WsClientBuilder<'a> {
let (sockaddrs, host, mode) = parse_url(url).map_err(|e| Error::Transport(Box::new(e)))?;

let builder = WsTransportClientBuilder {
certificate_store,
sockaddrs,
mode,
host,
Expand Down
4 changes: 2 additions & 2 deletions ws-client/src/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,10 @@ impl RequestManager {

/// Removes a notification handler
pub fn remove_notification_handler(&mut self, method: String) -> Result<(), Error> {
if let Some(_) = self.notification_handlers.remove(&method) {
if self.notification_handlers.remove(&method).is_some() {
Ok(())
} else {
Err(Error::UnregisteredNotification(method.to_owned()))
Err(Error::UnregisteredNotification(method))
}
}

Expand Down
45 changes: 38 additions & 7 deletions ws-client/src/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ pub struct Receiver {
/// Builder for a WebSocket transport [`Sender`] and ['Receiver`] pair.
#[derive(Debug)]
pub struct WsTransportClientBuilder<'a> {
/// What certificate store to use
pub certificate_store: CertificateStore,
/// Socket addresses to try to connect to.
pub sockaddrs: Vec<SocketAddr>,
/// Host.
Expand All @@ -87,6 +89,16 @@ pub enum Mode {
Tls,
}

/// What certificate store to use
#[derive(Clone, Copy, Debug, PartialEq)]
#[non_exhaustive]
pub enum CertificateStore {
/// Use the native system certificate store
Native,
/// Use webPki's certificate store
WebPki,
}

/// Error that can happen during the initial handshake.
#[derive(Debug, Error)]
pub enum WsNewError {
Expand Down Expand Up @@ -117,6 +129,10 @@ pub enum WsNewError {
/// Error that can happen during the initial handshake.
#[derive(Debug, Error)]
pub enum WsHandshakeError {
/// Failed to load system certs
#[error("Failed to load system certs: {}", 0)]
CertificateStore(io::Error),

/// Invalid URL.
#[error("Invalid url: {}", 0)]
Url(Cow<'static, str>),
Expand Down Expand Up @@ -199,10 +215,22 @@ impl<'a> WsTransportClientBuilder<'a> {
self
}

/// Try establish the connection.
/// Try to establish the connection.
pub async fn build(self) -> Result<(Sender, Receiver), WsHandshakeError> {
let connector = match self.mode {
Mode::Tls => {
let mut client_config = rustls::ClientConfig::default();
if let CertificateStore::Native = self.certificate_store {
client_config.root_store = rustls_native_certs::load_native_certs()
.map_err(|(_, e)| WsHandshakeError::CertificateStore(e))?;
}
Some(client_config.into())
}
Mode::Plain => None,
};

for sockaddr in &self.sockaddrs {
match self.try_connect(*sockaddr).await {
match self.try_connect(*sockaddr, &connector).await {
Ok(res) => return Ok(res),
Err(e) => {
log::debug!("Failed to connect to sockaddr: {:?} with err: {:?}", sockaddr, e);
Expand All @@ -212,7 +240,11 @@ impl<'a> WsTransportClientBuilder<'a> {
Err(WsHandshakeError::NoAddressFound)
}

async fn try_connect(&self, sockaddr: SocketAddr) -> Result<(Sender, Receiver), WsNewError> {
async fn try_connect(
&self,
sockaddr: SocketAddr,
tls_connector: &Option<async_tls::TlsConnector>,
) -> Result<(Sender, Receiver), WsNewError> {
// Try establish the TCP connection.
let tcp_stream = {
let socket = TcpStream::connect(sockaddr);
Expand All @@ -224,10 +256,9 @@ impl<'a> WsTransportClientBuilder<'a> {
if let Err(err) = socket.set_nodelay(true) {
log::warn!("set nodelay failed: {:?}", err);
}
match self.mode {
Mode::Plain => TlsOrPlain::Plain(socket),
Mode::Tls => {
let connector = async_tls::TlsConnector::default();
match tls_connector {
None => TlsOrPlain::Plain(socket),
Some(connector) => {
let dns_name: &str = webpki::DnsNameRef::try_from_ascii_str(self.host.as_str())?.into();
let tls_stream = connector.connect(dns_name, socket).await?;
TlsOrPlain::Tls(tls_stream)
Expand Down
2 changes: 2 additions & 0 deletions ws-server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ async fn background_task(
// Send results back to the client.
tokio::spawn(async move {
while let Some(response) = rx.next().await {
log::debug!("send: {}", response);
let _ = sender.send_binary_mut(response.into_bytes()).await;
let _ = sender.flush().await;
}
Expand Down Expand Up @@ -290,6 +291,7 @@ async fn background_task(
// worst case – unparseable input – we make three calls to [`serde_json::from_slice`] which is pretty annoying.
// Our [issue](https://github.com/paritytech/jsonrpsee/issues/296).
if let Ok(req) = serde_json::from_slice::<JsonRpcRequest>(&data) {
log::debug!("recv: {:?}", req);
execute(&tx, req);
} else if let Ok(batch) = serde_json::from_slice::<Vec<JsonRpcRequest>>(&data) {
if !batch.is_empty() {
Expand Down

0 comments on commit 531e805

Please sign in to comment.