From 837c7789610455b93035caddf0f1691b5f955287 Mon Sep 17 00:00:00 2001 From: Lucas Meurer Date: Sun, 7 Jan 2024 19:14:45 +0100 Subject: [PATCH] feat(tonic): Add client feature flag to support connect_with_connector in wasm32 targets This adds a new feature that enables compilation of connect_with_connector and the Endpoint and Channel struct for wasm. Related: #491 --- tonic/Cargo.toml | 18 +- tonic/src/lib.rs | 2 +- tonic/src/transport/channel/endpoint.rs | 4 +- tonic/src/transport/channel/mod.rs | 9 +- tonic/src/transport/mod.rs | 3 + tonic/src/transport/service/connection.rs | 67 +++++-- tonic/src/transport/service/executor.rs | 24 ++- tonic/src/transport/service/io.rs | 203 ++++++++++++---------- tonic/src/transport/service/mod.rs | 6 + 9 files changed, 216 insertions(+), 120 deletions(-) diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index 013cc6e72..b9c323e5f 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -38,9 +38,20 @@ transport = [ "channel", "dep:h2", "dep:hyper", + "hyper/full", + "dep:tokio", "tokio/net", "tokio/time", "dep:tower", + "tower/balance", + "dep:hyper-timeout", +] +client = [ + "dep:h2", + "hyper/client", + "hyper/http2", + "dep:tokio", + "dep:tower", "dep:hyper-timeout", ] channel = [] @@ -55,7 +66,7 @@ bytes = "1.0" http = "0.2" tracing = "0.1" -tokio = "1.0.1" +tokio = { version = "1.0.1", default-features = false, optional = true } http-body = "0.4.4" percent-encoding = "2.1" pin-project = "1.0.11" @@ -70,7 +81,7 @@ async-trait = {version = "0.1.13", optional = true} # transport h2 = {version = "0.3.17", optional = true} -hyper = {version = "0.14.26", features = ["full"], optional = true} +hyper = {version = "0.14.26", default-features = false, optional = true} hyper-timeout = {version = "0.4", optional = true} tokio-stream = "0.1" tower = {version = "0.4.7", default-features = false, features = ["balance", "buffer", "discover", "limit", "load", "make", "timeout", "util"], optional = true} @@ -88,6 +99,9 @@ webpki-roots = { version = "0.25.0", optional = true } flate2 = {version = "1.0", optional = true} zstd = { version = "0.12.3", optional = true } +[target.'cfg(target_arch = "wasm32")'.dependencies] +wasm-bindgen-futures = "0.4.38" + [dev-dependencies] bencher = "0.1.5" quickcheck = "1.0" diff --git a/tonic/src/lib.rs b/tonic/src/lib.rs index 8aa80a121..719f967ab 100644 --- a/tonic/src/lib.rs +++ b/tonic/src/lib.rs @@ -101,7 +101,7 @@ pub mod metadata; pub mod server; pub mod service; -#[cfg(feature = "transport")] +#[cfg(any(feature = "transport", feature = "client"))] #[cfg_attr(docsrs, doc(cfg(feature = "transport")))] pub mod transport; diff --git a/tonic/src/transport/channel/endpoint.rs b/tonic/src/transport/channel/endpoint.rs index 6aacb57a5..86e5af988 100644 --- a/tonic/src/transport/channel/endpoint.rs +++ b/tonic/src/transport/channel/endpoint.rs @@ -312,6 +312,7 @@ impl Endpoint { } /// Create a channel from this config. + #[cfg(feature = "transport")] pub async fn connect(&self) -> Result { let mut http = hyper::client::connect::HttpConnector::new(); http.enforce_http(false); @@ -333,6 +334,7 @@ impl Endpoint { /// /// The channel returned by this method does not attempt to connect to the endpoint until first /// use. + #[cfg(feature = "transport")] pub fn connect_lazy(&self) -> Channel { let mut http = hyper::client::connect::HttpConnector::new(); http.enforce_http(false); @@ -428,7 +430,7 @@ impl From for Endpoint { http2_keep_alive_while_idle: None, connect_timeout: None, http2_adaptive_window: None, - executor: SharedExec::tokio(), + executor: SharedExec::default_exec(), } } } diff --git a/tonic/src/transport/channel/mod.rs b/tonic/src/transport/channel/mod.rs index b510a6980..e33d58405 100644 --- a/tonic/src/transport/channel/mod.rs +++ b/tonic/src/transport/channel/mod.rs @@ -9,7 +9,9 @@ pub use endpoint::Endpoint; #[cfg(feature = "tls")] pub use tls::ClientTlsConfig; -use super::service::{Connection, DynamicServiceStream, SharedExec}; +use super::service::{Connection}; +#[cfg(feature = "transport")] +use super::service::{DynamicServiceStream, SharedExec}; use crate::body::BoxBody; use crate::transport::Executor; use bytes::Bytes; @@ -109,6 +111,7 @@ impl Channel { /// /// This creates a [`Channel`] that will load balance across all the /// provided endpoints. + #[cfg(feature = "transport")] pub fn balance_list(list: impl Iterator) -> Self { let (channel, tx) = Self::balance_channel(DEFAULT_BUFFER_SIZE); list.for_each(|endpoint| { @@ -122,11 +125,12 @@ impl Channel { /// Balance a list of [`Endpoint`]'s. /// /// This creates a [`Channel`] that will listen to a stream of change events and will add or remove provided endpoints. + #[cfg(feature = "transport")] pub fn balance_channel(capacity: usize) -> (Self, Sender>) where K: Hash + Eq + Send + Clone + 'static, { - Self::balance_channel_with_executor(capacity, SharedExec::tokio()) + Self::balance_channel_with_executor(capacity, SharedExec::default_exec()) } /// Balance a list of [`Endpoint`]'s. @@ -134,6 +138,7 @@ impl Channel { /// This creates a [`Channel`] that will listen to a stream of change events and will add or remove provided endpoints. /// /// The [`Channel`] will use the given executor to spawn async tasks. + #[cfg(feature = "transport")] pub fn balance_channel_with_executor( capacity: usize, executor: E, diff --git a/tonic/src/transport/mod.rs b/tonic/src/transport/mod.rs index a0435c797..ff349cac2 100644 --- a/tonic/src/transport/mod.rs +++ b/tonic/src/transport/mod.rs @@ -88,6 +88,7 @@ //! [rustls]: https://docs.rs/rustls/0.16.0/rustls/ pub mod channel; +#[cfg(feature = "transport")] pub mod server; mod error; @@ -100,6 +101,7 @@ mod tls; #[cfg_attr(docsrs, doc(cfg(feature = "channel")))] pub use self::channel::{Channel, Endpoint}; pub use self::error::Error; +#[cfg(feature = "transport")] #[doc(inline)] pub use self::server::Server; #[doc(inline)] @@ -107,6 +109,7 @@ pub use self::service::grpc_timeout::TimeoutExpired; #[cfg(feature = "tls")] #[cfg_attr(docsrs, doc(cfg(feature = "tls")))] pub use self::tls::Certificate; +#[cfg(feature = "transport")] pub use axum::{body::BoxBody as AxumBoxBody, Router as AxumRouter}; pub use hyper::{Body, Uri}; diff --git a/tonic/src/transport/service/connection.rs b/tonic/src/transport/service/connection.rs index 46a88dda5..6bd4b669b 100644 --- a/tonic/src/transport/service/connection.rs +++ b/tonic/src/transport/service/connection.rs @@ -1,26 +1,29 @@ -use super::{grpc_timeout::GrpcTimeout, reconnect::Reconnect, AddOrigin, UserAgent}; -use crate::{ - body::BoxBody, - transport::{BoxFuture, Endpoint}, +use std::{ + fmt, + task::{Context, Poll}, }; + use http::Uri; use hyper::client::conn::Builder; use hyper::client::connect::Connection as HyperConnection; use hyper::client::service::Connect as HyperConnect; -use std::{ - fmt, - task::{Context, Poll}, -}; use tokio::io::{AsyncRead, AsyncWrite}; -use tower::load::Load; use tower::{ layer::Layer, limit::{concurrency::ConcurrencyLimitLayer, rate::RateLimitLayer}, - util::BoxService, - ServiceBuilder, ServiceExt, + ServiceBuilder, + ServiceExt, util::BoxService, }; +use tower::load::Load; use tower_service::Service; +use crate::{ + body::BoxBody, + transport::{BoxFuture, Endpoint}, +}; + +use super::{AddOrigin, grpc_timeout::GrpcTimeout, reconnect::Reconnect, UserAgent}; + pub(crate) type Request = http::Request; pub(crate) type Response = http::Response; @@ -40,20 +43,32 @@ impl Connection { .http2_initial_stream_window_size(endpoint.init_stream_window_size) .http2_initial_connection_window_size(endpoint.init_connection_window_size) .http2_only(true) - .http2_keep_alive_interval(endpoint.http2_keep_alive_interval) .executor(endpoint.executor.clone()) .clone(); - if let Some(val) = endpoint.http2_keep_alive_timeout { - settings.http2_keep_alive_timeout(val); + if let Some(val) = endpoint.http2_adaptive_window { + settings.http2_adaptive_window(val); } - if let Some(val) = endpoint.http2_keep_alive_while_idle { - settings.http2_keep_alive_while_idle(val); + #[cfg(feature = "transport")] + { + settings + .http2_keep_alive_interval(endpoint.http2_keep_alive_interval); + + if let Some(val) = endpoint.http2_keep_alive_timeout { + settings.http2_keep_alive_timeout(val); + } + + if let Some(val) = endpoint.http2_keep_alive_while_idle { + settings.http2_keep_alive_while_idle(val); + } } - if let Some(val) = endpoint.http2_adaptive_window { - settings.http2_adaptive_window(val); + #[cfg(target_arch = "wasm32")] + { + settings.executor(wasm::Executor) + // reset streams require `Instant::now` which is not available on wasm + .http2_max_concurrent_reset_streams(0); } let stack = ServiceBuilder::new() @@ -126,3 +141,19 @@ impl fmt::Debug for Connection { f.debug_struct("Connection").finish() } } + +#[cfg(target_arch = "wasm32")] +mod wasm { + use std::future::Future; + use std::pin::Pin; + + type BoxSendFuture = Pin + Send>>; + + pub(crate) struct Executor; + + impl hyper::rt::Executor for Executor { + fn execute(&self, fut: BoxSendFuture) { + wasm_bindgen_futures::spawn_local(fut) + } + } +} \ No newline at end of file diff --git a/tonic/src/transport/service/executor.rs b/tonic/src/transport/service/executor.rs index de3cfbe6e..1cce1d36a 100644 --- a/tonic/src/transport/service/executor.rs +++ b/tonic/src/transport/service/executor.rs @@ -3,9 +3,11 @@ use std::{future::Future, sync::Arc}; pub(crate) use hyper::rt::Executor; +#[cfg(not(target_arch = "wasm32"))] #[derive(Copy, Clone)] struct TokioExec; +#[cfg(not(target_arch = "wasm32"))] impl Executor for TokioExec where F: Future + Send + 'static, @@ -16,6 +18,21 @@ where } } +#[cfg(target_arch = "wasm32")] +#[derive(Copy, Clone)] +struct WasmBindgenExec; + +#[cfg(target_arch = "wasm32")] +impl Executor for WasmBindgenExec +where + F: Future + 'static, + F::Output: 'static, +{ + fn execute(&self, fut: F) { + wasm_bindgen_futures::spawn_local(async move {fut.await;}); + } +} + #[derive(Clone)] pub(crate) struct SharedExec { inner: Arc> + Send + Sync + 'static>, @@ -31,8 +48,11 @@ impl SharedExec { } } - pub(crate) fn tokio() -> Self { - Self::new(TokioExec) + pub(crate) fn default_exec() -> Self { + #[cfg(not(target_arch = "wasm32"))] + return Self::new(TokioExec); + #[cfg(target_arch = "wasm32")] + Self::new(WasmBindgenExec) } } diff --git a/tonic/src/transport/service/io.rs b/tonic/src/transport/service/io.rs index 2230b9b2e..360979721 100644 --- a/tonic/src/transport/service/io.rs +++ b/tonic/src/transport/service/io.rs @@ -1,12 +1,13 @@ -use crate::transport::server::Connected; use hyper::client::connect::{Connected as HyperConnected, Connection}; use std::io; use std::io::IoSlice; use std::pin::Pin; use std::task::{Context, Poll}; + use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; -#[cfg(feature = "tls")] -use tokio_rustls::server::TlsStream; + +#[cfg(feature = "transport")] +pub(crate) use server::ServerIo; pub(in crate::transport) trait Io: AsyncRead + AsyncWrite + Send + 'static @@ -29,7 +30,8 @@ impl Connection for BoxedIo { } } -impl Connected for BoxedIo { +#[cfg(feature = "transport")] +impl crate::transport::server::Connected for BoxedIo { type ConnectInfo = NoneConnectInfo; fn connect_info(&self) -> Self::ConnectInfo { @@ -80,120 +82,133 @@ impl AsyncWrite for BoxedIo { } } -pub(crate) enum ServerIo { - Io(IO), +#[cfg(feature = "transport")] +mod server { + use tower::util::Either; + use crate::transport::server::Connected; + use std::io; + use std::io::IoSlice; + use std::pin::Pin; + use std::task::{Context, Poll}; + #[cfg(feature = "tls")] - TlsIo(Box>), -} + use tokio_rustls::server::TlsStream; + + use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; -use tower::util::Either; + pub(crate) enum ServerIo { + Io(IO), + #[cfg(feature = "tls")] + TlsIo(Box>), + } -#[cfg(feature = "tls")] -type ServerIoConnectInfo = + #[cfg(feature = "tls")] + type ServerIoConnectInfo = Either<::ConnectInfo, as Connected>::ConnectInfo>; -#[cfg(not(feature = "tls"))] -type ServerIoConnectInfo = Either<::ConnectInfo, ()>; + #[cfg(not(feature = "tls"))] + type ServerIoConnectInfo = Either<::ConnectInfo, ()>; -impl ServerIo { - pub(in crate::transport) fn new_io(io: IO) -> Self { - Self::Io(io) - } + impl ServerIo { + pub(in crate::transport) fn new_io(io: IO) -> Self { + Self::Io(io) + } - #[cfg(feature = "tls")] - pub(in crate::transport) fn new_tls_io(io: TlsStream) -> Self { - Self::TlsIo(Box::new(io)) - } + #[cfg(feature = "tls")] + pub(in crate::transport) fn new_tls_io(io: TlsStream) -> Self { + Self::TlsIo(Box::new(io)) + } - #[cfg(feature = "tls")] - pub(in crate::transport) fn connect_info(&self) -> ServerIoConnectInfo - where - IO: Connected, - TlsStream: Connected, - { - match self { - Self::Io(io) => Either::A(io.connect_info()), - Self::TlsIo(io) => Either::B(io.connect_info()), + #[cfg(feature = "tls")] + pub(in crate::transport) fn connect_info(&self) -> ServerIoConnectInfo + where + IO: Connected, + TlsStream: Connected, + { + match self { + Self::Io(io) => Either::A(io.connect_info()), + Self::TlsIo(io) => Either::B(io.connect_info()), + } } - } - #[cfg(not(feature = "tls"))] - pub(in crate::transport) fn connect_info(&self) -> ServerIoConnectInfo - where - IO: Connected, - { - match self { - Self::Io(io) => Either::A(io.connect_info()), + #[cfg(not(feature = "tls"))] + pub(in crate::transport) fn connect_info(&self) -> ServerIoConnectInfo + where + IO: Connected, + { + match self { + Self::Io(io) => Either::A(io.connect_info()), + } } } -} -impl AsyncRead for ServerIo -where - IO: AsyncWrite + AsyncRead + Unpin, -{ - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - match &mut *self { - Self::Io(io) => Pin::new(io).poll_read(cx, buf), - #[cfg(feature = "tls")] - Self::TlsIo(io) => Pin::new(io).poll_read(cx, buf), + impl AsyncRead for ServerIo + where + IO: AsyncWrite + AsyncRead + Unpin, + { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match &mut *self { + Self::Io(io) => Pin::new(io).poll_read(cx, buf), + #[cfg(feature = "tls")] + Self::TlsIo(io) => Pin::new(io).poll_read(cx, buf), + } } } -} -impl AsyncWrite for ServerIo -where - IO: AsyncWrite + AsyncRead + Unpin, -{ - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - match &mut *self { - Self::Io(io) => Pin::new(io).poll_write(cx, buf), - #[cfg(feature = "tls")] - Self::TlsIo(io) => Pin::new(io).poll_write(cx, buf), + impl AsyncWrite for ServerIo + where + IO: AsyncWrite + AsyncRead + Unpin, + { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match &mut *self { + Self::Io(io) => Pin::new(io).poll_write(cx, buf), + #[cfg(feature = "tls")] + Self::TlsIo(io) => Pin::new(io).poll_write(cx, buf), + } } - } - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match &mut *self { - Self::Io(io) => Pin::new(io).poll_flush(cx), - #[cfg(feature = "tls")] - Self::TlsIo(io) => Pin::new(io).poll_flush(cx), + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match &mut *self { + Self::Io(io) => Pin::new(io).poll_flush(cx), + #[cfg(feature = "tls")] + Self::TlsIo(io) => Pin::new(io).poll_flush(cx), + } } - } - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match &mut *self { - Self::Io(io) => Pin::new(io).poll_shutdown(cx), - #[cfg(feature = "tls")] - Self::TlsIo(io) => Pin::new(io).poll_shutdown(cx), + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match &mut *self { + Self::Io(io) => Pin::new(io).poll_shutdown(cx), + #[cfg(feature = "tls")] + Self::TlsIo(io) => Pin::new(io).poll_shutdown(cx), + } } - } - fn poll_write_vectored( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[IoSlice<'_>], - ) -> Poll> { - match &mut *self { - Self::Io(io) => Pin::new(io).poll_write_vectored(cx, bufs), - #[cfg(feature = "tls")] - Self::TlsIo(io) => Pin::new(io).poll_write_vectored(cx, bufs), + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + match &mut *self { + Self::Io(io) => Pin::new(io).poll_write_vectored(cx, bufs), + #[cfg(feature = "tls")] + Self::TlsIo(io) => Pin::new(io).poll_write_vectored(cx, bufs), + } } - } - fn is_write_vectored(&self) -> bool { - match self { - Self::Io(io) => io.is_write_vectored(), - #[cfg(feature = "tls")] - Self::TlsIo(io) => io.is_write_vectored(), + fn is_write_vectored(&self) -> bool { + match self { + Self::Io(io) => io.is_write_vectored(), + #[cfg(feature = "tls")] + Self::TlsIo(io) => io.is_write_vectored(), + } } } } diff --git a/tonic/src/transport/service/mod.rs b/tonic/src/transport/service/mod.rs index 69d850f10..0deb67db3 100644 --- a/tonic/src/transport/service/mod.rs +++ b/tonic/src/transport/service/mod.rs @@ -1,11 +1,13 @@ mod add_origin; mod connection; mod connector; +#[cfg(feature = "transport")] mod discover; pub(crate) mod executor; pub(crate) mod grpc_timeout; mod io; mod reconnect; +#[cfg(feature = "transport")] mod router; #[cfg(feature = "tls")] mod tls; @@ -14,13 +16,17 @@ mod user_agent; pub(crate) use self::add_origin::AddOrigin; pub(crate) use self::connection::Connection; pub(crate) use self::connector::Connector; +#[cfg(feature = "transport")] pub(crate) use self::discover::DynamicServiceStream; pub(crate) use self::executor::SharedExec; pub(crate) use self::grpc_timeout::GrpcTimeout; +#[cfg(feature = "transport")] pub(crate) use self::io::ServerIo; #[cfg(feature = "tls")] pub(crate) use self::tls::{TlsAcceptor, TlsConnector}; pub(crate) use self::user_agent::UserAgent; +#[cfg(feature = "transport")] pub use self::router::Routes; +#[cfg(feature = "transport")] pub use self::router::RoutesBuilder;