From adbb7756349762089050b9a9fb0aef0cff5cb50e Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 04:13:48 +0000 Subject: [PATCH 01/25] Upgrade to hyper 1 and http 1 Upgrades only in Cargo.toml Co-authored-by: Ivan Krivosheev Co-authored-by: Allan Zhang --- examples/Cargo.toml | 16 +++++++++------- interop/Cargo.toml | 6 +++--- tests/compression/Cargo.toml | 8 +++++--- tests/integration_tests/Cargo.toml | 7 ++++--- tonic-web/Cargo.toml | 6 +++--- tonic-web/tests/integration/Cargo.toml | 5 ++++- tonic/Cargo.toml | 17 ++++++++++------- 7 files changed, 38 insertions(+), 27 deletions(-) diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 57d05d3e3..a672287d8 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -311,16 +311,18 @@ serde_json = { version = "1.0", optional = true } tracing = { version = "0.1.16", optional = true } tracing-subscriber = { version = "0.3", features = ["tracing-log", "fmt"], optional = true } prost-types = { version = "0.12", optional = true } -http = { version = "0.2", optional = true } -http-body = { version = "0.4.2", optional = true } -hyper = { version = "0.14", optional = true } +http = { version = "1", optional = true } +http-body = { version = "1", optional = true } +http-body-util = { version = "0.1", optional = true } +hyper = { version = "1", optional = true } +hyper-util = { version = "0.1", optional = true } listenfd = { version = "1.0", optional = true } bytes = { version = "1", optional = true } h2 = { version = "0.3", optional = true } -tokio-rustls = { version = "0.24.0", optional = true } -hyper-rustls = { version = "0.24.0", features = ["http2"], optional = true } -rustls-pemfile = { version = "1", optional = true } -tower-http = { version = "0.4", optional = true } +tokio-rustls = { version = "0.26", optional = true, features = ["ring", "tls12"], default-features = false } +hyper-rustls = { version = "0.27.0", features = ["http2", "ring", "tls12"], optional = true, default-features = false } +rustls-pemfile = { version = "2.0.0", optional = true } +tower-http = { version = "0.5", optional = true } [build-dependencies] tonic-build = { path = "../tonic-build", features = ["prost"] } diff --git a/interop/Cargo.toml b/interop/Cargo.toml index a58ef64cf..9a32b2a1d 100644 --- a/interop/Cargo.toml +++ b/interop/Cargo.toml @@ -19,9 +19,9 @@ async-stream = "0.3" strum = {version = "0.26", features = ["derive"]} pico-args = {version = "0.5", features = ["eq-separator"]} console = "0.15" -http = "0.2" -http-body = "0.4.2" -hyper = "0.14" +http = "1" +http-body = "1" +hyper = "1" prost = "0.12" tokio = {version = "1.0", features = ["rt-multi-thread", "time", "macros"]} tokio-stream = "0.1" diff --git a/tests/compression/Cargo.toml b/tests/compression/Cargo.toml index 5bc87c829..4ba549cdc 100644 --- a/tests/compression/Cargo.toml +++ b/tests/compression/Cargo.toml @@ -8,9 +8,11 @@ version = "0.1.0" [dependencies] bytes = "1" -http = "0.2" -http-body = "0.4" -hyper = "0.14.3" +http = "1" +http-body = "1" +http-body-util = "0.1" +hyper = "1" +hyper-util = "0.1" paste = "1.0.12" pin-project = "1.0" prost = "0.12" diff --git a/tests/integration_tests/Cargo.toml b/tests/integration_tests/Cargo.toml index 222d1919c..6a7ec8052 100644 --- a/tests/integration_tests/Cargo.toml +++ b/tests/integration_tests/Cargo.toml @@ -17,9 +17,10 @@ tracing-subscriber = {version = "0.3"} [dev-dependencies] async-stream = "0.3" -http = "0.2" -http-body = "0.4" -hyper = "0.14" +http = "1" +http-body = "1" +hyper = "1" +hyper-util = "0.1" tokio-stream = {version = "0.1.5", features = ["net"]} tower = {version = "0.4", features = []} tower-http = { version = "0.4", features = ["set-header", "trace"] } diff --git a/tonic-web/Cargo.toml b/tonic-web/Cargo.toml index d6649f65c..5813fd6a3 100644 --- a/tonic-web/Cargo.toml +++ b/tonic-web/Cargo.toml @@ -18,9 +18,9 @@ version = "0.11.0" base64 = "0.22" bytes = "1" tokio-stream = "0.1" -http = "0.2" -http-body = "0.4" -hyper = {version = "0.14", default-features = false, features = ["stream"]} +http = "1" +http-body = "1" +http-body-util = "0.1" pin-project = "1" tonic = {version = "0.11", path = "../tonic", default-features = false} tower-service = "0.3" diff --git a/tonic-web/tests/integration/Cargo.toml b/tonic-web/tests/integration/Cargo.toml index 5c6d5727e..38fd9ff32 100644 --- a/tonic-web/tests/integration/Cargo.toml +++ b/tonic-web/tests/integration/Cargo.toml @@ -9,7 +9,10 @@ license = "MIT" [dependencies] base64 = "0.22" bytes = "1.0" -hyper = "0.14" +http-body = "1" +http-body-util = "0.1" +hyper = "1" +hyper-util = "0.1" prost = "0.12" tokio = { version = "1", features = ["macros", "rt", "net"] } tokio-stream = { version = "0.1", features = ["net"] } diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index 0d2be669f..1934c416e 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -51,10 +51,11 @@ channel = [] [dependencies] base64 = "0.22" bytes = "1.0" -http = "0.2" +http = "1" tracing = "0.1" -http-body = "0.4.4" +http-body = "1" +http-body-util = "0.1" percent-encoding = "2.1" pin-project = "1.0.11" tower-layer = "0.3" @@ -68,11 +69,13 @@ async-trait = {version = "0.1.13", optional = true} # transport async-stream = {version = "0.3", optional = true} -h2 = {version = "0.3.24", optional = true} -hyper = {version = "0.14.26", features = ["full"], optional = true} -hyper-timeout = {version = "0.4", optional = true} -tokio = {version = "1.0.1", optional = true} -tokio-stream = "0.1" +h2 = {version = "0.4", optional = true} +hyper = {version = "1", features = ["full"], optional = true} +hyper-util = { version = ">=0.1.4, <0.2", features = ["full"], optional = true } +hyper-timeout = {version = "0.5", optional = true} +socket2 = { version = ">=0.4.7, <0.6.0", optional = true, features = ["all"] } +tokio = {version = "1", default-features = false, optional = true} +tokio-stream = { version = "0.1", features = ["net"] } tower = {version = "0.4.7", default-features = false, features = ["balance", "buffer", "discover", "limit", "load", "make", "timeout", "util"], optional = true} axum = {version = "0.6.9", default-features = false, optional = true} From ddb2e004b9db103ad91b990f6fb39944aaa69855 Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 04:17:49 +0000 Subject: [PATCH 02/25] Convert from hyper::Body to http_body::BoxedBody When appropriate, we replace `hyper::Body` with `http_body::BoxedBody`, a good general purpose replacement for `hyper::Body`. Hyper does provide `hyper::body::Incoming`, but we cannot construct that, so anywhere we might need a body that we can construct (even most Service trait impls) we must use something like `http_body::BoxedBody`. When a service accepts `BoxedBody` and not `Incoming`, this indicates that the service is designed to run in places where it is not adjacent to hyper, for example, after routing (which is managed by Axum) Additionally, http >= 1 requires that extension types are `Clone`, so this bound has been added where appropriate. Co-authored-by: Ivan Krivosheev Co-authored-by: Allan Zhang --- examples/src/interceptor/server.rs | 1 + examples/src/tower/client.rs | 3 +- examples/src/tower/server.rs | 10 +++-- interop/src/server.rs | 6 +-- tests/integration_tests/tests/extensions.rs | 9 +++-- tests/integration_tests/tests/origin.rs | 1 + tonic-web/src/layer.rs | 2 +- tonic-web/src/lib.rs | 10 ++--- tonic-web/src/service.rs | 39 +++++++++---------- tonic-web/tests/integration/tests/grpc_web.rs | 10 +++-- tonic/src/body.rs | 6 +-- tonic/src/extensions.rs | 2 +- tonic/src/request.rs | 2 +- tonic/src/transport/server/mod.rs | 11 ++++-- tonic/src/transport/service/connection.rs | 5 +-- 15 files changed, 62 insertions(+), 55 deletions(-) diff --git a/examples/src/interceptor/server.rs b/examples/src/interceptor/server.rs index 263348a6d..fd0cf462f 100644 --- a/examples/src/interceptor/server.rs +++ b/examples/src/interceptor/server.rs @@ -57,6 +57,7 @@ fn intercept(mut req: Request<()>) -> Result, Status> { Ok(req) } +#[derive(Clone)] struct MyExtension { some_piece_of_data: String, } diff --git a/examples/src/tower/client.rs b/examples/src/tower/client.rs index 0a33fffae..39fec5d47 100644 --- a/examples/src/tower/client.rs +++ b/examples/src/tower/client.rs @@ -44,7 +44,6 @@ mod service { use std::pin::Pin; use std::task::{Context, Poll}; use tonic::body::BoxBody; - use tonic::transport::Body; use tonic::transport::Channel; use tower::Service; @@ -59,7 +58,7 @@ mod service { } impl Service> for AuthSvc { - type Response = Response; + type Response = Response; type Error = Box; #[allow(clippy::type_complexity)] type Future = Pin> + Send>>; diff --git a/examples/src/tower/server.rs b/examples/src/tower/server.rs index cc85d62e5..b7066a1b6 100644 --- a/examples/src/tower/server.rs +++ b/examples/src/tower/server.rs @@ -1,4 +1,3 @@ -use hyper::Body; use std::{ pin::Pin, task::{Context, Poll}, @@ -84,9 +83,12 @@ struct MyMiddleware { type BoxFuture<'a, T> = Pin + Send + 'a>>; -impl Service> for MyMiddleware +impl Service> for MyMiddleware where - S: Service, Response = hyper::Response> + Clone + Send + 'static, + S: Service, Response = hyper::Response> + + Clone + + Send + + 'static, S::Future: Send + 'static, { type Response = S::Response; @@ -97,7 +99,7 @@ where self.inner.poll_ready(cx) } - fn call(&mut self, req: hyper::Request) -> Self::Future { + fn call(&mut self, req: hyper::Request) -> Self::Future { // This is necessary because tonic internally uses `tower::buffer::Buffer`. // See https://github.com/tower-rs/tower/issues/547#issuecomment-767629149 // for details on why this is necessary diff --git a/interop/src/server.rs b/interop/src/server.rs index b32468866..aef7b0d45 100644 --- a/interop/src/server.rs +++ b/interop/src/server.rs @@ -180,9 +180,9 @@ impl EchoHeadersSvc { } } -impl Service> for EchoHeadersSvc +impl Service> for EchoHeadersSvc where - S: Service, Response = http::Response> + Send, + S: Service, Response = http::Response> + Send, S::Future: Send + 'static, { type Response = S::Response; @@ -193,7 +193,7 @@ where Ok(()).into() } - fn call(&mut self, req: http::Request) -> Self::Future { + fn call(&mut self, req: http::Request) -> Self::Future { let echo_header = req.headers().get("x-grpc-test-echo-initial").cloned(); let echo_trailer = req diff --git a/tests/integration_tests/tests/extensions.rs b/tests/integration_tests/tests/extensions.rs index b112f8e66..b2380181d 100644 --- a/tests/integration_tests/tests/extensions.rs +++ b/tests/integration_tests/tests/extensions.rs @@ -1,4 +1,4 @@ -use hyper::{Body, Request as HyperRequest, Response as HyperResponse}; +use hyper::{Request as HyperRequest, Response as HyperResponse}; use integration_tests::{ pb::{test_client, test_server, Input, Output}, BoxFuture, @@ -16,6 +16,7 @@ use tonic::{ }; use tower_service::Service; +#[derive(Clone)] struct ExtensionValue(i32); #[tokio::test] @@ -112,9 +113,9 @@ struct InterceptedService { inner: S, } -impl Service> for InterceptedService +impl Service> for InterceptedService where - S: Service, Response = HyperResponse> + S: Service, Response = HyperResponse> + NamedService + Clone + Send @@ -129,7 +130,7 @@ where self.inner.poll_ready(cx) } - fn call(&mut self, mut req: HyperRequest) -> Self::Future { + fn call(&mut self, mut req: HyperRequest) -> Self::Future { let clone = self.inner.clone(); let mut inner = std::mem::replace(&mut self.inner, clone); diff --git a/tests/integration_tests/tests/origin.rs b/tests/integration_tests/tests/origin.rs index f149dc68d..c8140c79f 100644 --- a/tests/integration_tests/tests/origin.rs +++ b/tests/integration_tests/tests/origin.rs @@ -7,6 +7,7 @@ use std::time::Duration; use tokio::sync::oneshot; use tonic::codegen::http::Request; use tonic::{ + body::BoxBody, transport::{Endpoint, Server}, Response, Status, }; diff --git a/tonic-web/src/layer.rs b/tonic-web/src/layer.rs index 77b03c77e..7834f1990 100644 --- a/tonic-web/src/layer.rs +++ b/tonic-web/src/layer.rs @@ -24,7 +24,7 @@ impl Default for GrpcWebLayer { impl Layer for GrpcWebLayer where - S: Service, Response = http::Response>, + S: Service, Response = http::Response>, S: Send + 'static, S::Future: Send + 'static, S::Error: Into + Send, diff --git a/tonic-web/src/lib.rs b/tonic-web/src/lib.rs index 16e57e19d..50ed8c0a8 100644 --- a/tonic-web/src/lib.rs +++ b/tonic-web/src/lib.rs @@ -127,7 +127,7 @@ type BoxError = Box; /// You can customize the CORS configuration composing the [`GrpcWebLayer`] with the cors layer of your choice. pub fn enable(service: S) -> CorsGrpcWeb where - S: Service, Response = http::Response>, + S: Service, Response = http::Response>, S: Clone + Send + 'static, S::Future: Send + 'static, S::Error: Into + Send, @@ -159,9 +159,9 @@ where #[derive(Debug, Clone)] pub struct CorsGrpcWeb(tower_http::cors::Cors>); -impl Service> for CorsGrpcWeb +impl Service> for CorsGrpcWeb where - S: Service, Response = http::Response>, + S: Service, Response = http::Response>, S: Clone + Send + 'static, S::Future: Send + 'static, S::Error: Into + Send, @@ -169,7 +169,7 @@ where type Response = S::Response; type Error = S::Error; type Future = - > as Service>>::Future; + > as Service>>::Future; fn poll_ready( &mut self, @@ -178,7 +178,7 @@ where self.0.poll_ready(cx) } - fn call(&mut self, req: http::Request) -> Self::Future { + fn call(&mut self, req: http::Request) -> Self::Future { self.0.call(req) } } diff --git a/tonic-web/src/service.rs b/tonic-web/src/service.rs index af4c5276f..da65ba832 100644 --- a/tonic-web/src/service.rs +++ b/tonic-web/src/service.rs @@ -3,7 +3,7 @@ use std::pin::Pin; use std::task::{ready, Context, Poll}; use http::{header, HeaderMap, HeaderValue, Method, Request, Response, StatusCode, Version}; -use hyper::Body; +use http_body_util::BodyExt; use pin_project::pin_project; use tonic::{ body::{empty_body, BoxBody}, @@ -50,7 +50,7 @@ impl GrpcWebService { impl GrpcWebService where - S: Service, Response = Response> + Send + 'static, + S: Service, Response = Response> + Send + 'static, { fn response(&self, status: StatusCode) -> ResponseFuture { ResponseFuture { @@ -66,9 +66,9 @@ where } } -impl Service> for GrpcWebService +impl Service> for GrpcWebService where - S: Service, Response = Response> + Send + 'static, + S: Service, Response = Response> + Send + 'static, S::Future: Send + 'static, S::Error: Into + Send, { @@ -80,7 +80,7 @@ where self.inner.poll_ready(cx) } - fn call(&mut self, req: Request) -> Self::Future { + fn call(&mut self, req: Request) -> Self::Future { match RequestKind::new(req.headers(), req.method(), req.version()) { // A valid grpc-web request, regardless of HTTP version. // @@ -202,7 +202,7 @@ impl<'a> RequestKind<'a> { // Mutating request headers to conform to a gRPC request is not really // necessary for us at this point. We could remove most of these except // maybe for inserting `header::TE`, which tonic should check? -fn coerce_request(mut req: Request, encoding: Encoding) -> Request { +fn coerce_request(mut req: Request, encoding: Encoding) -> Request { req.headers_mut().remove(header::CONTENT_LENGTH); req.headers_mut() @@ -216,8 +216,7 @@ fn coerce_request(mut req: Request, encoding: Encoding) -> Request { HeaderValue::from_static("identity,deflate,gzip"), ); - req.map(|b| GrpcWebCall::request(b, encoding)) - .map(Body::wrap_stream) + req.map(|b| GrpcWebCall::request(b, encoding).boxed_unsync()) } fn coerce_response(res: Response, encoding: Encoding) -> Response { @@ -246,7 +245,7 @@ mod tests { #[derive(Debug, Clone)] struct Svc; - impl tower_service::Service> for Svc { + impl tower_service::Service> for Svc { type Response = Response; type Error = String; type Future = BoxFuture; @@ -255,7 +254,7 @@ mod tests { Poll::Ready(Ok(())) } - fn call(&mut self, _: Request) -> Self::Future { + fn call(&mut self, _: Request) -> Self::Future { Box::pin(async { Ok(Response::new(empty_body())) }) } } @@ -266,15 +265,14 @@ mod tests { mod grpc_web { use super::*; - use http::HeaderValue; use tower_layer::Layer; - fn request() -> Request { + fn request() -> Request { Request::builder() .method(Method::POST) .header(CONTENT_TYPE, GRPC_WEB) .header(ORIGIN, "http://example.com") - .body(Body::empty()) + .body(empty_body()) .unwrap() } @@ -350,13 +348,13 @@ mod tests { mod options { use super::*; - fn request() -> Request { + fn request() -> Request { Request::builder() .method(Method::OPTIONS) .header(ORIGIN, "http://example.com") .header(ACCESS_CONTROL_REQUEST_HEADERS, "x-grpc-web") .header(ACCESS_CONTROL_REQUEST_METHOD, "POST") - .body(Body::empty()) + .body(empty_body()) .unwrap() } @@ -371,13 +369,12 @@ mod tests { mod grpc { use super::*; - use http::HeaderValue; - fn request() -> Request { + fn request() -> Request { Request::builder() .version(Version::HTTP_2) .header(CONTENT_TYPE, GRPC) - .body(Body::empty()) + .body(empty_body()) .unwrap() } @@ -397,7 +394,7 @@ mod tests { let req = Request::builder() .header(CONTENT_TYPE, GRPC) - .body(Body::empty()) + .body(empty_body()) .unwrap(); let res = svc.call(req).await.unwrap(); @@ -425,10 +422,10 @@ mod tests { mod other { use super::*; - fn request() -> Request { + fn request() -> Request { Request::builder() .header(CONTENT_TYPE, "application/text") - .body(Body::empty()) + .body(empty_body()) .unwrap() } diff --git a/tonic-web/tests/integration/tests/grpc_web.rs b/tonic-web/tests/integration/tests/grpc_web.rs index 3343d754c..037ff8dad 100644 --- a/tonic-web/tests/integration/tests/grpc_web.rs +++ b/tonic-web/tests/integration/tests/grpc_web.rs @@ -2,11 +2,13 @@ use std::net::SocketAddr; use base64::Engine as _; use bytes::{Buf, BufMut, Bytes, BytesMut}; +use http_body_util::{BodyExt as _, Full}; use hyper::http::{header, StatusCode}; -use hyper::{Body, Client, Method, Request, Uri}; +use hyper::{Method, Request, Uri}; use prost::Message; use tokio::net::TcpListener; use tokio_stream::wrappers::TcpListenerStream; +use tonic::body::BoxBody; use tonic::transport::Server; use integration::pb::{test_server::TestServer, Input, Output}; @@ -102,7 +104,7 @@ fn encode_body() -> Bytes { buf.split_to(len + 5).freeze() } -fn build_request(base_uri: String, content_type: &str, accept: &str) -> Request { +fn build_request(base_uri: String, content_type: &str, accept: &str) -> Request { use header::{ACCEPT, CONTENT_TYPE, ORIGIN}; let request_uri = format!("{}/{}/{}", base_uri, "test.Test", "UnaryCall") @@ -123,7 +125,9 @@ fn build_request(base_uri: String, content_type: &str, accept: &str) -> Request< .header(ORIGIN, "http://example.com") .header(ACCEPT, format!("application/{}", accept)) .uri(request_uri) - .body(Body::from(bytes)) + .body(BoxBody::new( + Full::new(bytes).map_err(|err| Status::internal(err.to_string())), + )) .unwrap() } diff --git a/tonic/src/body.rs b/tonic/src/body.rs index ef95eec47..428c0dade 100644 --- a/tonic/src/body.rs +++ b/tonic/src/body.rs @@ -1,9 +1,9 @@ //! HTTP specific body utilities. -use http_body::Body; +use http_body_util::BodyExt; /// A type erased HTTP body used for tonic services. -pub type BoxBody = http_body::combinators::UnsyncBoxBody; +pub type BoxBody = http_body_util::combinators::UnsyncBoxBody; /// Convert a [`http_body::Body`] into a [`BoxBody`]. pub(crate) fn boxed(body: B) -> BoxBody @@ -16,7 +16,7 @@ where /// Create an empty `BoxBody` pub fn empty_body() -> BoxBody { - http_body::Empty::new() + http_body_util::Empty::new() .map_err(|err| match err {}) .boxed_unsync() } diff --git a/tonic/src/extensions.rs b/tonic/src/extensions.rs index 37d84b87b..32b9ad021 100644 --- a/tonic/src/extensions.rs +++ b/tonic/src/extensions.rs @@ -24,7 +24,7 @@ impl Extensions { /// If a extension of this type already existed, it will /// be returned. #[inline] - pub fn insert(&mut self, val: T) -> Option { + pub fn insert(&mut self, val: T) -> Option { self.inner.insert(val) } diff --git a/tonic/src/request.rs b/tonic/src/request.rs index 76bf4e9eb..e0829424c 100644 --- a/tonic/src/request.rs +++ b/tonic/src/request.rs @@ -311,6 +311,7 @@ impl Request { /// ```no_run /// use tonic::{Request, service::interceptor}; /// + /// #[derive(Clone)] // Extensions must be Clone /// struct MyExtension { /// some_piece_of_data: String, /// } @@ -438,7 +439,6 @@ pub(crate) enum SanitizeHeaders { #[cfg(test)] mod tests { use super::*; - use crate::metadata::MetadataValue; use http::Uri; #[test] diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs index 7f2ffde2b..ad930c617 100644 --- a/tonic/src/transport/server/mod.rs +++ b/tonic/src/transport/server/mod.rs @@ -35,12 +35,13 @@ use crate::transport::Error; use self::recover_error::RecoverError; use super::service::{GrpcTimeout, ServerIo}; +use crate::body::boxed; use crate::body::BoxBody; use crate::server::NamedService; use bytes::Bytes; use http::{Request, Response}; -use http_body::Body as _; -use hyper::{server::accept, Body}; +use http_body_util::BodyExt; +use hyper::server::accept; use pin_project::pin_project; use std::{ convert::Infallible, @@ -63,9 +64,11 @@ use tower::{ Service, ServiceBuilder, }; -type BoxHttpBody = http_body::combinators::UnsyncBoxBody; -type BoxService = tower::util::BoxService, Response, crate::Error>; type TraceInterceptor = Arc) -> tracing::Span + Send + Sync + 'static>; +type BoxHttpBody = crate::body::BoxBody; +type Body = hyper::body::Incoming; // Temporary type alias to ease transition +type BoxError = crate::Error; +type BoxService = tower::util::BoxCloneService, Response, crate::Error>; const DEFAULT_HTTP2_KEEPALIVE_TIMEOUT_SECS: u64 = 20; diff --git a/tonic/src/transport/service/connection.rs b/tonic/src/transport/service/connection.rs index 46a88dda5..b3428aa2c 100644 --- a/tonic/src/transport/service/connection.rs +++ b/tonic/src/transport/service/connection.rs @@ -1,6 +1,6 @@ use super::{grpc_timeout::GrpcTimeout, reconnect::Reconnect, AddOrigin, UserAgent}; use crate::{ - body::BoxBody, + body::{boxed, BoxBody}, transport::{BoxFuture, Endpoint}, }; use http::Uri; @@ -21,8 +21,7 @@ use tower::{ }; use tower_service::Service; -pub(crate) type Request = http::Request; -pub(crate) type Response = http::Response; +pub(crate) use crate::transport::{Request, Response}; pub(crate) struct Connection { inner: BoxService, From 1a2af6622fc64efd6981342599171f44e0173484 Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 02:29:37 +0000 Subject: [PATCH 03/25] Convert tonic::codec::decode to use http >= 1 body types The Body trait has changed (removed `poll_data` and `poll_trailers`, they are now combined in `poll_frame`) and so the codec must be re-written to merge those two methods. Co-authored-by: Ivan Krivosheev --- tonic/benches/decode.rs | 21 ++++++------ tonic/src/codec/decode.rs | 71 ++++++++++++++++++--------------------- 2 files changed, 43 insertions(+), 49 deletions(-) diff --git a/tonic/benches/decode.rs b/tonic/benches/decode.rs index 5c7cd0159..22ab6d9d4 100644 --- a/tonic/benches/decode.rs +++ b/tonic/benches/decode.rs @@ -1,6 +1,6 @@ use bencher::{benchmark_group, benchmark_main, Bencher}; use bytes::{Buf, BufMut, Bytes, BytesMut}; -use http_body::Body; +use http_body::{Body, Frame, SizeHint}; use std::{ fmt::{Error, Formatter}, pin::Pin, @@ -58,23 +58,24 @@ impl Body for MockBody { type Data = Bytes; type Error = Status; - fn poll_data( + fn poll_frame( mut self: Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll>> { + _cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { if self.data.has_remaining() { let split = std::cmp::min(self.chunk_size, self.data.remaining()); - Poll::Ready(Some(Ok(self.data.split_to(split)))) + Poll::Ready(Some(Ok(Frame::data(self.data.split_to(split))))) } else { Poll::Ready(None) } } - fn poll_trailers( - self: Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll, Self::Error>> { - Poll::Ready(Ok(None)) + fn is_end_stream(&self) -> bool { + !self.data.is_empty() + } + + fn size_hint(&self) -> SizeHint { + SizeHint::with_exact(self.data.len() as u64) } } diff --git a/tonic/src/codec/decode.rs b/tonic/src/codec/decode.rs index 081f6193d..dea83d931 100644 --- a/tonic/src/codec/decode.rs +++ b/tonic/src/codec/decode.rs @@ -2,8 +2,9 @@ use super::compression::{decompress, CompressionEncoding, CompressionSettings}; use super::{BufferSettings, DecodeBuf, Decoder, DEFAULT_MAX_RECV_MESSAGE_SIZE, HEADER_SIZE}; use crate::{body::BoxBody, metadata::MetadataMap, Code, Status}; use bytes::{Buf, BufMut, BytesMut}; -use http::StatusCode; +use http::{HeaderMap, StatusCode}; use http_body::Body; +use http_body_util::BodyExt; use std::{ fmt, future, pin::Pin, @@ -27,7 +28,7 @@ struct StreamingInner { state: State, direction: Direction, buf: BytesMut, - trailers: Option, + trailers: Option, decompress_buf: BytesMut, encoding: Option, max_message_size: Option, @@ -121,7 +122,7 @@ impl Streaming { decoder: Box::new(decoder), inner: StreamingInner { body: body - .map_data(|mut buf| buf.copy_to_bytes(buf.remaining())) + .map_frame(|frame| frame.map_data(|mut buf| buf.copy_to_bytes(buf.remaining()))) .map_err(|err| Status::map_error(err.into())) .boxed_unsync(), state: State::ReadHeader, @@ -239,8 +240,8 @@ impl StreamingInner { } // Returns Some(()) if data was found or None if the loop in `poll_next` should break - fn poll_data(&mut self, cx: &mut Context<'_>) -> Poll, Status>> { - let chunk = match ready!(Pin::new(&mut self.body).poll_data(cx)) { + fn poll_frame(&mut self, cx: &mut Context<'_>) -> Poll, Status>> { + let chunk = match ready!(Pin::new(&mut self.body).poll_frame(cx)) { Some(Ok(d)) => Some(d), Some(Err(status)) => { if self.direction == Direction::Request && status.code() == Code::Cancelled { @@ -254,9 +255,18 @@ impl StreamingInner { None => None, }; - Poll::Ready(if let Some(data) = chunk { - self.buf.put(data); - Ok(Some(())) + Poll::Ready(if let Some(frame) = chunk { + match frame { + frame if frame.is_data() => { + self.buf.put(frame.into_data().unwrap()); + Ok(Some(())) + } + frame if frame.is_trailers() => { + self.trailers = Some(frame.into_trailers().unwrap()); + Ok(None) + } + frame => panic!("unexpected frame: {:?}", frame), + } } else { // FIXME: improve buf usage. if self.buf.has_remaining() { @@ -271,27 +281,18 @@ impl StreamingInner { }) } - fn poll_response(&mut self, cx: &mut Context<'_>) -> Poll> { + fn response(&mut self) -> Result<(), Status> { if let Direction::Response(status) = self.direction { - match ready!(Pin::new(&mut self.body).poll_trailers(cx)) { - Ok(trailer) => { - if let Err(e) = crate::status::infer_grpc_status(trailer.as_ref(), status) { - if let Some(e) = e { - return Poll::Ready(Err(e)); - } else { - return Poll::Ready(Ok(())); - } - } else { - self.trailers = trailer.map(MetadataMap::from_headers); - } - } - Err(status) => { - debug!("decoder inner trailers error: {:?}", status); - return Poll::Ready(Err(status)); + if let Err(e) = crate::status::infer_grpc_status(self.trailers.as_ref(), status) { + if let Some(e) = e { + // If the trailers contain a grpc-status, then we should return that as the error + // and otherwise stop the stream (by taking the error state) + self.trailers.take(); + return Err(e); } } } - Poll::Ready(Ok(())) + Ok(()) } } @@ -351,7 +352,7 @@ impl Streaming { // Shortcut to see if we already pulled the trailers in the stream step // we need to do that so that the stream can error on trailing grpc-status if let Some(trailers) = self.inner.trailers.take() { - return Ok(Some(trailers)); + return Ok(Some(MetadataMap::from_headers(trailers))); } // To fetch the trailers we must clear the body and drop it. @@ -360,16 +361,11 @@ impl Streaming { // Since we call poll_trailers internally on poll_next we need to // check if it got cached again. if let Some(trailers) = self.inner.trailers.take() { - return Ok(Some(trailers)); + return Ok(Some(MetadataMap::from_headers(trailers))); } - // Trailers were not caught during poll_next and thus lets poll for - // them manually. - let map = future::poll_fn(|cx| Pin::new(&mut self.inner.body).poll_trailers(cx)) - .await - .map_err(|e| Status::from_error(Box::new(e))); - - map.map(|x| x.map(MetadataMap::from_headers)) + // We've polled through all the frames, and still no trailers, return None + Ok(None) } fn decode_chunk(&mut self) -> Result, Status> { @@ -395,20 +391,17 @@ impl Stream for Streaming { return Poll::Ready(None); } - // FIXME: implement the ability to poll trailers when we _know_ that - // the consumer of this stream will only poll for the first message. - // This means we skip the poll_trailers step. if let Some(item) = self.decode_chunk()? { return Poll::Ready(Some(Ok(item))); } - match ready!(self.inner.poll_data(cx))? { + match ready!(self.inner.poll_frame(cx))? { Some(()) => (), None => break, } } - Poll::Ready(match ready!(self.inner.poll_response(cx)) { + Poll::Ready(match self.inner.response() { Ok(()) => None, Err(err) => Some(Err(err)), }) From 6593027165e2154f8c1b64ec99a1d34eab4ccbe2 Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 02:35:14 +0000 Subject: [PATCH 04/25] Convert tonic::transport::channel to use http >= 1 body types tonic::transport::channel previously used `hyper::Body` as the response body type. This type no longer exists in hyper >= 1, and so has been converted to a `BoxBody` provided by `http_body_util` designed for interoperability between http crates. Co-authored-by: Ivan Krivosheev Co-authored-by: Allan Zhang --- tonic/src/transport/channel/mod.rs | 6 +++--- tonic/src/transport/service/connection.rs | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tonic/src/transport/channel/mod.rs b/tonic/src/transport/channel/mod.rs index b510a6980..6a857dff1 100644 --- a/tonic/src/transport/channel/mod.rs +++ b/tonic/src/transport/channel/mod.rs @@ -38,7 +38,7 @@ use tower::{ Service, }; -type Svc = Either, Response, crate::Error>>; +type Svc = Either, Response, crate::Error>>; const DEFAULT_BUFFER_SIZE: usize = 1024; @@ -201,7 +201,7 @@ impl Channel { } impl Service> for Channel { - type Response = http::Response; + type Response = http::Response; type Error = super::Error; type Future = ResponseFuture; @@ -217,7 +217,7 @@ impl Service> for Channel { } impl Future for ResponseFuture { - type Output = Result, super::Error>; + type Output = Result, super::Error>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let val = ready!(Pin::new(&mut self.inner).poll(cx)).map_err(super::Error::from_source)?; diff --git a/tonic/src/transport/service/connection.rs b/tonic/src/transport/service/connection.rs index b3428aa2c..1fa059c96 100644 --- a/tonic/src/transport/service/connection.rs +++ b/tonic/src/transport/service/connection.rs @@ -21,7 +21,8 @@ use tower::{ }; use tower_service::Service; -pub(crate) use crate::transport::{Request, Response}; +pub(crate) type Response = http::Response; +pub(crate) type Request = http::Request; pub(crate) struct Connection { inner: BoxService, From 233f0c8ff853520d0c95037f506cf49aa83f793d Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 02:31:02 +0000 Subject: [PATCH 05/25] [tests] Convert tonic::codec::prost::tests to use http >= 1 body types The Body trait has changed (removed `poll_data` and `poll_trailers`, they are now combined in `poll_frame`) and so the codec must be re-written to merge those two methods. This also handles the return types which should now be wrapped in `Frame` when appropriate. Co-authored-by: Ivan Krivosheev Co-authored-by: Allan Zhang --- tonic/src/codec/prost.rs | 44 ++++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tonic/src/codec/prost.rs b/tonic/src/codec/prost.rs index 217934e9e..3e6789daf 100644 --- a/tonic/src/codec/prost.rs +++ b/tonic/src/codec/prost.rs @@ -156,6 +156,7 @@ mod tests { use crate::{Code, Status}; use bytes::{Buf, BufMut, BytesMut}; use http_body::Body; + use http_body_util::BodyExt as _; use std::pin::pin; const LEN: usize = 10000; @@ -238,7 +239,7 @@ mod tests { None, )); - while let Some(r) = body.data().await { + while let Some(r) = body.frame().await { r.unwrap(); } } @@ -260,12 +261,15 @@ mod tests { Some(MAX_MESSAGE_SIZE), )); - assert!(body.data().await.is_none()); + let frame = body + .frame() + .await + .expect("at least one frame") + .expect("no error polling frame"); assert_eq!( - body.trailers() - .await - .expect("no error polling trailers") - .expect("some trailers") + frame + .into_trailers() + .expect("got trailers") .get("grpc-status") .expect("grpc-status header"), "11" @@ -292,12 +296,15 @@ mod tests { Some(usize::MAX), )); - assert!(body.data().await.is_none()); + let frame = body + .frame() + .await + .expect("at least one frame") + .expect("no error polling frame"); assert_eq!( - body.trailers() - .await - .expect("no error polling trailers") - .expect("some trailers") + frame + .into_trailers() + .expect("got trailers") .get("grpc-status") .expect("grpc-status header"), "8" @@ -343,7 +350,7 @@ mod tests { mod body { use crate::Status; use bytes::Bytes; - use http_body::Body; + use http_body::{Body, Frame}; use std::{ pin::Pin, task::{Context, Poll}, @@ -374,10 +381,10 @@ mod tests { type Data = Bytes; type Error = Status; - fn poll_data( + fn poll_frame( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Self::Error>>> { // every other call to poll_data returns data let should_send = self.count % 2 == 0; let data_len = self.data.len(); @@ -395,18 +402,11 @@ mod tests { }; // make some fake progress self.count += 1; - result + result.map(|opt| opt.map(|res| res.map(|data| Frame::data(data)))) } else { Poll::Ready(None) } } - - fn poll_trailers( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - Poll::Ready(Ok(None)) - } } } } From a1f524f68d10ade7edb813f7fd2566b9687e4f30 Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 02:30:04 +0000 Subject: [PATCH 06/25] Convert tonic::codec::encode to use http >= 1 body types The Body trait has changed (removed `poll_data` and `poll_trailers`, they are now combined in `poll_frame`) and so the codec must be re-written to merge those two methods. Co-authored-by: Ivan Krivosheev Co-authored-by: Allan Zhang --- tonic/src/codec/encode.rs | 46 ++++++++++++++------------------------- 1 file changed, 16 insertions(+), 30 deletions(-) diff --git a/tonic/src/codec/encode.rs b/tonic/src/codec/encode.rs index 07aed1dda..3f7c628ee 100644 --- a/tonic/src/codec/encode.rs +++ b/tonic/src/codec/encode.rs @@ -5,7 +5,7 @@ use super::{BufferSettings, EncodeBuf, Encoder, DEFAULT_MAX_SEND_MESSAGE_SIZE, H use crate::{Code, Status}; use bytes::{BufMut, Bytes, BytesMut}; use http::HeaderMap; -use http_body::Body; +use http_body::{Body, Frame}; use pin_project::pin_project; use std::{ pin::Pin, @@ -298,22 +298,21 @@ where } impl EncodeState { - fn trailers(&mut self) -> Result, Status> { + fn trailers(&mut self) -> Option> { match self.role { - Role::Client => Ok(None), + Role::Client => None, Role::Server => { if self.is_end_stream { - return Ok(None); + return None; } + self.is_end_stream = true; let status = if let Some(status) = self.error.take() { - self.is_end_stream = true; status } else { Status::new(Code::Ok, "") }; - - Ok(Some(status.to_header_map()?)) + Some(status.to_header_map()) } } } @@ -330,38 +329,25 @@ where self.state.is_end_stream } - fn size_hint(&self) -> http_body::SizeHint { - let sh = self.inner.size_hint(); - let mut size_hint = http_body::SizeHint::new(); - size_hint.set_lower(sh.0 as u64); - if let Some(upper) = sh.1 { - size_hint.set_upper(upper as u64); - } - size_hint - } - - fn poll_data( + fn poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Self::Error>>> { let self_proj = self.project(); match ready!(self_proj.inner.poll_next(cx)) { - Some(Ok(d)) => Some(Ok(d)).into(), + Some(Ok(d)) => Some(Ok(Frame::data(d))).into(), Some(Err(status)) => match self_proj.state.role { Role::Client => Some(Err(status)).into(), Role::Server => { - self_proj.state.error = Some(status); - None.into() + self_proj.state.is_end_stream = true; + Some(Ok(Frame::trailers(status.to_header_map()?))).into() } }, - None => None.into(), + None => self_proj + .state + .trailers() + .map(|t| t.map(Frame::trailers)) + .into(), } } - - fn poll_trailers( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll, Status>> { - Poll::Ready(self.project().state.trailers()) - } } From 79317204a19b21574993ea50ad18a8fc2aa49a3c Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 02:31:37 +0000 Subject: [PATCH 07/25] [tests] Convert tonic::service::interceptor::tests to use http >= 1 body types The Body trait has changed (removed `poll_data` and `poll_trailers`, they are now combined in `poll_frame`) and so the codec must be re-written to merge those two methods. This also handles the return types which should now be wrapped in `Frame` when appropriate. Co-authored-by: Ivan Krivosheev Co-authored-by: Allan Zhang --- tonic/src/service/interceptor.rs | 24 +++++++----------------- 1 file changed, 7 insertions(+), 17 deletions(-) diff --git a/tonic/src/service/interceptor.rs b/tonic/src/service/interceptor.rs index cadff466f..ebe78093d 100644 --- a/tonic/src/service/interceptor.rs +++ b/tonic/src/service/interceptor.rs @@ -232,11 +232,8 @@ where mod tests { #[allow(unused_imports)] use super::*; - use http::header::HeaderMap; - use std::{ - pin::Pin, - task::{Context, Poll}, - }; + use http_body::Frame; + use http_body_util::Empty; use tower::ServiceExt; #[derive(Debug, Default)] @@ -246,19 +243,12 @@ mod tests { type Data = Bytes; type Error = Status; - fn poll_data( + fn poll_frame( self: Pin<&mut Self>, _cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Self::Error>>> { Poll::Ready(None) } - - fn poll_trailers( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - Poll::Ready(Ok(None)) - } } #[tokio::test] @@ -318,17 +308,17 @@ mod tests { #[tokio::test] async fn doesnt_change_http_method() { - let svc = tower::service_fn(|request: http::Request| async move { + let svc = tower::service_fn(|request: http::Request>| async move { assert_eq!(request.method(), http::Method::OPTIONS); - Ok::<_, hyper::Error>(hyper::Response::new(hyper::Body::empty())) + Ok::<_, hyper::Error>(hyper::Response::new(Empty::new())) }); let svc = InterceptedService::new(svc, Ok); let request = http::Request::builder() .method(http::Method::OPTIONS) - .body(hyper::Body::empty()) + .body(Empty::new()) .unwrap(); svc.oneshot(request).await.unwrap(); From 87444d1e5e51a6ea2c40b5e62146c5ed95ca5810 Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 02:39:22 +0000 Subject: [PATCH 08/25] Convert tonic::transport to use http >= 1 body types Here, we must update some body types which are no longer valid. (A) BoxBody no longer has an `empty` method, instead we provide a helper in `tonic::body` for creating an empty boxed body via `http_body_util`. As well, `hyper::Body` is no longer a type, and instead, `hyper::Incoming` is used when directly recieving a Request from hyper, and `BoxBody` is used when the request may have passed through an axum router. In tonic, we prefer `BoxBody` as it allows for services to be used downstream from other components which enforce a specific body type (e.g. Axum), at the cost of making Body streaming opaque. Co-authored-by: Ivan Krivosheev Co-authored-by: Allan Zhang --- tonic/src/transport/mod.rs | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/tonic/src/transport/mod.rs b/tonic/src/transport/mod.rs index a0435c797..1f5843754 100644 --- a/tonic/src/transport/mod.rs +++ b/tonic/src/transport/mod.rs @@ -1,7 +1,7 @@ //! Batteries included server and client. //! //! This module provides a set of batteries included, fully featured and -//! fast set of HTTP/2 server and client's. These components each provide a or +//! fast set of HTTP/2 server and client's. These components each provide a //! `rustls` tls backend when the respective feature flag is enabled, and //! provides builders to configure transport behavior. //! @@ -22,6 +22,7 @@ //! # use tonic::transport::{Channel, Certificate, ClientTlsConfig}; //! # use std::time::Duration; //! # use tonic::body::BoxBody; +//! # use tonic::body::empty_body; //! # use tonic::client::GrpcService;; //! # use http::Request; //! # #[cfg(feature = "rustls")] @@ -38,7 +39,7 @@ //! .connect() //! .await?; //! -//! channel.call(Request::new(BoxBody::empty())).await?; +//! channel.call(Request::new(empty_body())).await?; //! # Ok(()) //! # } //! ``` @@ -46,21 +47,23 @@ //! ## Server //! //! ```no_run +//! # use std::convert::Infallible; //! # #[cfg(feature = "rustls")] //! # use tonic::transport::{Server, Identity, ServerTlsConfig}; +//! # use tonic::body::BoxBody; //! # use tower::Service; //! # #[cfg(feature = "rustls")] //! # async fn do_thing() -> Result<(), Box> { //! # #[derive(Clone)] //! # pub struct Svc; -//! # impl Service> for Svc { //! # type Response = hyper::Response; -//! # type Error = tonic::Status; +//! # impl Service> for Svc { +//! # type Error = Infallible; //! # type Future = std::future::Ready>; //! # fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> std::task::Poll> { //! # Ok(()).into() //! # } -//! # fn call(&mut self, _req: hyper::Request) -> Self::Future { +//! # fn call(&mut self, _req: hyper::Request) -> Self::Future { //! # unimplemented!() //! # } //! # } @@ -108,7 +111,8 @@ pub use self::service::grpc_timeout::TimeoutExpired; #[cfg_attr(docsrs, doc(cfg(feature = "tls")))] pub use self::tls::Certificate; pub use axum::{body::BoxBody as AxumBoxBody, Router as AxumRouter}; -pub use hyper::{Body, Uri}; +pub use hyper::body::Body; +pub use hyper::Uri; pub(crate) use self::service::executor::Executor; @@ -121,5 +125,8 @@ pub use self::server::ServerTlsConfig; #[cfg(feature = "tls")] #[cfg_attr(docsrs, doc(cfg(feature = "tls")))] pub use self::tls::Identity; +use crate::body::BoxBody; type BoxFuture<'a, T> = std::pin::Pin + Send + 'a>>; +pub(crate) type Response = http::Response; +pub(crate) type Request = http::Request; From 070aba1a57ac5e17faec2fe5322b4d2893731b35 Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 02:40:14 +0000 Subject: [PATCH 09/25] Convert tonic::transport::server::recover_error to use http >= 1 body types The Body trait has changed (removed `poll_data` and `poll_trailers`, they are now combined in `poll_frame`) and so the codec must be re-written to merge those two methods. Co-authored-by: Ivan Krivosheev Co-authored-by: Ludea --- tonic/src/transport/server/recover_error.rs | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/tonic/src/transport/server/recover_error.rs b/tonic/src/transport/server/recover_error.rs index fdb14a66a..60b0d9a7b 100644 --- a/tonic/src/transport/server/recover_error.rs +++ b/tonic/src/transport/server/recover_error.rs @@ -1,5 +1,6 @@ use crate::Status; use http::Response; +use http_body::Frame; use pin_project::pin_project; use std::{ future::Future, @@ -98,26 +99,16 @@ where type Data = B::Data; type Error = B::Error; - fn poll_data( + fn poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Self::Error>>> { match self.project().inner.as_pin_mut() { - Some(b) => b.poll_data(cx), + Some(b) => b.poll_frame(cx), None => Poll::Ready(None), } } - fn poll_trailers( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - match self.project().inner.as_pin_mut() { - Some(b) => b.poll_trailers(cx), - None => Poll::Ready(Ok(None)), - } - } - fn is_end_stream(&self) -> bool { match &self.inner { Some(b) => b.is_end_stream(), From c61bf04c487ad53a9f71289571697f28be429e25 Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 04:25:12 +0000 Subject: [PATCH 10/25] Convert h2c examples to use http >= 1 body types In h2c, when a service is receiving from hyper, it has to accept a `hyper::body::Incoming` in hyper >= 1. Additionally, response bodies must be built from `http_body_util` combinators and become BoxBody objects. --- examples/src/h2c/client.rs | 11 +++++++---- examples/src/h2c/server.rs | 22 ++++++++++------------ 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/examples/src/h2c/client.rs b/examples/src/h2c/client.rs index 31076b1ac..624ea175f 100644 --- a/examples/src/h2c/client.rs +++ b/examples/src/h2c/client.rs @@ -34,15 +34,18 @@ mod h2c { }; use hyper::{client::HttpConnector, Client}; - use tonic::body::BoxBody; + use hyper::body::Incoming; + use hyper_util::{ + rt::TokioExecutor, + use tonic::body::{empty_body, BoxBody}; use tower::Service; pub struct H2cChannel { - pub client: Client, + pub client: Client, } impl Service> for H2cChannel { - type Response = http::Response; + type Response = http::Response; type Error = hyper::Error; type Future = Pin> + Send>>; @@ -60,7 +63,7 @@ mod h2c { let h2c_req = hyper::Request::builder() .uri(origin) .header(http::header::UPGRADE, "h2c") - .body(hyper::Body::empty()) + .body(empty_body()) .unwrap(); let res = client.request(h2c_req).await.unwrap(); diff --git a/examples/src/h2c/server.rs b/examples/src/h2c/server.rs index 92d08a417..21dcc1f35 100644 --- a/examples/src/h2c/server.rs +++ b/examples/src/h2c/server.rs @@ -49,7 +49,9 @@ mod h2c { use std::pin::Pin; use http::{Request, Response}; - use hyper::Body; + use hyper::body::Incoming; + use hyper_util::{rt::TokioExecutor, service::TowerToHyperService}; + use tonic::{body::empty_body, transport::AxumBoxBody}; use tower::Service; #[derive(Clone)] @@ -59,17 +61,14 @@ mod h2c { type BoxError = Box; - impl Service> for H2c + impl Service> for H2c where - S: Service, Response = Response> - + Clone - + Send - + 'static, + S: Service, Response = Response> + Clone + Send + 'static, S::Future: Send + 'static, S::Error: Into + Sync + Send + 'static, S::Response: Send + 'static, { - type Response = hyper::Response; + type Response = hyper::Response; type Error = hyper::Error; type Future = Pin> + Send>>; @@ -81,20 +80,19 @@ mod h2c { std::task::Poll::Ready(Ok(())) } - fn call(&mut self, mut req: hyper::Request) -> Self::Future { + fn call(&mut self, mut req: hyper::Request) -> Self::Future { let svc = self.s.clone(); Box::pin(async move { tokio::spawn(async move { let upgraded_io = hyper::upgrade::on(&mut req).await.unwrap(); - hyper::server::conn::Http::new() - .http2_only(true) - .serve_connection(upgraded_io, svc) + hyper::server::conn::http2::Builder::new(TokioExecutor::new()) + .serve_connection(upgraded_io, TowerToHyperService::new(svc)) .await .unwrap(); }); - let mut res = hyper::Response::new(hyper::Body::empty()); + let mut res = hyper::Response::new(empty_body()); *res.status_mut() = http::StatusCode::SWITCHING_PROTOCOLS; res.headers_mut().insert( hyper::header::UPGRADE, From b6c900e28aac58d48e6409f3411c6ae4e1e24b4f Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 02:54:14 +0000 Subject: [PATCH 11/25] [tests] Convert MergeTrailers body wrapper in interop server The Body trait has changed (removed `poll_data` and `poll_trailers`, they are now combined in `poll_frame`) and so the codec must be re-written to merge those two methods. --- interop/src/server.rs | 34 ++++++++++++++-------------------- 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/interop/src/server.rs b/interop/src/server.rs index aef7b0d45..38b1be65e 100644 --- a/interop/src/server.rs +++ b/interop/src/server.rs @@ -1,10 +1,10 @@ use crate::pb::{self, *}; use async_stream::try_stream; -use http::header::{HeaderMap, HeaderName, HeaderValue}; +use http::header::{HeaderName, HeaderValue}; use http_body::Body; use std::future::Future; use std::pin::Pin; -use std::task::{Context, Poll}; +use std::task::{ready, Context, Poll}; use std::time::Duration; use tokio_stream::StreamExt; use tonic::{body::BoxBody, server::NamedService, Code, Request, Response, Status}; @@ -235,25 +235,19 @@ impl Body for MergeTrailers { type Data = B::Data; type Error = B::Error; - fn poll_data( - mut self: Pin<&mut Self>, + fn poll_frame( + self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { - Pin::new(&mut self.inner).poll_data(cx) - } - - fn poll_trailers( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - Pin::new(&mut self.inner).poll_trailers(cx).map_ok(|h| { - h.map(|mut headers| { - if let Some((key, value)) = &self.trailer { - headers.insert(key.clone(), value.clone()); + ) -> Poll, Self::Error>>> { + let this = self.get_mut(); + let mut frame = ready!(Pin::new(&mut this.inner).poll_frame(cx)?); + if let Some(frame) = frame.as_mut() { + if let Some(trailers) = frame.trailers_mut() { + if let Some((key, value)) = &this.trailer { + trailers.insert(key.clone(), value.clone()); } - - headers - }) - }) + } + } + Poll::Ready(frame.map(Ok)) } } From f388657ca148f87c4405a75b70429abd40f5f99c Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 04:42:34 +0000 Subject: [PATCH 12/25] [tests] Convert compression tests to use hyper 1 body types The Body trait has changed (removed `poll_data` and `poll_trailers`, they are now combined in `poll_frame`) and so the codec must be re-written to merge those two methods. --- tests/compression/Cargo.toml | 2 +- tests/compression/src/util.rs | 81 ++++++++++++++++++++++++----------- 2 files changed, 56 insertions(+), 27 deletions(-) diff --git a/tests/compression/Cargo.toml b/tests/compression/Cargo.toml index 4ba549cdc..cf4da321b 100644 --- a/tests/compression/Cargo.toml +++ b/tests/compression/Cargo.toml @@ -20,7 +20,7 @@ tokio = {version = "1.0", features = ["macros", "rt-multi-thread", "net"]} tokio-stream = "0.1" tonic = {path = "../../tonic", features = ["gzip", "zstd"]} tower = {version = "0.4", features = []} -tower-http = {version = "0.4", features = ["map-response-body", "map-request-body"]} +tower-http = {version = "0.5", features = ["map-response-body", "map-request-body"]} [build-dependencies] tonic-build = {path = "../../tonic-build" } diff --git a/tests/compression/src/util.rs b/tests/compression/src/util.rs index 28fa5d96a..99afded3f 100644 --- a/tests/compression/src/util.rs +++ b/tests/compression/src/util.rs @@ -1,6 +1,8 @@ use super::*; -use bytes::Bytes; -use http_body::Body; +use bytes::{Buf, Bytes}; +use http_body::{Body, Frame}; +use http_body_util::BodyExt as _; +use hyper_util::rt::TokioIo; use pin_project::pin_project; use std::{ pin::Pin, @@ -11,6 +13,7 @@ use std::{ task::{ready, Context, Poll}, }; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tonic::body::BoxBody; use tonic::codec::CompressionEncoding; use tonic::transport::{server::Connected, Channel}; use tower_http::map_request_body::MapRequestBodyLayer; @@ -46,29 +49,22 @@ where type Data = B::Data; type Error = B::Error; - fn poll_data( + fn poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Self::Error>>> { let this = self.project(); let counter: Arc = this.counter.clone(); - match ready!(this.inner.poll_data(cx)) { + match ready!(this.inner.poll_frame(cx)) { Some(Ok(chunk)) => { - println!("response body chunk size = {}", chunk.len()); - counter.fetch_add(chunk.len(), SeqCst); + println!("response body chunk size = {}", frame_data_length(&chunk)); + counter.fetch_add(frame_data_length(&chunk), SeqCst); Poll::Ready(Some(Ok(chunk))) } x => Poll::Ready(x), } } - fn poll_trailers( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - self.project().inner.poll_trailers(cx) - } - fn is_end_stream(&self) -> bool { self.inner.is_end_stream() } @@ -78,28 +74,61 @@ where } } +fn frame_data_length(frame: &http_body::Frame) -> usize { + if let Some(data) = frame.data_ref() { + data.len() + } else { + 0 + } +} + +#[pin_project] +struct ChannelBody { + #[pin] + rx: tokio::sync::mpsc::Receiver>, +} + +impl ChannelBody { + pub fn new() -> (tokio::sync::mpsc::Sender>, Self) { + let (tx, rx) = tokio::sync::mpsc::channel(32); + (tx, Self { rx }) + } +} + +impl Body for ChannelBody +where + T: Buf, +{ + type Data = T; + type Error = tonic::Status; + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + let frame = ready!(self.project().rx.poll_recv(cx)); + Poll::Ready(frame.map(Ok)) + } +} + #[allow(dead_code)] pub fn measure_request_body_size_layer( bytes_sent_counter: Arc, -) -> MapRequestBodyLayer hyper::Body + Clone> { - MapRequestBodyLayer::new(move |mut body: hyper::Body| { - let (mut tx, new_body) = hyper::Body::channel(); +) -> MapRequestBodyLayer BoxBody + Clone> { + MapRequestBodyLayer::new(move |mut body: BoxBody| { + let (tx, new_body) = ChannelBody::new(); let bytes_sent_counter = bytes_sent_counter.clone(); tokio::spawn(async move { - while let Some(chunk) = body.data().await { + while let Some(chunk) = body.frame().await { let chunk = chunk.unwrap(); - println!("request body chunk size = {}", chunk.len()); - bytes_sent_counter.fetch_add(chunk.len(), SeqCst); - tx.send_data(chunk).await.unwrap(); - } - - if let Some(trailers) = body.trailers().await.unwrap() { - tx.send_trailers(trailers).await.unwrap(); + println!("request body chunk size = {}", frame_data_length(&chunk)); + bytes_sent_counter.fetch_add(frame_data_length(&chunk), SeqCst); + tx.send(chunk).await.unwrap(); } }); - new_body + new_body.boxed_unsync() }) } From 60594e7a71f9f95401c668ef5ad890c606bb212c Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 02:55:15 +0000 Subject: [PATCH 13/25] [tests] Convert complex_tower_middleware Body for hyper 1 The Body trait has changed (removed `poll_data` and `poll_trailers`, they are now combined in `poll_frame`) and so the codec must be re-written to merge those two methods. --- tests/integration_tests/Cargo.toml | 2 +- .../tests/complex_tower_middleware.rs | 11 ++--------- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/tests/integration_tests/Cargo.toml b/tests/integration_tests/Cargo.toml index 6a7ec8052..cfeebf725 100644 --- a/tests/integration_tests/Cargo.toml +++ b/tests/integration_tests/Cargo.toml @@ -23,7 +23,7 @@ hyper = "1" hyper-util = "0.1" tokio-stream = {version = "0.1.5", features = ["net"]} tower = {version = "0.4", features = []} -tower-http = { version = "0.4", features = ["set-header", "trace"] } +tower-http = { version = "0.5", features = ["set-header", "trace"] } tower-service = "0.3" tracing = "0.1" diff --git a/tests/integration_tests/tests/complex_tower_middleware.rs b/tests/integration_tests/tests/complex_tower_middleware.rs index 5d7690be3..b1b669426 100644 --- a/tests/integration_tests/tests/complex_tower_middleware.rs +++ b/tests/integration_tests/tests/complex_tower_middleware.rs @@ -97,17 +97,10 @@ where type Data = B::Data; type Error = BoxError; - fn poll_data( + fn poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { - unimplemented!() - } - - fn poll_trailers( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { + ) -> Poll, Self::Error>>> { unimplemented!() } } From 86d55afacb423d721aa7cd2378c83c731fa1ba11 Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 04:06:13 +0000 Subject: [PATCH 14/25] [tests] Convert integration_tests::origin to use http >= 1 body types The Body trait has changed (removed `poll_data` and `poll_trailers`, they are now combined in `poll_frame`) and so the codec must be re-written to merge those two methods. --- tests/integration_tests/tests/origin.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/integration_tests/tests/origin.rs b/tests/integration_tests/tests/origin.rs index c8140c79f..e41287245 100644 --- a/tests/integration_tests/tests/origin.rs +++ b/tests/integration_tests/tests/origin.rs @@ -7,7 +7,6 @@ use std::time::Duration; use tokio::sync::oneshot; use tonic::codegen::http::Request; use tonic::{ - body::BoxBody, transport::{Endpoint, Server}, Response, Status, }; @@ -77,9 +76,9 @@ struct OriginService { inner: S, } -impl Service> for OriginService +impl Service> for OriginService where - T: Service>, + T: Service>, T::Future: Send + 'static, T::Error: Into>, { @@ -91,7 +90,7 @@ where self.inner.poll_ready(cx).map_err(Into::into) } - fn call(&mut self, req: Request) -> Self::Future { + fn call(&mut self, req: Request) -> Self::Future { assert_eq!(req.uri().host(), Some("docs.rs")); let fut = self.inner.call(req); From 917cfe7349eb6c54cd51292a42795ae1701f9baa Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 02:56:09 +0000 Subject: [PATCH 15/25] Convert tonic-web to use http >= 1 body types The Body trait has changed (removed `poll_data` and `poll_trailers`, they are now combined in `poll_frame`) and so the codec must be re-written to merge those two methods. --- tonic-web/Cargo.toml | 2 +- tonic-web/src/call.rs | 134 ++++++++++-------- tonic-web/tests/integration/tests/grpc_web.rs | 5 +- 3 files changed, 78 insertions(+), 63 deletions(-) diff --git a/tonic-web/Cargo.toml b/tonic-web/Cargo.toml index 5813fd6a3..157605a95 100644 --- a/tonic-web/Cargo.toml +++ b/tonic-web/Cargo.toml @@ -25,7 +25,7 @@ pin-project = "1" tonic = {version = "0.11", path = "../tonic", default-features = false} tower-service = "0.3" tower-layer = "0.3" -tower-http = { version = "0.4", features = ["cors"] } +tower-http = { version = "0.5", features = ["cors"] } tracing = "0.1" [dev-dependencies] diff --git a/tonic-web/src/call.rs b/tonic-web/src/call.rs index f52087e9e..178e620ae 100644 --- a/tonic-web/src/call.rs +++ b/tonic-web/src/call.rs @@ -5,7 +5,7 @@ use std::task::{ready, Context, Poll}; use base64::Engine as _; use bytes::{Buf, BufMut, Bytes, BytesMut}; use http::{header, HeaderMap, HeaderName, HeaderValue}; -use http_body::{Body, SizeHint}; +use http_body::{Body, Frame, SizeHint}; use pin_project::pin_project; use tokio_stream::Stream; use tonic::Status; @@ -63,9 +63,9 @@ pub struct GrpcWebCall { #[pin] inner: B, buf: BytesMut, + decoded: BytesMut, direction: Direction, encoding: Encoding, - poll_trailers: bool, client: bool, trailers: Option, } @@ -75,9 +75,9 @@ impl Default for GrpcWebCall { Self { inner: Default::default(), buf: Default::default(), + decoded: Default::default(), direction: Direction::Empty, encoding: Encoding::None, - poll_trailers: Default::default(), client: Default::default(), trailers: Default::default(), } @@ -108,9 +108,12 @@ impl GrpcWebCall { (Direction::Encode, Encoding::Base64) => BUFFER_SIZE, _ => 0, }), + decoded: BytesMut::with_capacity(match direction { + Direction::Decode => BUFFER_SIZE, + _ => 0, + }), direction, encoding, - poll_trailers: true, client: true, trailers: None, } @@ -123,9 +126,9 @@ impl GrpcWebCall { (Direction::Encode, Encoding::Base64) => BUFFER_SIZE, _ => 0, }), + decoded: BytesMut::with_capacity(0), direction, encoding, - poll_trailers: true, client: false, trailers: None, } @@ -160,24 +163,37 @@ where B: Body, B::Error: Error, { + // Poll body for data, decoding (e.g. via Base64 if necessary) and returning frames + // to the caller. If the caller is a client, it should look for trailers before + // returning these frames. fn poll_decode( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Status>>> { match self.encoding { Encoding::Base64 => loop { if let Some(bytes) = self.as_mut().decode_chunk()? { - return Poll::Ready(Some(Ok(bytes))); + return Poll::Ready(Some(Ok(Frame::data(bytes)))); } let mut this = self.as_mut().project(); - match ready!(this.inner.as_mut().poll_data(cx)) { - Some(Ok(data)) => this.buf.put(data), + match ready!(this.inner.as_mut().poll_frame(cx)) { + Some(Ok(frame)) if frame.is_data() => this.buf.put(frame.into_data().unwrap()), + Some(Ok(frame)) if frame.is_trailers() => { + return Poll::Ready(Some(Err(internal_error( + "malformed base64 request has unencoded trailers", + )))) + } + Some(Ok(_)) => { + return Poll::Ready(Some(Err(internal_error("unexpected frame type")))) + } Some(Err(e)) => return Poll::Ready(Some(Err(internal_error(e)))), None => { return if this.buf.has_remaining() { Poll::Ready(Some(Err(internal_error("malformed base64 request")))) + } else if let Some(trailers) = this.trailers.take() { + Poll::Ready(Some(Ok(Frame::trailers(trailers)))) } else { Poll::Ready(None) } @@ -185,7 +201,7 @@ where } }, - Encoding::None => match ready!(self.project().inner.poll_data(cx)) { + Encoding::None => match ready!(self.project().inner.poll_frame(cx)) { Some(res) => Poll::Ready(Some(res.map_err(internal_error))), None => Poll::Ready(None), }, @@ -195,37 +211,33 @@ where fn poll_encode( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Status>>> { let mut this = self.as_mut().project(); - if let Some(mut res) = ready!(this.inner.as_mut().poll_data(cx)) { - if *this.encoding == Encoding::Base64 { - res = res.map(|b| crate::util::base64::STANDARD.encode(b).into()) - } - - return Poll::Ready(Some(res.map_err(internal_error))); - } - - // this flag is needed because the inner stream never - // returns Poll::Ready(None) when polled for trailers - if *this.poll_trailers { - return match ready!(this.inner.poll_trailers(cx)) { - Ok(Some(map)) => { - let mut frame = make_trailers_frame(map); + match ready!(this.inner.as_mut().poll_frame(cx)) { + Some(Ok(frame)) if frame.is_data() => { + let mut res = frame.into_data().unwrap(); - if *this.encoding == Encoding::Base64 { - frame = crate::util::base64::STANDARD.encode(frame).into_bytes(); - } + if *this.encoding == Encoding::Base64 { + let mut buf = Vec::with_capacity(res.len()); + buf.extend_from_slice(&res); + res = crate::util::base64::STANDARD.encode(buf).into(); + } - *this.poll_trailers = false; - Poll::Ready(Some(Ok(frame.into()))) + Poll::Ready(Some(Ok(Frame::data(res)))) + } + Some(Ok(frame)) if frame.is_trailers() => { + let trailers = frame.into_trailers().expect("must be trailers"); + let mut frame = make_trailers_frame(trailers); + if *this.encoding == Encoding::Base64 { + frame = crate::util::base64::STANDARD.encode(frame).into_bytes(); } - Ok(None) => Poll::Ready(None), - Err(e) => Poll::Ready(Some(Err(internal_error(e)))), - }; + Poll::Ready(Some(Ok(Frame::data(frame.into())))) + } + Some(Ok(_)) => Poll::Ready(Some(Err(internal_error("unexepected frame type")))), + Some(Err(e)) => Poll::Ready(Some(Err(internal_error(e)))), + None => Poll::Ready(None), } - - Poll::Ready(None) } } @@ -237,28 +249,34 @@ where type Data = Bytes; type Error = Status; - fn poll_data( + fn poll_frame( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Self::Error>>> { if self.client && self.direction == Direction::Decode { let mut me = self.as_mut(); loop { - let incoming_buf = match ready!(me.as_mut().poll_decode(cx)) { - Some(Ok(incoming_buf)) => incoming_buf, - None => { - // TODO: Consider eofing here? - // Even if the buffer has more data, this will hit the eof branch - // of decode in tonic - return Poll::Ready(None); + match ready!(me.as_mut().poll_decode(cx)) { + Some(Ok(incoming_buf)) if incoming_buf.is_data() => { + me.as_mut() + .project() + .decoded + .put(incoming_buf.into_data().unwrap()); + } + Some(Ok(incoming_buf)) if incoming_buf.is_trailers() => { + let trailers = incoming_buf.into_trailers().unwrap(); + me.as_mut().project().trailers.replace(trailers); + continue; } + Some(Ok(_)) => unreachable!("unexpected frame type"), + None => {} // No more data to decode, time to look for trailers Some(Err(e)) => return Poll::Ready(Some(Err(e))), }; - let buf = &mut me.as_mut().project().buf; - - buf.put(incoming_buf); + // Hold the incoming, decoded data until we have a full message + // or trailers to return. + let buf = me.as_mut().project().decoded; return match find_trailers(&buf[..])? { FindTrailers::Trailer(len) => { @@ -266,20 +284,24 @@ where let msg_buf = buf.copy_to_bytes(len); match decode_trailers_frame(buf.split().freeze()) { Ok(Some(trailers)) => { - self.project().trailers.replace(trailers); + me.as_mut().project().trailers.replace(trailers); } Err(e) => return Poll::Ready(Some(Err(e))), _ => {} } if msg_buf.has_remaining() { - Poll::Ready(Some(Ok(msg_buf))) + Poll::Ready(Some(Ok(Frame::data(msg_buf)))) + } else if let Some(trailers) = me.as_mut().project().trailers.take() { + Poll::Ready(Some(Ok(Frame::trailers(trailers)))) } else { Poll::Ready(None) } } FindTrailers::IncompleteBuf => continue, - FindTrailers::Done(len) => Poll::Ready(Some(Ok(buf.split_to(len).freeze()))), + FindTrailers::Done(len) => { + Poll::Ready(Some(Ok(Frame::data(buf.split_to(len).freeze())))) + } }; } } @@ -291,14 +313,6 @@ where } } - fn poll_trailers( - self: Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll>, Self::Error>> { - let trailers = self.project().trailers.take(); - Poll::Ready(Ok(trailers)) - } - fn is_end_stream(&self) -> bool { self.inner.is_end_stream() } @@ -313,10 +327,10 @@ where B: Body, B::Error: Error, { - type Item = Result; + type Item = Result, Status>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Body::poll_data(self, cx) + self.poll_frame(cx) } } diff --git a/tonic-web/tests/integration/tests/grpc_web.rs b/tonic-web/tests/integration/tests/grpc_web.rs index 037ff8dad..96720b19e 100644 --- a/tonic-web/tests/integration/tests/grpc_web.rs +++ b/tonic-web/tests/integration/tests/grpc_web.rs @@ -3,6 +3,7 @@ use std::net::SocketAddr; use base64::Engine as _; use bytes::{Buf, BufMut, Bytes, BytesMut}; use http_body_util::{BodyExt as _, Full}; +use hyper::body::Incoming; use hyper::http::{header, StatusCode}; use hyper::{Method, Request, Uri}; use prost::Message; @@ -131,8 +132,8 @@ fn build_request(base_uri: String, content_type: &str, accept: &str) -> Request< .unwrap() } -async fn decode_body(body: Body, content_type: &str) -> (Output, Bytes) { - let mut body = hyper::body::to_bytes(body).await.unwrap(); +async fn decode_body(body: Incoming, content_type: &str) -> (Output, Bytes) { + let mut body = body.collect().await.unwrap().to_bytes(); if content_type == "application/grpc-web-text+proto" { body = integration::util::base64::STANDARD From 0413bb1c56326bed3274d2aec07d17d4dfef3d46 Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 04:26:02 +0000 Subject: [PATCH 16/25] Adapt for hyper-specific IO traits hyper >= 1 provides its own I/O traits (Read & Write) instead of relying on the equivalent traits from `tokio`. Then, `hyper-util` provides adaptor structs to wrap `tokio` I/O objects and implement the hyper equivalents. Therefore, we update the appropriate bounds to use the hyper traits, and update the I/O objects so that they are wrapped in the tokio to hyper adaptor. Co-authored-by: Ivan Krivosheev Co-authored-by: Allan Zhang --- examples/Cargo.toml | 12 +++++----- examples/src/grpc-web/client.rs | 1 + examples/src/h2c/client.rs | 2 ++ examples/src/h2c/server.rs | 1 + examples/src/mock/mock.rs | 3 ++- examples/src/uds/client.rs | 6 +++-- tests/compression/src/util.rs | 2 +- tests/integration_tests/tests/connect_info.rs | 8 ++++++- .../tests/max_message_size.rs | 5 ++-- tests/integration_tests/tests/status.rs | 3 ++- tonic-web/tests/integration/tests/grpc_web.rs | 3 +++ tonic/Cargo.toml | 3 +-- tonic/src/transport/channel/endpoint.rs | 5 ++-- tonic/src/transport/channel/mod.rs | 10 ++++---- tonic/src/transport/server/mod.rs | 1 + tonic/src/transport/service/connection.rs | 8 +++---- tonic/src/transport/service/connector.rs | 24 +++++++++++-------- tonic/src/transport/service/io.rs | 13 +++++----- 18 files changed, 66 insertions(+), 44 deletions(-) diff --git a/examples/Cargo.toml b/examples/Cargo.toml index a672287d8..e04868826 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -271,21 +271,21 @@ routeguide = ["dep:async-stream", "tokio-stream", "dep:rand", "dep:serde", "dep: reflection = ["dep:tonic-reflection"] autoreload = ["tokio-stream/net", "dep:listenfd"] health = ["dep:tonic-health"] -grpc-web = ["dep:tonic-web", "dep:bytes", "dep:http", "dep:hyper", "dep:tracing-subscriber", "dep:tower"] +grpc-web = ["dep:tonic-web", "dep:bytes", "dep:http", "dep:hyper", "dep:hyper-util", "dep:tracing-subscriber", "dep:tower"] tracing = ["dep:tracing", "dep:tracing-subscriber"] -uds = ["tokio-stream/net", "dep:tower", "dep:hyper"] +uds = ["tokio-stream/net", "dep:tower", "dep:hyper", "dep:hyper-util"] streaming = ["tokio-stream", "dep:h2"] -mock = ["tokio-stream", "dep:tower"] -tower = ["dep:hyper", "dep:tower", "dep:http"] +mock = ["tokio-stream", "dep:tower", "dep:hyper-util"] +tower = ["dep:hyper", "dep:hyper-util", "dep:tower", "dep:http"] json-codec = ["dep:serde", "dep:serde_json", "dep:bytes"] compression = ["tonic/gzip"] tls = ["tonic/tls"] -tls-rustls = ["dep:hyper", "dep:hyper-rustls", "dep:tower", "tower-http/util", "tower-http/add-extension", "dep:rustls-pemfile", "dep:tokio-rustls"] +tls-rustls = ["dep:hyper", "dep:hyper-util", "dep:hyper-rustls", "dep:tower", "tower-http/util", "tower-http/add-extension", "dep:rustls-pemfile", "dep:tokio-rustls"] dynamic-load-balance = ["dep:tower"] timeout = ["tokio/time", "dep:tower"] tls-client-auth = ["tonic/tls"] types = ["dep:tonic-types"] -h2c = ["dep:hyper", "dep:tower", "dep:http"] +h2c = ["dep:hyper", "dep:tower", "dep:http", "dep:hyper-util"] cancellation = ["dep:tokio-util"] full = ["gcp", "routeguide", "reflection", "autoreload", "health", "grpc-web", "tracing", "uds", "streaming", "mock", "tower", "json-codec", "compression", "tls", "tls-rustls", "dynamic-load-balance", "timeout", "tls-client-auth", "types", "cancellation", "h2c"] diff --git a/examples/src/grpc-web/client.rs b/examples/src/grpc-web/client.rs index a16ac674a..fd20a788b 100644 --- a/examples/src/grpc-web/client.rs +++ b/examples/src/grpc-web/client.rs @@ -1,4 +1,5 @@ use hello_world::{greeter_client::GreeterClient, HelloRequest}; +use hyper_util::rt::TokioExecutor; use tonic_web::GrpcWebClientLayer; pub mod hello_world { diff --git a/examples/src/h2c/client.rs b/examples/src/h2c/client.rs index 624ea175f..2f9f90a79 100644 --- a/examples/src/h2c/client.rs +++ b/examples/src/h2c/client.rs @@ -2,6 +2,7 @@ use hello_world::greeter_client::GreeterClient; use hello_world::HelloRequest; use http::Uri; use hyper::Client; +use hyper_util::rt::TokioExecutor; pub mod hello_world { tonic::include_proto!("helloworld"); @@ -12,6 +13,7 @@ async fn main() -> Result<(), Box> { let origin = Uri::from_static("http://[::1]:50051"); let h2c_client = h2c::H2cChannel { client: Client::new(), + client: Client::builder(TokioExecutor::new()).build_http(), }; let mut client = GreeterClient::with_origin(h2c_client, origin); diff --git a/examples/src/h2c/server.rs b/examples/src/h2c/server.rs index 21dcc1f35..b1d4c0a8d 100644 --- a/examples/src/h2c/server.rs +++ b/examples/src/h2c/server.rs @@ -1,3 +1,4 @@ +use hyper_util::rt::{TokioExecutor, TokioIo}; use tonic::{transport::Server, Request, Response, Status}; use hello_world::greeter_server::{Greeter, GreeterServer}; diff --git a/examples/src/mock/mock.rs b/examples/src/mock/mock.rs index 0d3754921..6c26a6735 100644 --- a/examples/src/mock/mock.rs +++ b/examples/src/mock/mock.rs @@ -1,3 +1,4 @@ +use hyper_util::rt::TokioIo; use tonic::{ transport::{Endpoint, Server, Uri}, Request, Response, Status, @@ -36,7 +37,7 @@ async fn main() -> Result<(), Box> { async move { if let Some(client) = client { - Ok(client) + Ok(TokioIo::new(client)) } else { Err(std::io::Error::new( std::io::ErrorKind::Other, diff --git a/examples/src/uds/client.rs b/examples/src/uds/client.rs index e78531ac4..9a09e6981 100644 --- a/examples/src/uds/client.rs +++ b/examples/src/uds/client.rs @@ -5,6 +5,7 @@ pub mod hello_world { } use hello_world::{greeter_client::GreeterClient, HelloRequest}; +use hyper_util::rt::TokioIo; #[cfg(unix)] use tokio::net::UnixStream; use tonic::transport::{Endpoint, Uri}; @@ -16,12 +17,13 @@ async fn main() -> Result<(), Box> { // We will ignore this uri because uds do not use it // if your connector does use the uri it will be provided // as the request to the `MakeConnection`. + let channel = Endpoint::try_from("http://[::]:50051")? - .connect_with_connector(service_fn(|_: Uri| { + .connect_with_connector(service_fn(|_: Uri| async { let path = "/tmp/tonic/helloworld"; // Connect to a Uds socket - UnixStream::connect(path) + Ok::<_, std::io::Error>(TokioIo::new(UnixStream::connect(path).await?)) })) .await?; diff --git a/tests/compression/src/util.rs b/tests/compression/src/util.rs index 99afded3f..d7e250ce4 100644 --- a/tests/compression/src/util.rs +++ b/tests/compression/src/util.rs @@ -139,7 +139,7 @@ pub async fn mock_io_channel(client: tokio::io::DuplexStream) -> Channel { Endpoint::try_from("http://[::]:50051") .unwrap() .connect_with_connector(service_fn(move |_: Uri| { - let client = client.take().unwrap(); + let client = TokioIo::new(client.take().unwrap()); async move { Ok::<_, std::io::Error>(client) } })) .await diff --git a/tests/integration_tests/tests/connect_info.rs b/tests/integration_tests/tests/connect_info.rs index 94fac8221..e87bb858f 100644 --- a/tests/integration_tests/tests/connect_info.rs +++ b/tests/integration_tests/tests/connect_info.rs @@ -51,6 +51,9 @@ async fn getting_connect_info() { #[cfg(unix)] pub mod unix { + use std::io; + + use hyper_util::rt::TokioIo; use tokio::{ net::{UnixListener, UnixStream}, sync::oneshot, @@ -106,7 +109,10 @@ pub mod unix { let path = unix_socket_path.clone(); let channel = Endpoint::try_from("http://[::]:50051") .unwrap() - .connect_with_connector(service_fn(move |_: Uri| UnixStream::connect(path.clone()))) + .connect_with_connector(service_fn(move |_: Uri| { + let path = path.clone(); + async move { Ok::<_, io::Error>(TokioIo::new(UnixStream::connect(path).await?)) } + })) .await .unwrap(); diff --git a/tests/integration_tests/tests/max_message_size.rs b/tests/integration_tests/tests/max_message_size.rs index 9ae524dbc..f03699cdf 100644 --- a/tests/integration_tests/tests/max_message_size.rs +++ b/tests/integration_tests/tests/max_message_size.rs @@ -1,5 +1,6 @@ use std::pin::Pin; +use hyper_util::rt::TokioIo; use integration_tests::{ pb::{test1_client, test1_server, Input1, Output1}, trace_init, @@ -163,7 +164,7 @@ async fn response_stream_limit() { async move { if let Some(client) = client { - Ok(client) + Ok(TokioIo::new(client)) } else { Err(std::io::Error::new( std::io::ErrorKind::Other, @@ -332,7 +333,7 @@ async fn max_message_run(case: &TestCase) -> Result<(), Status> { async move { if let Some(client) = client { - Ok(client) + Ok(TokioIo::new(client)) } else { Err(std::io::Error::new( std::io::ErrorKind::Other, diff --git a/tests/integration_tests/tests/status.rs b/tests/integration_tests/tests/status.rs index 3fdabcd36..df6bc4b3b 100644 --- a/tests/integration_tests/tests/status.rs +++ b/tests/integration_tests/tests/status.rs @@ -1,5 +1,6 @@ use bytes::Bytes; use http::Uri; +use hyper_util::rt::TokioIo; use integration_tests::mock::MockStream; use integration_tests::pb::{ test_client, test_server, test_stream_client, test_stream_server, Input, InputStream, Output, @@ -183,7 +184,7 @@ async fn status_from_server_stream_with_source() { let channel = Endpoint::try_from("http://[::]:50051") .unwrap() .connect_with_connector_lazy(tower::service_fn(move |_: Uri| async move { - Err::(std::io::Error::new(std::io::ErrorKind::Other, "WTF")) + Err::, _>(std::io::Error::new(std::io::ErrorKind::Other, "WTF")) })); let mut client = test_stream_client::TestStreamClient::new(channel); diff --git a/tonic-web/tests/integration/tests/grpc_web.rs b/tonic-web/tests/integration/tests/grpc_web.rs index 96720b19e..b46d98d45 100644 --- a/tonic-web/tests/integration/tests/grpc_web.rs +++ b/tonic-web/tests/integration/tests/grpc_web.rs @@ -6,6 +6,7 @@ use http_body_util::{BodyExt as _, Full}; use hyper::body::Incoming; use hyper::http::{header, StatusCode}; use hyper::{Method, Request, Uri}; +use hyper_util::rt::TokioExecutor; use prost::Message; use tokio::net::TcpListener; use tokio_stream::wrappers::TcpListenerStream; @@ -20,6 +21,7 @@ use tonic_web::GrpcWebLayer; async fn binary_request() { let server_url = spawn().await; let client = Client::new(); + let client = Client::builder(TokioExecutor::new()).build_http(); let req = build_request(server_url, "grpc-web", "grpc-web"); let res = client.request(req).await.unwrap(); @@ -43,6 +45,7 @@ async fn binary_request() { async fn text_request() { let server_url = spawn().await; let client = Client::new(); + let client = Client::builder(TokioExecutor::new()).build_http(); let req = build_request(server_url, "grpc-web-text", "grpc-web-text"); let res = client.request(req).await.unwrap(); diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index 1934c416e..51b38a6e6 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -37,10 +37,9 @@ transport = [ "dep:axum", "channel", "dep:h2", - "dep:hyper", "dep:tokio", "tokio?/net", "tokio?/time", + "dep:hyper", "dep:hyper-util", "dep:hyper-timeout", "dep:tower", - "dep:hyper-timeout", ] channel = [] diff --git a/tonic/src/transport/channel/endpoint.rs b/tonic/src/transport/channel/endpoint.rs index 995e2a15b..584c56f8c 100644 --- a/tonic/src/transport/channel/endpoint.rs +++ b/tonic/src/transport/channel/endpoint.rs @@ -7,6 +7,7 @@ use crate::transport::service::TlsConnector; use crate::transport::{service::SharedExec, Error, Executor}; use bytes::Bytes; use http::{uri::Uri, HeaderValue}; +use hyper::rt; use std::{fmt, future::Future, pin::Pin, str::FromStr, time::Duration}; use tower::make::MakeConnection; @@ -369,7 +370,7 @@ impl Endpoint { pub async fn connect_with_connector(&self, connector: C) -> Result where C: MakeConnection + Send + 'static, - C::Connection: Unpin + Send + 'static, + C::Response: rt::Read + rt::Write + Send + Unpin + 'static, C::Future: Send + 'static, crate::Error: From + Send + 'static, { @@ -394,7 +395,7 @@ impl Endpoint { pub fn connect_with_connector_lazy(&self, connector: C) -> Channel where C: MakeConnection + Send + 'static, - C::Connection: Unpin + Send + 'static, + C::Response: rt::Read + rt::Write + Send + Unpin + 'static, C::Future: Send + 'static, crate::Error: From + Send + 'static, { diff --git a/tonic/src/transport/channel/mod.rs b/tonic/src/transport/channel/mod.rs index 6a857dff1..3e5869bcb 100644 --- a/tonic/src/transport/channel/mod.rs +++ b/tonic/src/transport/channel/mod.rs @@ -25,11 +25,9 @@ use std::{ pin::Pin, task::{ready, Context, Poll}, }; -use tokio::{ - io::{AsyncRead, AsyncWrite}, - sync::mpsc::{channel, Sender}, -}; +use tokio::sync::mpsc::{channel, Sender}; +use hyper::rt; use tower::balance::p2c::Balance; use tower::{ buffer::{self, Buffer}, @@ -152,7 +150,7 @@ impl Channel { C: Service + Send + 'static, C::Error: Into + Send, C::Future: Unpin + Send, - C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static, + C::Response: rt::Read + rt::Write + HyperConnection + Unpin + Send + 'static, { let buffer_size = endpoint.buffer_size.unwrap_or(DEFAULT_BUFFER_SIZE); let executor = endpoint.executor.clone(); @@ -169,7 +167,7 @@ impl Channel { C: Service + Send + 'static, C::Error: Into + Send, C::Future: Unpin + Send, - C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static, + C::Response: rt::Read + rt::Write + HyperConnection + Unpin + Send + 'static, { let buffer_size = endpoint.buffer_size.unwrap_or(DEFAULT_BUFFER_SIZE); let executor = endpoint.executor.clone(); diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs index ad930c617..0c64ba6d0 100644 --- a/tonic/src/transport/server/mod.rs +++ b/tonic/src/transport/server/mod.rs @@ -13,6 +13,7 @@ pub use super::service::Routes; pub use super::service::RoutesBuilder; pub use conn::{Connected, TcpConnectInfo}; +use hyper_util::rt::{TokioExecutor, TokioIo}; #[cfg(feature = "tls")] pub use tls::ServerTlsConfig; diff --git a/tonic/src/transport/service/connection.rs b/tonic/src/transport/service/connection.rs index 1fa059c96..8e1f52c5f 100644 --- a/tonic/src/transport/service/connection.rs +++ b/tonic/src/transport/service/connection.rs @@ -7,11 +7,11 @@ use http::Uri; use hyper::client::conn::Builder; use hyper::client::connect::Connection as HyperConnection; use hyper::client::service::Connect as HyperConnect; +use hyper::rt; use std::{ fmt, task::{Context, Poll}, }; -use tokio::io::{AsyncRead, AsyncWrite}; use tower::load::Load; use tower::{ layer::Layer, @@ -34,7 +34,7 @@ impl Connection { C: Service + Send + 'static, C::Error: Into + Send, C::Future: Unpin + Send, - C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static, + C::Response: rt::Read + rt::Write + Unpin + Send + 'static, { let mut settings = Builder::new() .http2_initial_stream_window_size(endpoint.init_stream_window_size) @@ -83,7 +83,7 @@ impl Connection { C: Service + Send + 'static, C::Error: Into + Send, C::Future: Unpin + Send, - C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static, + C::Response: rt::Read + rt::Write + Unpin + Send + 'static, { Self::new(connector, endpoint, false).ready_oneshot().await } @@ -93,7 +93,7 @@ impl Connection { C: Service + Send + 'static, C::Error: Into + Send, C::Future: Unpin + Send, - C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static, + C::Response: rt::Read + rt::Write + Unpin + Send + 'static, { Self::new(connector, endpoint, true) } diff --git a/tonic/src/transport/service/connector.rs b/tonic/src/transport/service/connector.rs index 12336813a..8219fe8d9 100644 --- a/tonic/src/transport/service/connector.rs +++ b/tonic/src/transport/service/connector.rs @@ -6,7 +6,11 @@ use http::Uri; #[cfg(feature = "tls")] use std::fmt; use std::task::{Context, Poll}; -use tower::make::MakeConnection; + +use hyper::rt; + +#[cfg(feature = "tls")] +use hyper_util::rt::TokioIo; use tower_service::Service; pub(crate) struct Connector { @@ -51,8 +55,8 @@ impl Connector { impl Service for Connector where - C: MakeConnection, - C::Connection: Unpin + Send + 'static, + C: Service, + C::Response: rt::Read + rt::Write + Unpin + Send + 'static, C::Future: Send + 'static, crate::Error: From + Send + 'static, { @@ -61,7 +65,7 @@ where type Future = BoxFuture<'static, Result>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - MakeConnection::poll_ready(&mut self.inner, cx).map_err(Into::into) + self.inner.poll_ready(cx).map_err(Into::into) } fn call(&mut self, uri: Uri) -> Self::Future { @@ -73,7 +77,7 @@ where #[cfg(feature = "tls")] let is_https = uri.scheme_str() == Some("https"); - let connect = self.inner.make_connection(uri); + let connect = self.inner.call(uri); Box::pin(async move { let io = connect.await?; @@ -81,12 +85,12 @@ where #[cfg(feature = "tls")] { if let Some(tls) = tls { - if is_https { - let conn = tls.connect(io).await?; - return Ok(BoxedIo::new(conn)); + return if is_https { + let io = tls.connect(TokioIo::new(io)).await?; + Ok(io) } else { - return Ok(BoxedIo::new(io)); - } + Ok(BoxedIo::new(io)) + }; } else if is_https { return Err(HttpsUriWithoutTlsSupport(()).into()); } diff --git a/tonic/src/transport/service/io.rs b/tonic/src/transport/service/io.rs index 2230b9b2e..cb2296cac 100644 --- a/tonic/src/transport/service/io.rs +++ b/tonic/src/transport/service/io.rs @@ -1,5 +1,6 @@ use crate::transport::server::Connected; -use hyper::client::connect::{Connected as HyperConnected, Connection}; +use hyper::rt; +use hyper_util::client::legacy::connect::{Connected as HyperConnected, Connection}; use std::io; use std::io::IoSlice; use std::pin::Pin; @@ -9,11 +10,11 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio_rustls::server::TlsStream; pub(in crate::transport) trait Io: - AsyncRead + AsyncWrite + Send + 'static + rt::Read + rt::Write + Send + 'static { } -impl Io for T where T: AsyncRead + AsyncWrite + Send + 'static {} +impl Io for T where T: rt::Read + rt::Write + Send + 'static {} pub(crate) struct BoxedIo(Pin>); @@ -40,17 +41,17 @@ impl Connected for BoxedIo { #[derive(Copy, Clone)] pub(crate) struct NoneConnectInfo; -impl AsyncRead for BoxedIo { +impl rt::Read for BoxedIo { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, + buf: rt::ReadBufCursor<'_>, ) -> Poll> { Pin::new(&mut self.0).poll_read(cx, buf) } } -impl AsyncWrite for BoxedIo { +impl rt::Write for BoxedIo { fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, From c5f6ded857c19a61553e6e91d63221f0a6146574 Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 03:00:15 +0000 Subject: [PATCH 17/25] Upgrade axum to 0.7 Axum must be >= 0.7 to support hyper >= 1 Doing this also involves changing the Body type used. Since hyper >= 1 does not provide a generic body type, Axum and tonic both use `BoxBody` to provide a pointer to a Body. This changes the trait bounds required for methods which accept additional Serivces to be run alongside the primary GRPC service, since those will be routed with Axum, and therefore must accept a BoxBody. Co-authored-by: Ivan Krivosheev Co-authored-by: Allan Zhang --- tonic/Cargo.toml | 2 +- tonic/src/transport/mod.rs | 7 +- tonic/src/transport/server/incoming.rs | 4 +- tonic/src/transport/server/mod.rs | 340 ++++++++++++++++++++----- tonic/src/transport/service/router.rs | 53 +++- 5 files changed, 325 insertions(+), 81 deletions(-) diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index 51b38a6e6..da4482291 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -76,7 +76,7 @@ socket2 = { version = ">=0.4.7, <0.6.0", optional = true, features = ["all"] } tokio = {version = "1", default-features = false, optional = true} tokio-stream = { version = "0.1", features = ["net"] } tower = {version = "0.4.7", default-features = false, features = ["balance", "buffer", "discover", "limit", "load", "make", "timeout", "util"], optional = true} -axum = {version = "0.6.9", default-features = false, optional = true} +axum = {version = "0.7", default-features = false, optional = true} # rustls rustls-pemfile = { version = "2.0", optional = true } diff --git a/tonic/src/transport/mod.rs b/tonic/src/transport/mod.rs index 1f5843754..0301db3fd 100644 --- a/tonic/src/transport/mod.rs +++ b/tonic/src/transport/mod.rs @@ -56,8 +56,8 @@ //! # async fn do_thing() -> Result<(), Box> { //! # #[derive(Clone)] //! # pub struct Svc; -//! # type Response = hyper::Response; //! # impl Service> for Svc { +//! # type Response = hyper::Response; //! # type Error = Infallible; //! # type Future = std::future::Ready>; //! # fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> std::task::Poll> { @@ -110,7 +110,7 @@ pub use self::service::grpc_timeout::TimeoutExpired; #[cfg(feature = "tls")] #[cfg_attr(docsrs, doc(cfg(feature = "tls")))] pub use self::tls::Certificate; -pub use axum::{body::BoxBody as AxumBoxBody, Router as AxumRouter}; +pub use axum::{body::Body as AxumBoxBody, Router as AxumRouter}; pub use hyper::body::Body; pub use hyper::Uri; @@ -125,8 +125,5 @@ pub use self::server::ServerTlsConfig; #[cfg(feature = "tls")] #[cfg_attr(docsrs, doc(cfg(feature = "tls")))] pub use self::tls::Identity; -use crate::body::BoxBody; type BoxFuture<'a, T> = std::pin::Pin + Send + 'a>>; -pub(crate) type Response = http::Response; -pub(crate) type Request = http::Request; diff --git a/tonic/src/transport/server/incoming.rs b/tonic/src/transport/server/incoming.rs index bc1bb7650..ede62a32d 100644 --- a/tonic/src/transport/server/incoming.rs +++ b/tonic/src/transport/server/incoming.rs @@ -139,13 +139,13 @@ impl TcpIncoming { /// ```no_run /// # use tower_service::Service; /// # use http::{request::Request, response::Response}; - /// # use tonic::{body::BoxBody, server::NamedService, transport::{Body, Server, server::TcpIncoming}}; + /// # use tonic::{body::BoxBody, server::NamedService, transport::{Server, server::TcpIncoming}}; /// # use core::convert::Infallible; /// # use std::error::Error; /// # fn main() { } // Cannot have type parameters, hence instead define: /// # fn run(some_service: S) -> Result<(), Box> /// # where - /// # S: Service, Response = Response, Error = Infallible> + NamedService + Clone + Send + 'static, + /// # S: Service, Response = Response, Error = Infallible> + NamedService + Clone + Send + 'static, /// # S::Future: Send + 'static, /// # { /// // Find a free port diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs index 0c64ba6d0..fb63058ad 100644 --- a/tonic/src/transport/server/mod.rs +++ b/tonic/src/transport/server/mod.rs @@ -9,6 +9,13 @@ mod tls; #[cfg(unix)] mod unix; +use tokio_stream::StreamExt as _; +use tower::util::BoxCloneService; +use tower::util::Oneshot; +use tower::ServiceExt; +use tracing::debug; +use tracing::trace; + pub use super::service::Routes; pub use super::service::RoutesBuilder; @@ -42,15 +49,16 @@ use crate::server::NamedService; use bytes::Bytes; use http::{Request, Response}; use http_body_util::BodyExt; -use hyper::server::accept; +use hyper::body::Incoming; use pin_project::pin_project; +use std::future::poll_fn; use std::{ convert::Infallible, fmt, future::{self, Future}, marker::PhantomData, net::SocketAddr, - pin::Pin, + pin::{pin, Pin}, sync::Arc, task::{ready, Context, Poll}, time::Duration, @@ -65,18 +73,17 @@ use tower::{ Service, ServiceBuilder, }; -type TraceInterceptor = Arc) -> tracing::Span + Send + Sync + 'static>; type BoxHttpBody = crate::body::BoxBody; -type Body = hyper::body::Incoming; // Temporary type alias to ease transition type BoxError = crate::Error; -type BoxService = tower::util::BoxCloneService, Response, crate::Error>; +type BoxService = tower::util::BoxCloneService, Response, crate::Error>; +type TraceInterceptor = Arc) -> tracing::Span + Send + Sync + 'static>; const DEFAULT_HTTP2_KEEPALIVE_TIMEOUT_SECS: u64 = 20; /// A default batteries included `transport` server. /// -/// This is a wrapper around [`hyper::Server`] and provides an easy builder -/// pattern style builder [`Server`]. This builder exposes easy configuration parameters +/// This provides an easy builder pattern style builder [`Server`] on top of +/// `hyper` connections. This builder exposes easy configuration parameters /// for providing a fully featured http2 based gRPC server. This should provide /// a very good out of the box http2 server for use with tonic but is also a /// reference implementation that should be a good starting point for anyone @@ -126,7 +133,7 @@ impl Default for Server { } } -/// A stack based `Service` router. +/// A stack based [`Service`] router. #[derive(Debug)] pub struct Router { server: Server, @@ -363,7 +370,7 @@ impl Server { /// route around different services. pub fn add_service(&mut self, svc: S) -> Router where - S: Service, Response = Response, Error = Infallible> + S: Service, Response = Response, Error = Infallible> + NamedService + Clone + Send @@ -384,7 +391,7 @@ impl Server { /// As a result, one cannot use this to toggle between two identically named implementations. pub fn add_optional_service(&mut self, svc: Option) -> Router where - S: Service, Response = Response, Error = Infallible> + S: Service, Response = Response, Error = Infallible> + NamedService + Clone + Send @@ -498,9 +505,11 @@ impl Server { ) -> Result<(), super::Error> where L: Layer, - L::Service: Service, Response = Response> + Clone + Send + 'static, - <>::Service as Service>>::Future: Send + 'static, - <>::Service as Service>>::Error: Into + Send, + L::Service: + Service, Response = Response> + Clone + Send + 'static, + <>::Service as Service>>::Future: Send + 'static, + <>::Service as Service>>::Error: + Into + Send + 'static, I: Stream>, IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, IO::ConnectInfo: Clone + Send + Sync + 'static, @@ -527,10 +536,8 @@ impl Server { let svc = self.service_builder.service(svc); - let tcp = incoming::tcp_incoming(incoming, self); - let incoming = accept::from_stream::<_, _, crate::Error>(tcp); - - let svc = MakeSvc { + let incoming = incoming::tcp_incoming(incoming, self); + let mut svc = MakeSvc { inner: svc, concurrency_limit, timeout, @@ -538,31 +545,204 @@ impl Server { _io: PhantomData, }; - let server = hyper::Server::builder(incoming) - .http2_only(http2_only) - .http2_initial_connection_window_size(init_connection_window_size) - .http2_initial_stream_window_size(init_stream_window_size) - .http2_max_concurrent_streams(max_concurrent_streams) - .http2_keep_alive_interval(http2_keepalive_interval) - .http2_keep_alive_timeout(http2_keepalive_timeout) - .http2_adaptive_window(http2_adaptive_window.unwrap_or_default()) - .http2_max_pending_accept_reset_streams(http2_max_pending_accept_reset_streams) - .http2_max_frame_size(max_frame_size); - - if let Some(signal) = signal { - server - .serve(svc) - .with_graceful_shutdown(signal) - .await - .map_err(super::Error::from_source)? - } else { - server.serve(svc).await.map_err(super::Error::from_source)?; + let server = { + let mut builder = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()); + + if http2_only { + builder = builder.http2_only(); + } + + builder + .http2() + .initial_connection_window_size(init_connection_window_size) + .initial_stream_window_size(init_stream_window_size) + .max_concurrent_streams(max_concurrent_streams) + .keep_alive_interval(http2_keepalive_interval) + .keep_alive_timeout(http2_keepalive_timeout) + .adaptive_window(http2_adaptive_window.unwrap_or_default()) + .max_pending_accept_reset_streams(http2_max_pending_accept_reset_streams) + .max_frame_size(max_frame_size); + + builder + }; + + let (signal_tx, signal_rx) = tokio::sync::watch::channel(()); + let signal_tx = Arc::new(signal_tx); + + let graceful = signal.is_some(); + let mut sig = pin!(Fuse { inner: signal }); + let mut incoming = pin!(incoming); + + loop { + tokio::select! { + _ = &mut sig => { + trace!("signal received, shutting down"); + break; + }, + io = incoming.next() => { + let io = match io { + Some(Ok(io)) => io, + Some(Err(e)) => { + trace!("error accepting connection: {:#}", e); + continue; + }, + None => { + break + }, + }; + + trace!("connection accepted"); + + poll_fn(|cx| svc.poll_ready(cx)) + .await + .map_err(super::Error::from_source)?; + + let req_svc = svc + .call(&io) + .await + .map_err(super::Error::from_source)?; + let hyper_svc = TowerToHyperService::new(req_svc); + + serve_connection(io, hyper_svc, server.clone(), graceful.then(|| signal_rx.clone())); + } + } + } + + if graceful { + let _ = signal_tx.send(()); + drop(signal_rx); + trace!( + "waiting for {} connections to close", + signal_tx.receiver_count() + ); + + // Wait for all connections to close + signal_tx.closed().await; } Ok(()) } } +// This is moved to its own function as a way to get around +// https://github.com/rust-lang/rust/issues/102211 +fn serve_connection( + io: ServerIo, + hyper_svc: TowerToHyperService, + builder: ConnectionBuilder, + mut watcher: Option>, +) where + S: Service, Response = Response> + Clone + Send + 'static, + S::Future: Send + 'static, + S::Error: Into + Send, + IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, + IO::ConnectInfo: Clone + Send + Sync + 'static, +{ + tokio::spawn(async move { + { + let mut sig = pin!(Fuse { + inner: watcher.as_mut().map(|w| w.changed()), + }); + + let mut conn = pin!(builder.serve_connection(TokioIo::new(io), hyper_svc)); + + loop { + tokio::select! { + rv = &mut conn => { + if let Err(err) = rv { + debug!("failed serving connection: {:#}", err); + } + break; + }, + _ = &mut sig => { + conn.as_mut().graceful_shutdown(); + } + } + } + } + + drop(watcher); + trace!("connection closed"); + }); +} + +type ConnectionBuilder = hyper_util::server::conn::auto::Builder; + +/// An adaptor which converts a [`tower::Service`] to a [`hyper::service::Service`]. +/// +/// The [`hyper::service::Service`] trait is used by hyper to handle incoming requests, +/// and does not support the `poll_ready` method that is used by tower services. +#[derive(Debug, Copy, Clone)] +pub struct TowerToHyperService { + service: S, +} + +impl TowerToHyperService { + /// Create a new `TowerToHyperService` from a tower service. + pub fn new(service: S) -> Self { + Self { service } + } + + /// Extract the inner tower service. + pub fn into_inner(self) -> S { + self.service + } + + /// Get a reference to the inner tower service. + pub fn as_inner(&self) -> &S { + &self.service + } + + /// Get a mutable reference to the inner tower service. + pub fn as_inner_mut(&mut self) -> &mut S { + &mut self.service + } +} + +impl hyper::service::Service> for TowerToHyperService +where + S: tower_service::Service> + Clone, + S::Error: Into + 'static, +{ + type Response = S::Response; + type Error = super::Error; + type Future = TowerToHyperServiceFuture>; + + fn call(&self, req: Request) -> Self::Future { + let req = req.map(crate::body::boxed); + TowerToHyperServiceFuture { + future: self.service.clone().oneshot(req), + } + } +} + +/// Future returned by [`TowerToHyperService`]. +#[derive(Debug)] +#[pin_project] +pub struct TowerToHyperServiceFuture +where + S: tower_service::Service, +{ + #[pin] + future: Oneshot, +} + +impl Future for TowerToHyperServiceFuture +where + S: tower_service::Service, + S::Error: Into + 'static, +{ + type Output = Result; + + #[inline] + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.project() + .future + .poll(cx) + .map_err(super::Error::from_source) + } +} + impl Router { pub(crate) fn new(server: Server, routes: Routes) -> Self { Self { server, routes } @@ -573,7 +753,7 @@ impl Router { /// Add a new service to this router. pub fn add_service(mut self, svc: S) -> Self where - S: Service, Response = Response, Error = Infallible> + S: Service, Response = Response, Error = Infallible> + NamedService + Clone + Send @@ -592,7 +772,7 @@ impl Router { #[allow(clippy::type_complexity)] pub fn add_optional_service(mut self, svc: Option) -> Self where - S: Service, Response = Response, Error = Infallible> + S: Service, Response = Response, Error = Infallible> + NamedService + Clone + Send @@ -617,10 +797,12 @@ impl Router { /// [tokio]: https://docs.rs/tokio pub async fn serve(self, addr: SocketAddr) -> Result<(), super::Error> where - L: Layer, - L::Service: Service, Response = Response> + Clone + Send + 'static, - <>::Service as Service>>::Future: Send + 'static, - <>::Service as Service>>::Error: Into + Send, + L: Layer + Clone, + L::Service: + Service, Response = Response> + Clone + Send + 'static, + <>::Service as Service>>::Future: Send + 'static, + <>::Service as Service>>::Error: + Into + Send, ResBody: http_body::Body + Send + 'static, ResBody::Error: Into, { @@ -648,9 +830,11 @@ impl Router { ) -> Result<(), super::Error> where L: Layer, - L::Service: Service, Response = Response> + Clone + Send + 'static, - <>::Service as Service>>::Future: Send + 'static, - <>::Service as Service>>::Error: Into + Send, + L::Service: + Service, Response = Response> + Clone + Send + 'static, + <>::Service as Service>>::Future: Send + 'static, + <>::Service as Service>>::Error: + Into + Send, ResBody: http_body::Body + Send + 'static, ResBody::Error: Into, { @@ -677,9 +861,11 @@ impl Router { IO::ConnectInfo: Clone + Send + Sync + 'static, IE: Into, L: Layer, - L::Service: Service, Response = Response> + Clone + Send + 'static, - <>::Service as Service>>::Future: Send + 'static, - <>::Service as Service>>::Error: Into + Send, + L::Service: + Service, Response = Response> + Clone + Send + 'static, + <>::Service as Service>>::Future: Send + 'static, + <>::Service as Service>>::Error: + Into + Send, ResBody: http_body::Body + Send + 'static, ResBody::Error: Into, { @@ -712,9 +898,11 @@ impl Router { IE: Into, F: Future, L: Layer, - L::Service: Service, Response = Response> + Clone + Send + 'static, - <>::Service as Service>>::Future: Send + 'static, - <>::Service as Service>>::Error: Into + Send, + L::Service: + Service, Response = Response> + Clone + Send + 'static, + <>::Service as Service>>::Future: Send + 'static, + <>::Service as Service>>::Error: + Into + Send, ResBody: http_body::Body + Send + 'static, ResBody::Error: Into, { @@ -727,9 +915,11 @@ impl Router { pub fn into_service(self) -> L::Service where L: Layer, - L::Service: Service, Response = Response> + Clone + Send + 'static, - <>::Service as Service>>::Future: Send + 'static, - <>::Service as Service>>::Error: Into + Send, + L::Service: + Service, Response = Response> + Clone + Send + 'static, + <>::Service as Service>>::Future: Send + 'static, + <>::Service as Service>>::Error: + Into + Send, ResBody: http_body::Body + Send + 'static, ResBody::Error: Into, { @@ -743,14 +933,15 @@ impl fmt::Debug for Server { } } +#[derive(Clone)] struct Svc { inner: S, trace_interceptor: Option, } -impl Service> for Svc +impl Service> for Svc where - S: Service, Response = Response>, + S: Service, Response = Response>, S::Error: Into, ResBody: http_body::Body + Send + 'static, ResBody::Error: Into, @@ -763,7 +954,7 @@ where self.inner.poll_ready(cx).map_err(Into::into) } - fn call(&mut self, mut req: Request) -> Self::Future { + fn call(&mut self, mut req: Request) -> Self::Future { let span = if let Some(trace_interceptor) = &self.trace_interceptor { let (parts, body) = req.into_parts(); let bodyless_request = Request::from_parts(parts, ()); @@ -806,7 +997,7 @@ where let _guard = this.span.enter(); let response: Response = ready!(this.inner.poll(cx)).map_err(Into::into)?; - let response = response.map(|body| body.map_err(Into::into).boxed_unsync()); + let response = response.map(|body| boxed(body.map_err(Into::into))); Poll::Ready(Ok(response)) } } @@ -817,6 +1008,7 @@ impl fmt::Debug for Svc { } } +#[derive(Clone)] struct MakeSvc { concurrency_limit: Option, timeout: Option, @@ -828,7 +1020,7 @@ struct MakeSvc { impl Service<&ServerIo> for MakeSvc where IO: Connected, - S: Service, Response = Response> + Clone + Send + 'static, + S: Service, Response = Response> + Clone + Send + 'static, S::Future: Send + 'static, S::Error: Into + Send, ResBody: http_body::Body + Send + 'static, @@ -857,8 +1049,8 @@ where .service(svc); let svc = ServiceBuilder::new() - .layer(BoxService::layer()) - .map_request(move |mut request: Request| { + .layer(BoxCloneService::layer()) + .map_request(move |mut request: Request| { match &conn_info { tower::util::Either::A(inner) => { request.extensions_mut().insert(inner.clone()); @@ -889,3 +1081,29 @@ where future::ready(Ok(svc)) } } + +// From `futures-util` crate, borrowed since this is the only dependency tonic requires. +// LICENSE: MIT or Apache-2.0 +// A future which only yields `Poll::Ready` once, and thereafter yields `Poll::Pending`. +#[pin_project] +struct Fuse { + #[pin] + inner: Option, +} + +impl Future for Fuse +where + F: Future, +{ + type Output = F::Output; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.as_mut().project().inner.as_pin_mut() { + Some(fut) => fut.poll(cx).map(|output| { + self.project().inner.set(None); + output + }), + None => Poll::Pending, + } + } +} diff --git a/tonic/src/transport/service/router.rs b/tonic/src/transport/service/router.rs index 85636c4d4..c43782ba9 100644 --- a/tonic/src/transport/service/router.rs +++ b/tonic/src/transport/service/router.rs @@ -1,9 +1,9 @@ use crate::{ body::{boxed, BoxBody}, server::NamedService, + transport::BoxFuture, }; use http::{Request, Response}; -use hyper::Body; use pin_project::pin_project; use std::{ convert::Infallible, @@ -12,7 +12,6 @@ use std::{ pin::Pin, task::{ready, Context, Poll}, }; -use tower::ServiceExt; use tower_service::Service; /// A [`Service`] router. @@ -31,7 +30,7 @@ impl RoutesBuilder { /// Add a new service. pub fn add_service(&mut self, svc: S) -> &mut Self where - S: Service, Response = Response, Error = Infallible> + S: Service, Response = Response, Error = Infallible> + NamedService + Clone + Send @@ -53,7 +52,7 @@ impl Routes { /// Create a new routes with `svc` already added to it. pub fn new(svc: S) -> Self where - S: Service, Response = Response, Error = Infallible> + S: Service, Response = Response, Error = Infallible> + NamedService + Clone + Send @@ -68,7 +67,7 @@ impl Routes { /// Add a new service. pub fn add_service(mut self, svc: S) -> Self where - S: Service, Response = Response, Error = Infallible> + S: Service, Response = Response, Error = Infallible> + NamedService + Clone + Send @@ -76,10 +75,10 @@ impl Routes { S::Future: Send + 'static, S::Error: Into + Send, { - let svc = svc.map_response(|res| res.map(axum::body::boxed)); - self.router = self - .router - .route_service(&format!("/{}/*rest", S::NAME), svc); + self.router = self.router.route_service( + &format!("/{}/*rest", S::NAME), + AxumBodyService { service: svc }, + ); self } @@ -103,7 +102,7 @@ async fn unimplemented() -> impl axum::response::IntoResponse { (status, headers) } -impl Service> for Routes { +impl Service> for Routes { type Response = Response; type Error = crate::Error; type Future = RoutesFuture; @@ -113,13 +112,13 @@ impl Service> for Routes { Poll::Ready(Ok(())) } - fn call(&mut self, req: Request) -> Self::Future { + fn call(&mut self, req: Request) -> Self::Future { RoutesFuture(self.router.call(req)) } } #[pin_project] -pub struct RoutesFuture(#[pin] axum::routing::future::RouteFuture); +pub struct RoutesFuture(#[pin] axum::routing::future::RouteFuture); impl fmt::Debug for RoutesFuture { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -137,3 +136,33 @@ impl Future for RoutesFuture { } } } + +#[derive(Clone)] +struct AxumBodyService { + service: S, +} + +impl Service> for AxumBodyService +where + S: Service, Response = Response, Error = Infallible> + + Clone + + Send + + 'static, + S::Future: Send + 'static, +{ + type Response = Response; + type Error = Infallible; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.service.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + let fut = self.service.call(req.map(|body| boxed(body))); + Box::pin(async move { + fut.await + .map(|res| res.map(|body| axum::body::Body::new(body))) + }) + } +} From 2427bbdc022bb2e71d3f7f332bc0a37256996514 Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 03:05:19 +0000 Subject: [PATCH 18/25] Convert service connector for hyper-1.0 Hyper >= 1 no longer includes automatic http2/http1 combined connections, and so we must swtich to the `http2::Builder` type (this is okay, we set http2_only(true) anyhow). As well, hyper >= 1 is generic over executors and does not directly depend on tokio. Since http2 connections can be multiplexed, they require some additional background task to handle sending and receiving requests. Additionally, these background tasks do not natively implement `tower::Service` since hyper >= 1 does not depend on `tower`. Therefore, we re-implement the `SendRequest` task as a tower::Service, so that it can be used within `Connection`, which expects to operate on a tower::Service to serve connections. Co-authored-by: Ivan Krivosheev Co-authored-by: Allan Zhang --- tonic/src/transport/service/connection.rs | 113 +++++++++++++++++++--- tonic/src/transport/service/executor.rs | 9 +- 2 files changed, 103 insertions(+), 19 deletions(-) diff --git a/tonic/src/transport/service/connection.rs b/tonic/src/transport/service/connection.rs index 8e1f52c5f..a31c9868b 100644 --- a/tonic/src/transport/service/connection.rs +++ b/tonic/src/transport/service/connection.rs @@ -1,13 +1,12 @@ +use super::SharedExec; use super::{grpc_timeout::GrpcTimeout, reconnect::Reconnect, AddOrigin, UserAgent}; use crate::{ body::{boxed, BoxBody}, transport::{BoxFuture, Endpoint}, }; use http::Uri; -use hyper::client::conn::Builder; -use hyper::client::connect::Connection as HyperConnection; -use hyper::client::service::Connect as HyperConnect; use hyper::rt; +use hyper::{client::conn::http2::Builder, rt::Executor}; use std::{ fmt, task::{Context, Poll}, @@ -36,24 +35,22 @@ impl Connection { C::Future: Unpin + Send, C::Response: rt::Read + rt::Write + Unpin + Send + 'static, { - let mut settings = Builder::new() - .http2_initial_stream_window_size(endpoint.init_stream_window_size) - .http2_initial_connection_window_size(endpoint.init_connection_window_size) - .http2_only(true) - .http2_keep_alive_interval(endpoint.http2_keep_alive_interval) - .executor(endpoint.executor.clone()) + let mut settings: Builder = Builder::new(endpoint.executor.clone()) + .initial_stream_window_size(endpoint.init_stream_window_size) + .initial_connection_window_size(endpoint.init_connection_window_size) + .keep_alive_interval(endpoint.http2_keep_alive_interval) .clone(); if let Some(val) = endpoint.http2_keep_alive_timeout { - settings.http2_keep_alive_timeout(val); + settings.keep_alive_timeout(val); } if let Some(val) = endpoint.http2_keep_alive_while_idle { - settings.http2_keep_alive_while_idle(val); + settings.keep_alive_while_idle(val); } if let Some(val) = endpoint.http2_adaptive_window { - settings.http2_adaptive_window(val); + settings.adaptive_window(val); } let stack = ServiceBuilder::new() @@ -68,13 +65,13 @@ impl Connection { .option_layer(endpoint.rate_limit.map(|(l, d)| RateLimitLayer::new(l, d))) .into_inner(); - let connector = HyperConnect::new(connector, settings); - let conn = Reconnect::new(connector, endpoint.uri.clone(), is_lazy); + let make_service = + MakeSendRequestService::new(connector, endpoint.executor.clone(), settings); - let inner = stack.layer(conn); + let conn = Reconnect::new(make_service, endpoint.uri.clone(), is_lazy); Self { - inner: BoxService::new(inner), + inner: BoxService::new(stack.layer(conn)), } } @@ -126,3 +123,87 @@ impl fmt::Debug for Connection { f.debug_struct("Connection").finish() } } + +struct SendRequest { + inner: hyper::client::conn::http2::SendRequest, +} + +impl From> for SendRequest { + fn from(inner: hyper::client::conn::http2::SendRequest) -> Self { + Self { inner } + } +} + +impl tower::Service> for SendRequest { + type Response = http::Response; + type Error = crate::Error; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx).map_err(Into::into) + } + + fn call(&mut self, req: Request) -> Self::Future { + let fut = self.inner.send_request(req); + + Box::pin(async move { + fut.await + .map_err(Into::into) + .map(|res| res.map(|body| boxed(body))) + }) + } +} + +struct MakeSendRequestService { + connector: C, + executor: super::SharedExec, + settings: Builder, +} + +impl MakeSendRequestService { + fn new(connector: C, executor: SharedExec, settings: Builder) -> Self { + Self { + connector, + executor, + settings, + } + } +} + +impl tower::Service for MakeSendRequestService +where + C: Service + Send + 'static, + C::Error: Into + Send, + C::Future: Unpin + Send, + C::Response: rt::Read + rt::Write + Unpin + Send + 'static, +{ + type Response = SendRequest; + type Error = crate::Error; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.connector.poll_ready(cx).map_err(Into::into) + } + + fn call(&mut self, req: Uri) -> Self::Future { + let fut = self.connector.call(req); + let builder = self.settings.clone(); + let executor = self.executor.clone(); + + Box::pin(async move { + let io = fut.await.map_err(Into::into)?; + let (send_request, conn) = builder.handshake(io).await?; + + Executor::>::execute( + &executor, + Box::pin(async move { + if let Err(e) = conn.await { + tracing::debug!("connection task error: {:?}", e); + } + }) as _, + ); + + Ok(SendRequest::from(send_request)) + }) + } +} diff --git a/tonic/src/transport/service/executor.rs b/tonic/src/transport/service/executor.rs index de3cfbe6e..7b699c307 100644 --- a/tonic/src/transport/service/executor.rs +++ b/tonic/src/transport/service/executor.rs @@ -36,8 +36,11 @@ impl SharedExec { } } -impl Executor> for SharedExec { - fn execute(&self, fut: BoxFuture<'static, ()>) { - self.inner.execute(fut) +impl Executor for SharedExec +where + F: Future + Send + 'static, +{ + fn execute(&self, fut: F) { + self.inner.execute(Box::pin(fut)) } } From c1cd975807143607f90f11c6411a1a362e383faa Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 03:07:21 +0000 Subject: [PATCH 19/25] Convert hyper::Client to hyper_util::legacy::Client `hyper::Client` has been moved to `hyper_util::legacy::Client` in version 1. Co-authored-by: Ivan Krivosheev Co-authored-by: Allan Zhang --- examples/src/grpc-web/client.rs | 2 +- examples/src/h2c/client.rs | 16 ++++++++-------- examples/src/tls_rustls/client.rs | 5 +++-- tonic-web/tests/integration/tests/grpc_web.rs | 4 ++-- tonic/src/transport/channel/endpoint.rs | 11 ++++++----- tonic/src/transport/channel/mod.rs | 4 ++-- tonic/src/transport/service/discover.rs | 3 ++- 7 files changed, 24 insertions(+), 21 deletions(-) diff --git a/examples/src/grpc-web/client.rs b/examples/src/grpc-web/client.rs index fd20a788b..fa64dd506 100644 --- a/examples/src/grpc-web/client.rs +++ b/examples/src/grpc-web/client.rs @@ -9,7 +9,7 @@ pub mod hello_world { #[tokio::main] async fn main() -> Result<(), Box> { // Must use hyper directly... - let client = hyper::Client::builder().build_http(); + let client = hyper_util::client::legacy::Client::builder(TokioExecutor::new()).build_http(); let svc = tower::ServiceBuilder::new() .layer(GrpcWebClientLayer::new()) diff --git a/examples/src/h2c/client.rs b/examples/src/h2c/client.rs index 2f9f90a79..b162fcc08 100644 --- a/examples/src/h2c/client.rs +++ b/examples/src/h2c/client.rs @@ -1,7 +1,7 @@ use hello_world::greeter_client::GreeterClient; use hello_world::HelloRequest; use http::Uri; -use hyper::Client; +use hyper_util::client::legacy::Client; use hyper_util::rt::TokioExecutor; pub mod hello_world { @@ -12,7 +12,6 @@ pub mod hello_world { async fn main() -> Result<(), Box> { let origin = Uri::from_static("http://[::1]:50051"); let h2c_client = h2c::H2cChannel { - client: Client::new(), client: Client::builder(TokioExecutor::new()).build_http(), }; @@ -35,10 +34,11 @@ mod h2c { task::{Context, Poll}, }; - use hyper::{client::HttpConnector, Client}; use hyper::body::Incoming; use hyper_util::{ + client::legacy::{connect::HttpConnector, Client}, rt::TokioExecutor, + }; use tonic::body::{empty_body, BoxBody}; use tower::Service; @@ -77,11 +77,11 @@ mod h2c { let upgraded_io = hyper::upgrade::on(res).await.unwrap(); // In an ideal world you would somehow cache this connection - let (mut h2_client, conn) = hyper::client::conn::Builder::new() - .http2_only(true) - .handshake(upgraded_io) - .await - .unwrap(); + let (mut h2_client, conn) = + hyper::client::conn::http2::Builder::new(TokioExecutor::new()) + .handshake(upgraded_io) + .await + .unwrap(); tokio::spawn(conn); h2_client.send_request(request).await diff --git a/examples/src/tls_rustls/client.rs b/examples/src/tls_rustls/client.rs index 23d6e8130..f03f70051 100644 --- a/examples/src/tls_rustls/client.rs +++ b/examples/src/tls_rustls/client.rs @@ -5,7 +5,8 @@ pub mod pb { tonic::include_proto!("/grpc.examples.unaryecho"); } -use hyper::{client::HttpConnector, Uri}; +use hyper::Uri; +use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor}; use pb::{echo_client::EchoClient, EchoRequest}; use tokio_rustls::rustls::{ClientConfig, RootCertStore}; @@ -47,7 +48,7 @@ async fn main() -> Result<(), Box> { .map_request(|_| Uri::from_static("https://[::1]:50051")) .service(http); - let client = hyper::Client::builder().build(connector); + let client = hyper_util::client::legacy::Client::builder(TokioExecutor::new()).build(connector); // Using `with_origin` will let the codegenerated client set the `scheme` and // `authority` from the porvided `Uri`. diff --git a/tonic-web/tests/integration/tests/grpc_web.rs b/tonic-web/tests/integration/tests/grpc_web.rs index b46d98d45..2c57f2680 100644 --- a/tonic-web/tests/integration/tests/grpc_web.rs +++ b/tonic-web/tests/integration/tests/grpc_web.rs @@ -6,6 +6,7 @@ use http_body_util::{BodyExt as _, Full}; use hyper::body::Incoming; use hyper::http::{header, StatusCode}; use hyper::{Method, Request, Uri}; +use hyper_util::client::legacy::Client; use hyper_util::rt::TokioExecutor; use prost::Message; use tokio::net::TcpListener; @@ -15,12 +16,12 @@ use tonic::transport::Server; use integration::pb::{test_server::TestServer, Input, Output}; use integration::Svc; +use tonic::Status; use tonic_web::GrpcWebLayer; #[tokio::test] async fn binary_request() { let server_url = spawn().await; - let client = Client::new(); let client = Client::builder(TokioExecutor::new()).build_http(); let req = build_request(server_url, "grpc-web", "grpc-web"); @@ -44,7 +45,6 @@ async fn binary_request() { #[tokio::test] async fn text_request() { let server_url = spawn().await; - let client = Client::new(); let client = Client::builder(TokioExecutor::new()).build_http(); let req = build_request(server_url, "grpc-web-text", "grpc-web-text"); diff --git a/tonic/src/transport/channel/endpoint.rs b/tonic/src/transport/channel/endpoint.rs index 584c56f8c..6014960a8 100644 --- a/tonic/src/transport/channel/endpoint.rs +++ b/tonic/src/transport/channel/endpoint.rs @@ -8,8 +8,9 @@ use crate::transport::{service::SharedExec, Error, Executor}; use bytes::Bytes; use http::{uri::Uri, HeaderValue}; use hyper::rt; +use hyper_util::client::legacy::connect::HttpConnector; use std::{fmt, future::Future, pin::Pin, str::FromStr, time::Duration}; -use tower::make::MakeConnection; +use tower_service::Service; /// Channel builder. /// @@ -333,7 +334,7 @@ impl Endpoint { /// Create a channel from this config. pub async fn connect(&self) -> Result { - let mut http = hyper::client::connect::HttpConnector::new(); + let mut http = HttpConnector::new(); http.enforce_http(false); http.set_nodelay(self.tcp_nodelay); http.set_keepalive(self.tcp_keepalive); @@ -349,7 +350,7 @@ impl Endpoint { /// The channel returned by this method does not attempt to connect to the endpoint until first /// use. pub fn connect_lazy(&self) -> Channel { - let mut http = hyper::client::connect::HttpConnector::new(); + let mut http = HttpConnector::new(); http.enforce_http(false); http.set_nodelay(self.tcp_nodelay); http.set_keepalive(self.tcp_keepalive); @@ -369,7 +370,7 @@ impl Endpoint { /// The [`connect_timeout`](Endpoint::connect_timeout) will still be applied. pub async fn connect_with_connector(&self, connector: C) -> Result where - C: MakeConnection + Send + 'static, + C: Service + Send + 'static, C::Response: rt::Read + rt::Write + Send + Unpin + 'static, C::Future: Send + 'static, crate::Error: From + Send + 'static, @@ -394,7 +395,7 @@ impl Endpoint { /// uses a Unix socket transport. pub fn connect_with_connector_lazy(&self, connector: C) -> Channel where - C: MakeConnection + Send + 'static, + C: Service + Send + 'static, C::Response: rt::Read + rt::Write + Send + Unpin + 'static, C::Future: Send + 'static, crate::Error: From + Send + 'static, diff --git a/tonic/src/transport/channel/mod.rs b/tonic/src/transport/channel/mod.rs index 3e5869bcb..0983725f8 100644 --- a/tonic/src/transport/channel/mod.rs +++ b/tonic/src/transport/channel/mod.rs @@ -17,7 +17,7 @@ use http::{ uri::{InvalidUri, Uri}, Request, Response, }; -use hyper::client::connect::Connection as HyperConnection; +use hyper_util::client::legacy::connect::Connection as HyperConnection; use std::{ fmt, future::Future, @@ -42,7 +42,7 @@ const DEFAULT_BUFFER_SIZE: usize = 1024; /// A default batteries included `transport` channel. /// -/// This provides a fully featured http2 gRPC client based on [`hyper::Client`] +/// This provides a fully featured http2 gRPC client based on `hyper` /// and `tower` services. /// /// # Multiplexing requests diff --git a/tonic/src/transport/service/discover.rs b/tonic/src/transport/service/discover.rs index 2d23ca74c..b9356110e 100644 --- a/tonic/src/transport/service/discover.rs +++ b/tonic/src/transport/service/discover.rs @@ -1,6 +1,7 @@ use super::connection::Connection; use crate::transport::Endpoint; +use hyper_util::client::legacy::connect::HttpConnector; use std::{ hash::Hash, pin::Pin, @@ -32,7 +33,7 @@ impl Stream for DynamicServiceStream { Poll::Pending | Poll::Ready(None) => Poll::Pending, Poll::Ready(Some(change)) => match change { Change::Insert(k, endpoint) => { - let mut http = hyper::client::connect::HttpConnector::new(); + let mut http = HttpConnector::new(); http.set_nodelay(endpoint.tcp_nodelay); http.set_keepalive(endpoint.tcp_keepalive); http.set_connect_timeout(endpoint.connect_timeout); From c3c6ee7deff98a418ce58f8d1846728a8a61e22f Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 18:54:31 +0000 Subject: [PATCH 20/25] Identify and propogate connect errors hyper::Error no longer provides information about Connect errors, especially since hyper_util now contains the connection implementation, it does not provide a separate error type. Instead, we create an internal Error type which is used in our own connectors, and then checked when figuring out what the gRPC status should be. --- tonic/src/status.rs | 13 +++--- tonic/src/transport/mod.rs | 2 + tonic/src/transport/service/connector.rs | 58 ++++++++++++++++-------- tonic/src/transport/service/mod.rs | 1 + 4 files changed, 49 insertions(+), 25 deletions(-) diff --git a/tonic/src/status.rs b/tonic/src/status.rs index 4ce3abde6..77c85b47d 100644 --- a/tonic/src/status.rs +++ b/tonic/src/status.rs @@ -412,13 +412,7 @@ impl Status { // > status. Note that the frequency of PINGs is highly dependent on the network // > environment, implementations are free to adjust PING frequency based on network and // > application requirements, which is why it's mapped to unavailable here. - // - // Likewise, if we are unable to connect to the server, map this to UNAVAILABLE. This is - // consistent with the behavior of a C++ gRPC client when the server is not running, and - // matches the spec of: - // > The service is currently unavailable. This is most likely a transient condition that - // > can be corrected if retried with a backoff. - if err.is_timeout() || err.is_connect() { + if err.is_timeout() { return Some(Status::unavailable(err.to_string())); } @@ -623,6 +617,11 @@ fn find_status_in_source_chain(err: &(dyn Error + 'static)) -> Option { return Some(Status::cancelled(timeout.to_string())); } + #[cfg(feature = "transport")] + if let Some(connect) = err.downcast_ref::() { + return Some(Status::unavailable(connect.to_string())); + } + #[cfg(feature = "transport")] if let Some(hyper) = err .downcast_ref::() diff --git a/tonic/src/transport/mod.rs b/tonic/src/transport/mod.rs index 0301db3fd..357ab2bf4 100644 --- a/tonic/src/transport/mod.rs +++ b/tonic/src/transport/mod.rs @@ -107,6 +107,8 @@ pub use self::error::Error; pub use self::server::Server; #[doc(inline)] pub use self::service::grpc_timeout::TimeoutExpired; +pub(crate) use self::service::ConnectError; + #[cfg(feature = "tls")] #[cfg_attr(docsrs, doc(cfg(feature = "tls")))] pub use self::tls::Certificate; diff --git a/tonic/src/transport/service/connector.rs b/tonic/src/transport/service/connector.rs index 8219fe8d9..978441d75 100644 --- a/tonic/src/transport/service/connector.rs +++ b/tonic/src/transport/service/connector.rs @@ -3,7 +3,6 @@ use super::io::BoxedIo; #[cfg(feature = "tls")] use super::tls::TlsConnector; use http::Uri; -#[cfg(feature = "tls")] use std::fmt; use std::task::{Context, Poll}; @@ -13,6 +12,23 @@ use hyper::rt; use hyper_util::rt::TokioIo; use tower_service::Service; +/// Wrapper type to indicate that an error occurs during the connection +/// process, so that the appropriate gRPC Status can be inferred. +#[derive(Debug)] +pub(crate) struct ConnectError(pub(crate) crate::Error); + +impl fmt::Display for ConnectError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&self.0, f) + } +} + +impl std::error::Error for ConnectError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + Some(self.0.as_ref()) + } +} + pub(crate) struct Connector { inner: C, #[cfg(feature = "tls")] @@ -61,11 +77,13 @@ where crate::Error: From + Send + 'static, { type Response = BoxedIo; - type Error = crate::Error; + type Error = ConnectError; type Future = BoxFuture<'static, Result>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_ready(cx).map_err(Into::into) + self.inner + .poll_ready(cx) + .map_err(|err| ConnectError(From::from(err))) } fn call(&mut self, uri: Uri) -> Self::Future { @@ -80,23 +98,27 @@ where let connect = self.inner.call(uri); Box::pin(async move { - let io = connect.await?; - - #[cfg(feature = "tls")] - { - if let Some(tls) = tls { - return if is_https { - let io = tls.connect(TokioIo::new(io)).await?; - Ok(io) - } else { - Ok(BoxedIo::new(io)) - }; - } else if is_https { - return Err(HttpsUriWithoutTlsSupport(()).into()); + async { + let io = connect.await?; + + #[cfg(feature = "tls")] + { + if let Some(tls) = tls { + return if is_https { + let io = tls.connect(TokioIo::new(io)).await?; + Ok(io) + } else { + Ok(BoxedIo::new(io)) + }; + } else if is_https { + return Err(HttpsUriWithoutTlsSupport(()).into()); + } } - } - Ok(BoxedIo::new(io)) + Ok::<_, crate::Error>(BoxedIo::new(io)) + } + .await + .map_err(|err| ConnectError(From::from(err))) }) } } diff --git a/tonic/src/transport/service/mod.rs b/tonic/src/transport/service/mod.rs index 69d850f10..2b2a84070 100644 --- a/tonic/src/transport/service/mod.rs +++ b/tonic/src/transport/service/mod.rs @@ -13,6 +13,7 @@ mod user_agent; pub(crate) use self::add_origin::AddOrigin; pub(crate) use self::connection::Connection; +pub(crate) use self::connector::ConnectError; pub(crate) use self::connector::Connector; pub(crate) use self::discover::DynamicServiceStream; pub(crate) use self::executor::SharedExec; From e349b131e3403dc8f6df852602e90d96ee18a9af Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 03:10:24 +0000 Subject: [PATCH 21/25] Remove hyper::server::conn::AddrStream MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit hyper >= 1 has deprecated all of `hyper::server`, including `AddrStream` Co-authored-by: Ivan Krivosheev Co-authored-by: Allan Zhang Replace hyper::server::Accept hyper::server is deprectaed. Instead, we implement our own TCP-incoming based on the now removed hyper::server::Accept. In order to set `TCP_KEEPALIVE` we require the socket2 crate, since this option is not exposed in the standard library’s API. The implementaiton is inspired by that of hyper v0.14 --- tonic/Cargo.toml | 3 +- tonic/src/transport/server/conn.rs | 12 ----- tonic/src/transport/server/incoming.rs | 66 +++++++++++++++++++------- 3 files changed, 50 insertions(+), 31 deletions(-) diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index da4482291..b795cc5f9 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -37,8 +37,9 @@ transport = [ "dep:axum", "channel", "dep:h2", - "dep:tokio", "tokio?/net", "tokio?/time", "dep:hyper", "dep:hyper-util", "dep:hyper-timeout", + "dep:socket2", + "dep:tokio", "tokio?/macros", "tokio?/net", "tokio?/time", "dep:tower", ] channel = [] diff --git a/tonic/src/transport/server/conn.rs b/tonic/src/transport/server/conn.rs index 907cf4965..37bcc561b 100644 --- a/tonic/src/transport/server/conn.rs +++ b/tonic/src/transport/server/conn.rs @@ -1,4 +1,3 @@ -use hyper::server::conn::AddrStream; use std::net::SocketAddr; use tokio::net::TcpStream; @@ -86,17 +85,6 @@ impl TcpConnectInfo { } } -impl Connected for AddrStream { - type ConnectInfo = TcpConnectInfo; - - fn connect_info(&self) -> Self::ConnectInfo { - TcpConnectInfo { - local_addr: Some(self.local_addr()), - remote_addr: Some(self.remote_addr()), - } - } -} - impl Connected for TcpStream { type ConnectInfo = TcpConnectInfo; diff --git a/tonic/src/transport/server/incoming.rs b/tonic/src/transport/server/incoming.rs index ede62a32d..7f5f76c25 100644 --- a/tonic/src/transport/server/incoming.rs +++ b/tonic/src/transport/server/incoming.rs @@ -1,20 +1,18 @@ use super::{Connected, Server}; use crate::transport::service::ServerIo; -use hyper::server::{ - accept::Accept, - conn::{AddrIncoming, AddrStream}, -}; use std::{ - net::SocketAddr, + net::{SocketAddr, TcpListener as StdTcpListener}, pin::{pin, Pin}, - task::{Context, Poll}, + task::{ready, Context, Poll}, time::Duration, }; use tokio::{ io::{AsyncRead, AsyncWrite}, - net::TcpListener, + net::{TcpListener, TcpStream}, }; +use tokio_stream::wrappers::TcpListenerStream; use tokio_stream::{Stream, StreamExt}; +use tracing::warn; #[cfg(not(feature = "tls"))] pub(crate) fn tcp_incoming( @@ -127,7 +125,9 @@ enum SelectOutput { /// of `AsyncRead + AsyncWrite` that communicate with clients that connect to a socket address. #[derive(Debug)] pub struct TcpIncoming { - inner: AddrIncoming, + inner: TcpListenerStream, + nodelay: bool, + keepalive: Option, } impl TcpIncoming { @@ -167,10 +167,15 @@ impl TcpIncoming { nodelay: bool, keepalive: Option, ) -> Result { - let mut inner = AddrIncoming::bind(&addr)?; - inner.set_nodelay(nodelay); - inner.set_keepalive(keepalive); - Ok(TcpIncoming { inner }) + let std_listener = StdTcpListener::bind(addr)?; + std_listener.set_nonblocking(true)?; + + let inner = TcpListenerStream::new(TcpListener::from_std(std_listener)?); + Ok(Self { + inner, + nodelay, + keepalive, + }) } /// Creates a new `TcpIncoming` from an existing `tokio::net::TcpListener`. @@ -179,18 +184,43 @@ impl TcpIncoming { nodelay: bool, keepalive: Option, ) -> Result { - let mut inner = AddrIncoming::from_listener(listener)?; - inner.set_nodelay(nodelay); - inner.set_keepalive(keepalive); - Ok(TcpIncoming { inner }) + Ok(Self { + inner: TcpListenerStream::new(listener), + nodelay, + keepalive, + }) } } impl Stream for TcpIncoming { - type Item = Result; + type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.inner).poll_accept(cx) + match ready!(Pin::new(&mut self.inner).poll_next(cx)) { + Some(Ok(stream)) => { + set_accepted_socket_options(&stream, self.nodelay, self.keepalive); + Some(Ok(stream)).into() + } + other => Poll::Ready(other), + } + } +} + +// Consistent with hyper-0.14, this function does not return an error. +fn set_accepted_socket_options(stream: &TcpStream, nodelay: bool, keepalive: Option) { + if nodelay { + if let Err(e) = stream.set_nodelay(true) { + warn!("error trying to set TCP nodelay: {}", e); + } + } + + if let Some(timeout) = keepalive { + let sock_ref = socket2::SockRef::from(&stream); + let sock_keepalive = socket2::TcpKeepalive::new().with_time(timeout); + + if let Err(e) = sock_ref.set_tcp_keepalive(&sock_keepalive) { + warn!("error trying to set TCP keepalive: {}", e); + } } } From b81e0fc3cd465e278cf23afda6ca52e98e0dc495 Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 03:13:39 +0000 Subject: [PATCH 22/25] [examples] In h2c, replace hyper::Server with an accept loop MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit hyper::Server is deprecated, with no current common replacement. Instead of implementing (or using tonic’s new) full server in here, we write a simple accept loop, which is sufficient to demonstrate the functionality of h2c. --- examples/src/h2c/server.rs | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/examples/src/h2c/server.rs b/examples/src/h2c/server.rs index b1d4c0a8d..da5c3425c 100644 --- a/examples/src/h2c/server.rs +++ b/examples/src/h2c/server.rs @@ -1,9 +1,14 @@ +use std::net::SocketAddr; + use hyper_util::rt::{TokioExecutor, TokioIo}; +use hyper_util::server::conn::auto::Builder; +use hyper_util::service::TowerToHyperService; +use tokio::net::TcpListener; +// use tonic::transport::server::TowerToHyperService; use tonic::{transport::Server, Request, Response, Status}; use hello_world::greeter_server::{Greeter, GreeterServer}; use hello_world::{HelloReply, HelloRequest}; -use tower::make::Shared; pub mod hello_world { tonic::include_proto!("helloworld"); @@ -29,21 +34,36 @@ impl Greeter for MyGreeter { #[tokio::main] async fn main() -> Result<(), Box> { - let addr = "[::1]:50051".parse().unwrap(); + let addr: SocketAddr = "[::1]:50051".parse().unwrap(); let greeter = MyGreeter::default(); println!("GreeterServer listening on {}", addr); + let incoming = TcpListener::bind(addr).await?; let svc = Server::builder() .add_service(GreeterServer::new(greeter)) .into_router(); let h2c = h2c::H2c { s: svc }; - let server = hyper::Server::bind(&addr).serve(Shared::new(h2c)); - server.await.unwrap(); - - Ok(()) + loop { + match incoming.accept().await { + Ok((io, _)) => { + let router = h2c.clone(); + tokio::spawn(async move { + let builder = Builder::new(TokioExecutor::new()); + let conn = builder.serve_connection_with_upgrades( + TokioIo::new(io), + TowerToHyperService::new(router), + ); + let _ = conn.await; + }); + } + Err(e) => { + eprintln!("Error accepting connection: {}", e); + } + } + } } mod h2c { From 3d44ca858cc944eb80f9a150b7bd4d8302d99a4a Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Mon, 27 May 2024 04:22:29 +0000 Subject: [PATCH 23/25] Upgrade tls dependencies MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit hyper-rustls requires version 0.27.0 to support hyper >= 1, bringing a few other tls bumps along. Importantly, we add the “ring” and “tls12” features to use ring as the crypto backend, consistent with previous versions of tonic. A future version of tonic might support selecting backends via features. Co-authored-by: Ivan Krivosheev --- examples/src/tls_rustls/client.rs | 5 ++-- examples/src/tls_rustls/server.rs | 39 ++++++++++++++++++------------ tonic/src/transport/service/tls.rs | 3 ++- 3 files changed, 27 insertions(+), 20 deletions(-) diff --git a/examples/src/tls_rustls/client.rs b/examples/src/tls_rustls/client.rs index f03f70051..4c39a2c46 100644 --- a/examples/src/tls_rustls/client.rs +++ b/examples/src/tls_rustls/client.rs @@ -18,11 +18,10 @@ async fn main() -> Result<(), Box> { let mut roots = RootCertStore::empty(); let mut buf = std::io::BufReader::new(&fd); - let certs = rustls_pemfile::certs(&mut buf)?; - roots.add_parsable_certificates(&certs); + let certs = rustls_pemfile::certs(&mut buf).collect::, _>>()?; + roots.add_parsable_certificates(certs.into_iter()); let tls = ClientConfig::builder() - .with_safe_defaults() .with_root_certificates(roots) .with_no_client_auth(); diff --git a/examples/src/tls_rustls/server.rs b/examples/src/tls_rustls/server.rs index 82f009344..5630edfa1 100644 --- a/examples/src/tls_rustls/server.rs +++ b/examples/src/tls_rustls/server.rs @@ -2,45 +2,51 @@ pub mod pb { tonic::include_proto!("/grpc.examples.unaryecho"); } -use hyper::server::conn::Http; +use hyper::server::conn::http2::Builder; +use hyper_util::rt::{TokioExecutor, TokioIo}; use pb::{EchoRequest, EchoResponse}; use std::sync::Arc; use tokio::net::TcpListener; use tokio_rustls::{ - rustls::{Certificate, PrivateKey, ServerConfig}, + rustls::{ + pki_types::{CertificateDer, PrivatePkcs8KeyDer}, + ServerConfig, + }, TlsAcceptor, }; +use tonic::transport::server::TowerToHyperService; use tonic::{transport::Server, Request, Response, Status}; use tower_http::ServiceBuilderExt; #[tokio::main] async fn main() -> Result<(), Box> { let data_dir = std::path::PathBuf::from_iter([std::env!("CARGO_MANIFEST_DIR"), "data"]); - let certs = { + let certs: Vec> = { let fd = std::fs::File::open(data_dir.join("tls/server.pem"))?; let mut buf = std::io::BufReader::new(&fd); - rustls_pemfile::certs(&mut buf)? + rustls_pemfile::certs(&mut buf) .into_iter() - .map(Certificate) - .collect() + .map(|res| res.map(|cert| cert.to_owned())) + .collect::, _>>()? }; - let key = { + let key: PrivatePkcs8KeyDer<'static> = { let fd = std::fs::File::open(data_dir.join("tls/server.key"))?; let mut buf = std::io::BufReader::new(&fd); - rustls_pemfile::pkcs8_private_keys(&mut buf)? + let key = rustls_pemfile::pkcs8_private_keys(&mut buf) .into_iter() - .map(PrivateKey) .next() - .unwrap() + .unwrap()? + .clone_key(); + + key // let key = std::fs::read(data_dir.join("tls/server.key"))?; // PrivateKey(key) }; let mut tls = ServerConfig::builder() - .with_safe_defaults() .with_no_client_auth() - .with_single_cert(certs, key)?; + .with_single_cert(certs, key.into())?; tls.alpn_protocols = vec![b"h2".to_vec()]; let server = EchoServer::default(); @@ -49,8 +55,7 @@ async fn main() -> Result<(), Box> { .add_service(pb::echo_server::EchoServer::new(server)) .into_service(); - let mut http = Http::new(); - http.http2_only(true); + let http = Builder::new(TokioExecutor::new()); let listener = TcpListener::bind("[::1]:50051").await?; let tls_acceptor = TlsAcceptor::from(Arc::new(tls)); @@ -86,7 +91,9 @@ async fn main() -> Result<(), Box> { .add_extension(Arc::new(ConnInfo { addr, certificates })) .service(svc); - http.serve_connection(conn, svc).await.unwrap(); + http.serve_connection(TokioIo::new(conn), TowerToHyperService::new(svc)) + .await + .unwrap(); }); } } @@ -94,7 +101,7 @@ async fn main() -> Result<(), Box> { #[derive(Debug)] struct ConnInfo { addr: std::net::SocketAddr, - certificates: Vec, + certificates: Vec>, } type EchoResult = Result, Status>; diff --git a/tonic/src/transport/service/tls.rs b/tonic/src/transport/service/tls.rs index 2a6394a4f..2ce9dc5da 100644 --- a/tonic/src/transport/service/tls.rs +++ b/tonic/src/transport/service/tls.rs @@ -18,6 +18,7 @@ use crate::transport::{ server::{Connected, TlsStream}, Certificate, Identity, }; +use hyper_util::rt::TokioIo; /// h2 alpn in plain format for rustls. const ALPN_H2: &[u8] = b"h2"; @@ -88,7 +89,7 @@ impl TlsConnector { if !(alpn_protocol == Some(ALPN_H2) || self.assume_http2) { return Err(TlsError::H2NotNegotiated.into()); } - Ok(BoxedIo::new(io)) + Ok(BoxedIo::new(TokioIo::new(io))) } } From fd2d6127deadc668ecdbcf45aec231be8ba02174 Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Fri, 31 May 2024 04:52:20 +0000 Subject: [PATCH 24/25] Combine trailers when streaming decode body We aren't sure if multiple trailers should even be legal, but if we get multiple trailers in an HTTP body stream, we'll combine them all, to preserve their data. Alternatively we'd have to pick the first or last trailers, and that might lose information. --- tonic/src/codec/decode.rs | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tonic/src/codec/decode.rs b/tonic/src/codec/decode.rs index dea83d931..a38c4a834 100644 --- a/tonic/src/codec/decode.rs +++ b/tonic/src/codec/decode.rs @@ -262,7 +262,15 @@ impl StreamingInner { Ok(Some(())) } frame if frame.is_trailers() => { - self.trailers = Some(frame.into_trailers().unwrap()); + match &mut self.trailers { + Some(trailers) => { + trailers.extend(frame.into_trailers().unwrap()); + } + None => { + self.trailers = Some(frame.into_trailers().unwrap()); + } + } + Ok(None) } frame => panic!("unexpected frame: {:?}", frame), From 3fb8a91c98d204b487223475603c575e6ad9bb55 Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Fri, 31 May 2024 04:53:19 +0000 Subject: [PATCH 25/25] Tweak imports in transport example Example used `empty_body()`, which is now fully qualified as `tonic::body::empty_body()` to make clear that this is a tonic helper method for creating an empty BoxBody. --- tonic/src/transport/mod.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tonic/src/transport/mod.rs b/tonic/src/transport/mod.rs index 357ab2bf4..cce5d7b6c 100644 --- a/tonic/src/transport/mod.rs +++ b/tonic/src/transport/mod.rs @@ -22,7 +22,6 @@ //! # use tonic::transport::{Channel, Certificate, ClientTlsConfig}; //! # use std::time::Duration; //! # use tonic::body::BoxBody; -//! # use tonic::body::empty_body; //! # use tonic::client::GrpcService;; //! # use http::Request; //! # #[cfg(feature = "rustls")] @@ -39,7 +38,7 @@ //! .connect() //! .await?; //! -//! channel.call(Request::new(empty_body())).await?; +//! channel.call(Request::new(tonic::body::empty_body())).await?; //! # Ok(()) //! # } //! ```