diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 57d05d3e3..e04868826 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -271,21 +271,21 @@ routeguide = ["dep:async-stream", "tokio-stream", "dep:rand", "dep:serde", "dep: reflection = ["dep:tonic-reflection"] autoreload = ["tokio-stream/net", "dep:listenfd"] health = ["dep:tonic-health"] -grpc-web = ["dep:tonic-web", "dep:bytes", "dep:http", "dep:hyper", "dep:tracing-subscriber", "dep:tower"] +grpc-web = ["dep:tonic-web", "dep:bytes", "dep:http", "dep:hyper", "dep:hyper-util", "dep:tracing-subscriber", "dep:tower"] tracing = ["dep:tracing", "dep:tracing-subscriber"] -uds = ["tokio-stream/net", "dep:tower", "dep:hyper"] +uds = ["tokio-stream/net", "dep:tower", "dep:hyper", "dep:hyper-util"] streaming = ["tokio-stream", "dep:h2"] -mock = ["tokio-stream", "dep:tower"] -tower = ["dep:hyper", "dep:tower", "dep:http"] +mock = ["tokio-stream", "dep:tower", "dep:hyper-util"] +tower = ["dep:hyper", "dep:hyper-util", "dep:tower", "dep:http"] json-codec = ["dep:serde", "dep:serde_json", "dep:bytes"] compression = ["tonic/gzip"] tls = ["tonic/tls"] -tls-rustls = ["dep:hyper", "dep:hyper-rustls", "dep:tower", "tower-http/util", "tower-http/add-extension", "dep:rustls-pemfile", "dep:tokio-rustls"] +tls-rustls = ["dep:hyper", "dep:hyper-util", "dep:hyper-rustls", "dep:tower", "tower-http/util", "tower-http/add-extension", "dep:rustls-pemfile", "dep:tokio-rustls"] dynamic-load-balance = ["dep:tower"] timeout = ["tokio/time", "dep:tower"] tls-client-auth = ["tonic/tls"] types = ["dep:tonic-types"] -h2c = ["dep:hyper", "dep:tower", "dep:http"] +h2c = ["dep:hyper", "dep:tower", "dep:http", "dep:hyper-util"] cancellation = ["dep:tokio-util"] full = ["gcp", "routeguide", "reflection", "autoreload", "health", "grpc-web", "tracing", "uds", "streaming", "mock", "tower", "json-codec", "compression", "tls", "tls-rustls", "dynamic-load-balance", "timeout", "tls-client-auth", "types", "cancellation", "h2c"] @@ -311,16 +311,18 @@ serde_json = { version = "1.0", optional = true } tracing = { version = "0.1.16", optional = true } tracing-subscriber = { version = "0.3", features = ["tracing-log", "fmt"], optional = true } prost-types = { version = "0.12", optional = true } -http = { version = "0.2", optional = true } -http-body = { version = "0.4.2", optional = true } -hyper = { version = "0.14", optional = true } +http = { version = "1", optional = true } +http-body = { version = "1", optional = true } +http-body-util = { version = "0.1", optional = true } +hyper = { version = "1", optional = true } +hyper-util = { version = "0.1", optional = true } listenfd = { version = "1.0", optional = true } bytes = { version = "1", optional = true } h2 = { version = "0.3", optional = true } -tokio-rustls = { version = "0.24.0", optional = true } -hyper-rustls = { version = "0.24.0", features = ["http2"], optional = true } -rustls-pemfile = { version = "1", optional = true } -tower-http = { version = "0.4", optional = true } +tokio-rustls = { version = "0.26", optional = true, features = ["ring", "tls12"], default-features = false } +hyper-rustls = { version = "0.27.0", features = ["http2", "ring", "tls12"], optional = true, default-features = false } +rustls-pemfile = { version = "2.0.0", optional = true } +tower-http = { version = "0.5", optional = true } [build-dependencies] tonic-build = { path = "../tonic-build", features = ["prost"] } diff --git a/examples/src/grpc-web/client.rs b/examples/src/grpc-web/client.rs index a16ac674a..fa64dd506 100644 --- a/examples/src/grpc-web/client.rs +++ b/examples/src/grpc-web/client.rs @@ -1,4 +1,5 @@ use hello_world::{greeter_client::GreeterClient, HelloRequest}; +use hyper_util::rt::TokioExecutor; use tonic_web::GrpcWebClientLayer; pub mod hello_world { @@ -8,7 +9,7 @@ pub mod hello_world { #[tokio::main] async fn main() -> Result<(), Box> { // Must use hyper directly... - let client = hyper::Client::builder().build_http(); + let client = hyper_util::client::legacy::Client::builder(TokioExecutor::new()).build_http(); let svc = tower::ServiceBuilder::new() .layer(GrpcWebClientLayer::new()) diff --git a/examples/src/h2c/client.rs b/examples/src/h2c/client.rs index 31076b1ac..b162fcc08 100644 --- a/examples/src/h2c/client.rs +++ b/examples/src/h2c/client.rs @@ -1,7 +1,8 @@ use hello_world::greeter_client::GreeterClient; use hello_world::HelloRequest; use http::Uri; -use hyper::Client; +use hyper_util::client::legacy::Client; +use hyper_util::rt::TokioExecutor; pub mod hello_world { tonic::include_proto!("helloworld"); @@ -11,7 +12,7 @@ pub mod hello_world { async fn main() -> Result<(), Box> { let origin = Uri::from_static("http://[::1]:50051"); let h2c_client = h2c::H2cChannel { - client: Client::new(), + client: Client::builder(TokioExecutor::new()).build_http(), }; let mut client = GreeterClient::with_origin(h2c_client, origin); @@ -33,16 +34,20 @@ mod h2c { task::{Context, Poll}, }; - use hyper::{client::HttpConnector, Client}; - use tonic::body::BoxBody; + use hyper::body::Incoming; + use hyper_util::{ + client::legacy::{connect::HttpConnector, Client}, + rt::TokioExecutor, + }; + use tonic::body::{empty_body, BoxBody}; use tower::Service; pub struct H2cChannel { - pub client: Client, + pub client: Client, } impl Service> for H2cChannel { - type Response = http::Response; + type Response = http::Response; type Error = hyper::Error; type Future = Pin> + Send>>; @@ -60,7 +65,7 @@ mod h2c { let h2c_req = hyper::Request::builder() .uri(origin) .header(http::header::UPGRADE, "h2c") - .body(hyper::Body::empty()) + .body(empty_body()) .unwrap(); let res = client.request(h2c_req).await.unwrap(); @@ -72,11 +77,11 @@ mod h2c { let upgraded_io = hyper::upgrade::on(res).await.unwrap(); // In an ideal world you would somehow cache this connection - let (mut h2_client, conn) = hyper::client::conn::Builder::new() - .http2_only(true) - .handshake(upgraded_io) - .await - .unwrap(); + let (mut h2_client, conn) = + hyper::client::conn::http2::Builder::new(TokioExecutor::new()) + .handshake(upgraded_io) + .await + .unwrap(); tokio::spawn(conn); h2_client.send_request(request).await diff --git a/examples/src/h2c/server.rs b/examples/src/h2c/server.rs index 92d08a417..da5c3425c 100644 --- a/examples/src/h2c/server.rs +++ b/examples/src/h2c/server.rs @@ -1,8 +1,14 @@ +use std::net::SocketAddr; + +use hyper_util::rt::{TokioExecutor, TokioIo}; +use hyper_util::server::conn::auto::Builder; +use hyper_util::service::TowerToHyperService; +use tokio::net::TcpListener; +// use tonic::transport::server::TowerToHyperService; use tonic::{transport::Server, Request, Response, Status}; use hello_world::greeter_server::{Greeter, GreeterServer}; use hello_world::{HelloReply, HelloRequest}; -use tower::make::Shared; pub mod hello_world { tonic::include_proto!("helloworld"); @@ -28,28 +34,45 @@ impl Greeter for MyGreeter { #[tokio::main] async fn main() -> Result<(), Box> { - let addr = "[::1]:50051".parse().unwrap(); + let addr: SocketAddr = "[::1]:50051".parse().unwrap(); let greeter = MyGreeter::default(); println!("GreeterServer listening on {}", addr); + let incoming = TcpListener::bind(addr).await?; let svc = Server::builder() .add_service(GreeterServer::new(greeter)) .into_router(); let h2c = h2c::H2c { s: svc }; - let server = hyper::Server::bind(&addr).serve(Shared::new(h2c)); - server.await.unwrap(); - - Ok(()) + loop { + match incoming.accept().await { + Ok((io, _)) => { + let router = h2c.clone(); + tokio::spawn(async move { + let builder = Builder::new(TokioExecutor::new()); + let conn = builder.serve_connection_with_upgrades( + TokioIo::new(io), + TowerToHyperService::new(router), + ); + let _ = conn.await; + }); + } + Err(e) => { + eprintln!("Error accepting connection: {}", e); + } + } + } } mod h2c { use std::pin::Pin; use http::{Request, Response}; - use hyper::Body; + use hyper::body::Incoming; + use hyper_util::{rt::TokioExecutor, service::TowerToHyperService}; + use tonic::{body::empty_body, transport::AxumBoxBody}; use tower::Service; #[derive(Clone)] @@ -59,17 +82,14 @@ mod h2c { type BoxError = Box; - impl Service> for H2c + impl Service> for H2c where - S: Service, Response = Response> - + Clone - + Send - + 'static, + S: Service, Response = Response> + Clone + Send + 'static, S::Future: Send + 'static, S::Error: Into + Sync + Send + 'static, S::Response: Send + 'static, { - type Response = hyper::Response; + type Response = hyper::Response; type Error = hyper::Error; type Future = Pin> + Send>>; @@ -81,20 +101,19 @@ mod h2c { std::task::Poll::Ready(Ok(())) } - fn call(&mut self, mut req: hyper::Request) -> Self::Future { + fn call(&mut self, mut req: hyper::Request) -> Self::Future { let svc = self.s.clone(); Box::pin(async move { tokio::spawn(async move { let upgraded_io = hyper::upgrade::on(&mut req).await.unwrap(); - hyper::server::conn::Http::new() - .http2_only(true) - .serve_connection(upgraded_io, svc) + hyper::server::conn::http2::Builder::new(TokioExecutor::new()) + .serve_connection(upgraded_io, TowerToHyperService::new(svc)) .await .unwrap(); }); - let mut res = hyper::Response::new(hyper::Body::empty()); + let mut res = hyper::Response::new(empty_body()); *res.status_mut() = http::StatusCode::SWITCHING_PROTOCOLS; res.headers_mut().insert( hyper::header::UPGRADE, diff --git a/examples/src/interceptor/server.rs b/examples/src/interceptor/server.rs index 263348a6d..fd0cf462f 100644 --- a/examples/src/interceptor/server.rs +++ b/examples/src/interceptor/server.rs @@ -57,6 +57,7 @@ fn intercept(mut req: Request<()>) -> Result, Status> { Ok(req) } +#[derive(Clone)] struct MyExtension { some_piece_of_data: String, } diff --git a/examples/src/mock/mock.rs b/examples/src/mock/mock.rs index 0d3754921..6c26a6735 100644 --- a/examples/src/mock/mock.rs +++ b/examples/src/mock/mock.rs @@ -1,3 +1,4 @@ +use hyper_util::rt::TokioIo; use tonic::{ transport::{Endpoint, Server, Uri}, Request, Response, Status, @@ -36,7 +37,7 @@ async fn main() -> Result<(), Box> { async move { if let Some(client) = client { - Ok(client) + Ok(TokioIo::new(client)) } else { Err(std::io::Error::new( std::io::ErrorKind::Other, diff --git a/examples/src/tls_rustls/client.rs b/examples/src/tls_rustls/client.rs index 23d6e8130..4c39a2c46 100644 --- a/examples/src/tls_rustls/client.rs +++ b/examples/src/tls_rustls/client.rs @@ -5,7 +5,8 @@ pub mod pb { tonic::include_proto!("/grpc.examples.unaryecho"); } -use hyper::{client::HttpConnector, Uri}; +use hyper::Uri; +use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor}; use pb::{echo_client::EchoClient, EchoRequest}; use tokio_rustls::rustls::{ClientConfig, RootCertStore}; @@ -17,11 +18,10 @@ async fn main() -> Result<(), Box> { let mut roots = RootCertStore::empty(); let mut buf = std::io::BufReader::new(&fd); - let certs = rustls_pemfile::certs(&mut buf)?; - roots.add_parsable_certificates(&certs); + let certs = rustls_pemfile::certs(&mut buf).collect::, _>>()?; + roots.add_parsable_certificates(certs.into_iter()); let tls = ClientConfig::builder() - .with_safe_defaults() .with_root_certificates(roots) .with_no_client_auth(); @@ -47,7 +47,7 @@ async fn main() -> Result<(), Box> { .map_request(|_| Uri::from_static("https://[::1]:50051")) .service(http); - let client = hyper::Client::builder().build(connector); + let client = hyper_util::client::legacy::Client::builder(TokioExecutor::new()).build(connector); // Using `with_origin` will let the codegenerated client set the `scheme` and // `authority` from the porvided `Uri`. diff --git a/examples/src/tls_rustls/server.rs b/examples/src/tls_rustls/server.rs index 82f009344..5630edfa1 100644 --- a/examples/src/tls_rustls/server.rs +++ b/examples/src/tls_rustls/server.rs @@ -2,45 +2,51 @@ pub mod pb { tonic::include_proto!("/grpc.examples.unaryecho"); } -use hyper::server::conn::Http; +use hyper::server::conn::http2::Builder; +use hyper_util::rt::{TokioExecutor, TokioIo}; use pb::{EchoRequest, EchoResponse}; use std::sync::Arc; use tokio::net::TcpListener; use tokio_rustls::{ - rustls::{Certificate, PrivateKey, ServerConfig}, + rustls::{ + pki_types::{CertificateDer, PrivatePkcs8KeyDer}, + ServerConfig, + }, TlsAcceptor, }; +use tonic::transport::server::TowerToHyperService; use tonic::{transport::Server, Request, Response, Status}; use tower_http::ServiceBuilderExt; #[tokio::main] async fn main() -> Result<(), Box> { let data_dir = std::path::PathBuf::from_iter([std::env!("CARGO_MANIFEST_DIR"), "data"]); - let certs = { + let certs: Vec> = { let fd = std::fs::File::open(data_dir.join("tls/server.pem"))?; let mut buf = std::io::BufReader::new(&fd); - rustls_pemfile::certs(&mut buf)? + rustls_pemfile::certs(&mut buf) .into_iter() - .map(Certificate) - .collect() + .map(|res| res.map(|cert| cert.to_owned())) + .collect::, _>>()? }; - let key = { + let key: PrivatePkcs8KeyDer<'static> = { let fd = std::fs::File::open(data_dir.join("tls/server.key"))?; let mut buf = std::io::BufReader::new(&fd); - rustls_pemfile::pkcs8_private_keys(&mut buf)? + let key = rustls_pemfile::pkcs8_private_keys(&mut buf) .into_iter() - .map(PrivateKey) .next() - .unwrap() + .unwrap()? + .clone_key(); + + key // let key = std::fs::read(data_dir.join("tls/server.key"))?; // PrivateKey(key) }; let mut tls = ServerConfig::builder() - .with_safe_defaults() .with_no_client_auth() - .with_single_cert(certs, key)?; + .with_single_cert(certs, key.into())?; tls.alpn_protocols = vec![b"h2".to_vec()]; let server = EchoServer::default(); @@ -49,8 +55,7 @@ async fn main() -> Result<(), Box> { .add_service(pb::echo_server::EchoServer::new(server)) .into_service(); - let mut http = Http::new(); - http.http2_only(true); + let http = Builder::new(TokioExecutor::new()); let listener = TcpListener::bind("[::1]:50051").await?; let tls_acceptor = TlsAcceptor::from(Arc::new(tls)); @@ -86,7 +91,9 @@ async fn main() -> Result<(), Box> { .add_extension(Arc::new(ConnInfo { addr, certificates })) .service(svc); - http.serve_connection(conn, svc).await.unwrap(); + http.serve_connection(TokioIo::new(conn), TowerToHyperService::new(svc)) + .await + .unwrap(); }); } } @@ -94,7 +101,7 @@ async fn main() -> Result<(), Box> { #[derive(Debug)] struct ConnInfo { addr: std::net::SocketAddr, - certificates: Vec, + certificates: Vec>, } type EchoResult = Result, Status>; diff --git a/examples/src/tower/client.rs b/examples/src/tower/client.rs index 0a33fffae..39fec5d47 100644 --- a/examples/src/tower/client.rs +++ b/examples/src/tower/client.rs @@ -44,7 +44,6 @@ mod service { use std::pin::Pin; use std::task::{Context, Poll}; use tonic::body::BoxBody; - use tonic::transport::Body; use tonic::transport::Channel; use tower::Service; @@ -59,7 +58,7 @@ mod service { } impl Service> for AuthSvc { - type Response = Response; + type Response = Response; type Error = Box; #[allow(clippy::type_complexity)] type Future = Pin> + Send>>; diff --git a/examples/src/tower/server.rs b/examples/src/tower/server.rs index cc85d62e5..b7066a1b6 100644 --- a/examples/src/tower/server.rs +++ b/examples/src/tower/server.rs @@ -1,4 +1,3 @@ -use hyper::Body; use std::{ pin::Pin, task::{Context, Poll}, @@ -84,9 +83,12 @@ struct MyMiddleware { type BoxFuture<'a, T> = Pin + Send + 'a>>; -impl Service> for MyMiddleware +impl Service> for MyMiddleware where - S: Service, Response = hyper::Response> + Clone + Send + 'static, + S: Service, Response = hyper::Response> + + Clone + + Send + + 'static, S::Future: Send + 'static, { type Response = S::Response; @@ -97,7 +99,7 @@ where self.inner.poll_ready(cx) } - fn call(&mut self, req: hyper::Request) -> Self::Future { + fn call(&mut self, req: hyper::Request) -> Self::Future { // This is necessary because tonic internally uses `tower::buffer::Buffer`. // See https://github.com/tower-rs/tower/issues/547#issuecomment-767629149 // for details on why this is necessary diff --git a/examples/src/uds/client.rs b/examples/src/uds/client.rs index e78531ac4..9a09e6981 100644 --- a/examples/src/uds/client.rs +++ b/examples/src/uds/client.rs @@ -5,6 +5,7 @@ pub mod hello_world { } use hello_world::{greeter_client::GreeterClient, HelloRequest}; +use hyper_util::rt::TokioIo; #[cfg(unix)] use tokio::net::UnixStream; use tonic::transport::{Endpoint, Uri}; @@ -16,12 +17,13 @@ async fn main() -> Result<(), Box> { // We will ignore this uri because uds do not use it // if your connector does use the uri it will be provided // as the request to the `MakeConnection`. + let channel = Endpoint::try_from("http://[::]:50051")? - .connect_with_connector(service_fn(|_: Uri| { + .connect_with_connector(service_fn(|_: Uri| async { let path = "/tmp/tonic/helloworld"; // Connect to a Uds socket - UnixStream::connect(path) + Ok::<_, std::io::Error>(TokioIo::new(UnixStream::connect(path).await?)) })) .await?; diff --git a/interop/Cargo.toml b/interop/Cargo.toml index a58ef64cf..9a32b2a1d 100644 --- a/interop/Cargo.toml +++ b/interop/Cargo.toml @@ -19,9 +19,9 @@ async-stream = "0.3" strum = {version = "0.26", features = ["derive"]} pico-args = {version = "0.5", features = ["eq-separator"]} console = "0.15" -http = "0.2" -http-body = "0.4.2" -hyper = "0.14" +http = "1" +http-body = "1" +hyper = "1" prost = "0.12" tokio = {version = "1.0", features = ["rt-multi-thread", "time", "macros"]} tokio-stream = "0.1" diff --git a/interop/src/server.rs b/interop/src/server.rs index b32468866..38b1be65e 100644 --- a/interop/src/server.rs +++ b/interop/src/server.rs @@ -1,10 +1,10 @@ use crate::pb::{self, *}; use async_stream::try_stream; -use http::header::{HeaderMap, HeaderName, HeaderValue}; +use http::header::{HeaderName, HeaderValue}; use http_body::Body; use std::future::Future; use std::pin::Pin; -use std::task::{Context, Poll}; +use std::task::{ready, Context, Poll}; use std::time::Duration; use tokio_stream::StreamExt; use tonic::{body::BoxBody, server::NamedService, Code, Request, Response, Status}; @@ -180,9 +180,9 @@ impl EchoHeadersSvc { } } -impl Service> for EchoHeadersSvc +impl Service> for EchoHeadersSvc where - S: Service, Response = http::Response> + Send, + S: Service, Response = http::Response> + Send, S::Future: Send + 'static, { type Response = S::Response; @@ -193,7 +193,7 @@ where Ok(()).into() } - fn call(&mut self, req: http::Request) -> Self::Future { + fn call(&mut self, req: http::Request) -> Self::Future { let echo_header = req.headers().get("x-grpc-test-echo-initial").cloned(); let echo_trailer = req @@ -235,25 +235,19 @@ impl Body for MergeTrailers { type Data = B::Data; type Error = B::Error; - fn poll_data( - mut self: Pin<&mut Self>, + fn poll_frame( + self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { - Pin::new(&mut self.inner).poll_data(cx) - } - - fn poll_trailers( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - Pin::new(&mut self.inner).poll_trailers(cx).map_ok(|h| { - h.map(|mut headers| { - if let Some((key, value)) = &self.trailer { - headers.insert(key.clone(), value.clone()); + ) -> Poll, Self::Error>>> { + let this = self.get_mut(); + let mut frame = ready!(Pin::new(&mut this.inner).poll_frame(cx)?); + if let Some(frame) = frame.as_mut() { + if let Some(trailers) = frame.trailers_mut() { + if let Some((key, value)) = &this.trailer { + trailers.insert(key.clone(), value.clone()); } - - headers - }) - }) + } + } + Poll::Ready(frame.map(Ok)) } } diff --git a/tests/compression/Cargo.toml b/tests/compression/Cargo.toml index 5bc87c829..cf4da321b 100644 --- a/tests/compression/Cargo.toml +++ b/tests/compression/Cargo.toml @@ -8,9 +8,11 @@ version = "0.1.0" [dependencies] bytes = "1" -http = "0.2" -http-body = "0.4" -hyper = "0.14.3" +http = "1" +http-body = "1" +http-body-util = "0.1" +hyper = "1" +hyper-util = "0.1" paste = "1.0.12" pin-project = "1.0" prost = "0.12" @@ -18,7 +20,7 @@ tokio = {version = "1.0", features = ["macros", "rt-multi-thread", "net"]} tokio-stream = "0.1" tonic = {path = "../../tonic", features = ["gzip", "zstd"]} tower = {version = "0.4", features = []} -tower-http = {version = "0.4", features = ["map-response-body", "map-request-body"]} +tower-http = {version = "0.5", features = ["map-response-body", "map-request-body"]} [build-dependencies] tonic-build = {path = "../../tonic-build" } diff --git a/tests/compression/src/util.rs b/tests/compression/src/util.rs index 28fa5d96a..d7e250ce4 100644 --- a/tests/compression/src/util.rs +++ b/tests/compression/src/util.rs @@ -1,6 +1,8 @@ use super::*; -use bytes::Bytes; -use http_body::Body; +use bytes::{Buf, Bytes}; +use http_body::{Body, Frame}; +use http_body_util::BodyExt as _; +use hyper_util::rt::TokioIo; use pin_project::pin_project; use std::{ pin::Pin, @@ -11,6 +13,7 @@ use std::{ task::{ready, Context, Poll}, }; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tonic::body::BoxBody; use tonic::codec::CompressionEncoding; use tonic::transport::{server::Connected, Channel}; use tower_http::map_request_body::MapRequestBodyLayer; @@ -46,29 +49,22 @@ where type Data = B::Data; type Error = B::Error; - fn poll_data( + fn poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Self::Error>>> { let this = self.project(); let counter: Arc = this.counter.clone(); - match ready!(this.inner.poll_data(cx)) { + match ready!(this.inner.poll_frame(cx)) { Some(Ok(chunk)) => { - println!("response body chunk size = {}", chunk.len()); - counter.fetch_add(chunk.len(), SeqCst); + println!("response body chunk size = {}", frame_data_length(&chunk)); + counter.fetch_add(frame_data_length(&chunk), SeqCst); Poll::Ready(Some(Ok(chunk))) } x => Poll::Ready(x), } } - fn poll_trailers( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - self.project().inner.poll_trailers(cx) - } - fn is_end_stream(&self) -> bool { self.inner.is_end_stream() } @@ -78,28 +74,61 @@ where } } +fn frame_data_length(frame: &http_body::Frame) -> usize { + if let Some(data) = frame.data_ref() { + data.len() + } else { + 0 + } +} + +#[pin_project] +struct ChannelBody { + #[pin] + rx: tokio::sync::mpsc::Receiver>, +} + +impl ChannelBody { + pub fn new() -> (tokio::sync::mpsc::Sender>, Self) { + let (tx, rx) = tokio::sync::mpsc::channel(32); + (tx, Self { rx }) + } +} + +impl Body for ChannelBody +where + T: Buf, +{ + type Data = T; + type Error = tonic::Status; + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + let frame = ready!(self.project().rx.poll_recv(cx)); + Poll::Ready(frame.map(Ok)) + } +} + #[allow(dead_code)] pub fn measure_request_body_size_layer( bytes_sent_counter: Arc, -) -> MapRequestBodyLayer hyper::Body + Clone> { - MapRequestBodyLayer::new(move |mut body: hyper::Body| { - let (mut tx, new_body) = hyper::Body::channel(); +) -> MapRequestBodyLayer BoxBody + Clone> { + MapRequestBodyLayer::new(move |mut body: BoxBody| { + let (tx, new_body) = ChannelBody::new(); let bytes_sent_counter = bytes_sent_counter.clone(); tokio::spawn(async move { - while let Some(chunk) = body.data().await { + while let Some(chunk) = body.frame().await { let chunk = chunk.unwrap(); - println!("request body chunk size = {}", chunk.len()); - bytes_sent_counter.fetch_add(chunk.len(), SeqCst); - tx.send_data(chunk).await.unwrap(); - } - - if let Some(trailers) = body.trailers().await.unwrap() { - tx.send_trailers(trailers).await.unwrap(); + println!("request body chunk size = {}", frame_data_length(&chunk)); + bytes_sent_counter.fetch_add(frame_data_length(&chunk), SeqCst); + tx.send(chunk).await.unwrap(); } }); - new_body + new_body.boxed_unsync() }) } @@ -110,7 +139,7 @@ pub async fn mock_io_channel(client: tokio::io::DuplexStream) -> Channel { Endpoint::try_from("http://[::]:50051") .unwrap() .connect_with_connector(service_fn(move |_: Uri| { - let client = client.take().unwrap(); + let client = TokioIo::new(client.take().unwrap()); async move { Ok::<_, std::io::Error>(client) } })) .await diff --git a/tests/integration_tests/Cargo.toml b/tests/integration_tests/Cargo.toml index 222d1919c..cfeebf725 100644 --- a/tests/integration_tests/Cargo.toml +++ b/tests/integration_tests/Cargo.toml @@ -17,12 +17,13 @@ tracing-subscriber = {version = "0.3"} [dev-dependencies] async-stream = "0.3" -http = "0.2" -http-body = "0.4" -hyper = "0.14" +http = "1" +http-body = "1" +hyper = "1" +hyper-util = "0.1" tokio-stream = {version = "0.1.5", features = ["net"]} tower = {version = "0.4", features = []} -tower-http = { version = "0.4", features = ["set-header", "trace"] } +tower-http = { version = "0.5", features = ["set-header", "trace"] } tower-service = "0.3" tracing = "0.1" diff --git a/tests/integration_tests/tests/complex_tower_middleware.rs b/tests/integration_tests/tests/complex_tower_middleware.rs index 5d7690be3..b1b669426 100644 --- a/tests/integration_tests/tests/complex_tower_middleware.rs +++ b/tests/integration_tests/tests/complex_tower_middleware.rs @@ -97,17 +97,10 @@ where type Data = B::Data; type Error = BoxError; - fn poll_data( + fn poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { - unimplemented!() - } - - fn poll_trailers( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { + ) -> Poll, Self::Error>>> { unimplemented!() } } diff --git a/tests/integration_tests/tests/connect_info.rs b/tests/integration_tests/tests/connect_info.rs index 94fac8221..e87bb858f 100644 --- a/tests/integration_tests/tests/connect_info.rs +++ b/tests/integration_tests/tests/connect_info.rs @@ -51,6 +51,9 @@ async fn getting_connect_info() { #[cfg(unix)] pub mod unix { + use std::io; + + use hyper_util::rt::TokioIo; use tokio::{ net::{UnixListener, UnixStream}, sync::oneshot, @@ -106,7 +109,10 @@ pub mod unix { let path = unix_socket_path.clone(); let channel = Endpoint::try_from("http://[::]:50051") .unwrap() - .connect_with_connector(service_fn(move |_: Uri| UnixStream::connect(path.clone()))) + .connect_with_connector(service_fn(move |_: Uri| { + let path = path.clone(); + async move { Ok::<_, io::Error>(TokioIo::new(UnixStream::connect(path).await?)) } + })) .await .unwrap(); diff --git a/tests/integration_tests/tests/extensions.rs b/tests/integration_tests/tests/extensions.rs index b112f8e66..b2380181d 100644 --- a/tests/integration_tests/tests/extensions.rs +++ b/tests/integration_tests/tests/extensions.rs @@ -1,4 +1,4 @@ -use hyper::{Body, Request as HyperRequest, Response as HyperResponse}; +use hyper::{Request as HyperRequest, Response as HyperResponse}; use integration_tests::{ pb::{test_client, test_server, Input, Output}, BoxFuture, @@ -16,6 +16,7 @@ use tonic::{ }; use tower_service::Service; +#[derive(Clone)] struct ExtensionValue(i32); #[tokio::test] @@ -112,9 +113,9 @@ struct InterceptedService { inner: S, } -impl Service> for InterceptedService +impl Service> for InterceptedService where - S: Service, Response = HyperResponse> + S: Service, Response = HyperResponse> + NamedService + Clone + Send @@ -129,7 +130,7 @@ where self.inner.poll_ready(cx) } - fn call(&mut self, mut req: HyperRequest) -> Self::Future { + fn call(&mut self, mut req: HyperRequest) -> Self::Future { let clone = self.inner.clone(); let mut inner = std::mem::replace(&mut self.inner, clone); diff --git a/tests/integration_tests/tests/max_message_size.rs b/tests/integration_tests/tests/max_message_size.rs index 9ae524dbc..f03699cdf 100644 --- a/tests/integration_tests/tests/max_message_size.rs +++ b/tests/integration_tests/tests/max_message_size.rs @@ -1,5 +1,6 @@ use std::pin::Pin; +use hyper_util::rt::TokioIo; use integration_tests::{ pb::{test1_client, test1_server, Input1, Output1}, trace_init, @@ -163,7 +164,7 @@ async fn response_stream_limit() { async move { if let Some(client) = client { - Ok(client) + Ok(TokioIo::new(client)) } else { Err(std::io::Error::new( std::io::ErrorKind::Other, @@ -332,7 +333,7 @@ async fn max_message_run(case: &TestCase) -> Result<(), Status> { async move { if let Some(client) = client { - Ok(client) + Ok(TokioIo::new(client)) } else { Err(std::io::Error::new( std::io::ErrorKind::Other, diff --git a/tests/integration_tests/tests/origin.rs b/tests/integration_tests/tests/origin.rs index f149dc68d..e41287245 100644 --- a/tests/integration_tests/tests/origin.rs +++ b/tests/integration_tests/tests/origin.rs @@ -76,9 +76,9 @@ struct OriginService { inner: S, } -impl Service> for OriginService +impl Service> for OriginService where - T: Service>, + T: Service>, T::Future: Send + 'static, T::Error: Into>, { @@ -90,7 +90,7 @@ where self.inner.poll_ready(cx).map_err(Into::into) } - fn call(&mut self, req: Request) -> Self::Future { + fn call(&mut self, req: Request) -> Self::Future { assert_eq!(req.uri().host(), Some("docs.rs")); let fut = self.inner.call(req); diff --git a/tests/integration_tests/tests/status.rs b/tests/integration_tests/tests/status.rs index 3fdabcd36..df6bc4b3b 100644 --- a/tests/integration_tests/tests/status.rs +++ b/tests/integration_tests/tests/status.rs @@ -1,5 +1,6 @@ use bytes::Bytes; use http::Uri; +use hyper_util::rt::TokioIo; use integration_tests::mock::MockStream; use integration_tests::pb::{ test_client, test_server, test_stream_client, test_stream_server, Input, InputStream, Output, @@ -183,7 +184,7 @@ async fn status_from_server_stream_with_source() { let channel = Endpoint::try_from("http://[::]:50051") .unwrap() .connect_with_connector_lazy(tower::service_fn(move |_: Uri| async move { - Err::(std::io::Error::new(std::io::ErrorKind::Other, "WTF")) + Err::, _>(std::io::Error::new(std::io::ErrorKind::Other, "WTF")) })); let mut client = test_stream_client::TestStreamClient::new(channel); diff --git a/tonic-web/Cargo.toml b/tonic-web/Cargo.toml index d6649f65c..157605a95 100644 --- a/tonic-web/Cargo.toml +++ b/tonic-web/Cargo.toml @@ -18,14 +18,14 @@ version = "0.11.0" base64 = "0.22" bytes = "1" tokio-stream = "0.1" -http = "0.2" -http-body = "0.4" -hyper = {version = "0.14", default-features = false, features = ["stream"]} +http = "1" +http-body = "1" +http-body-util = "0.1" pin-project = "1" tonic = {version = "0.11", path = "../tonic", default-features = false} tower-service = "0.3" tower-layer = "0.3" -tower-http = { version = "0.4", features = ["cors"] } +tower-http = { version = "0.5", features = ["cors"] } tracing = "0.1" [dev-dependencies] diff --git a/tonic-web/src/call.rs b/tonic-web/src/call.rs index f52087e9e..178e620ae 100644 --- a/tonic-web/src/call.rs +++ b/tonic-web/src/call.rs @@ -5,7 +5,7 @@ use std::task::{ready, Context, Poll}; use base64::Engine as _; use bytes::{Buf, BufMut, Bytes, BytesMut}; use http::{header, HeaderMap, HeaderName, HeaderValue}; -use http_body::{Body, SizeHint}; +use http_body::{Body, Frame, SizeHint}; use pin_project::pin_project; use tokio_stream::Stream; use tonic::Status; @@ -63,9 +63,9 @@ pub struct GrpcWebCall { #[pin] inner: B, buf: BytesMut, + decoded: BytesMut, direction: Direction, encoding: Encoding, - poll_trailers: bool, client: bool, trailers: Option, } @@ -75,9 +75,9 @@ impl Default for GrpcWebCall { Self { inner: Default::default(), buf: Default::default(), + decoded: Default::default(), direction: Direction::Empty, encoding: Encoding::None, - poll_trailers: Default::default(), client: Default::default(), trailers: Default::default(), } @@ -108,9 +108,12 @@ impl GrpcWebCall { (Direction::Encode, Encoding::Base64) => BUFFER_SIZE, _ => 0, }), + decoded: BytesMut::with_capacity(match direction { + Direction::Decode => BUFFER_SIZE, + _ => 0, + }), direction, encoding, - poll_trailers: true, client: true, trailers: None, } @@ -123,9 +126,9 @@ impl GrpcWebCall { (Direction::Encode, Encoding::Base64) => BUFFER_SIZE, _ => 0, }), + decoded: BytesMut::with_capacity(0), direction, encoding, - poll_trailers: true, client: false, trailers: None, } @@ -160,24 +163,37 @@ where B: Body, B::Error: Error, { + // Poll body for data, decoding (e.g. via Base64 if necessary) and returning frames + // to the caller. If the caller is a client, it should look for trailers before + // returning these frames. fn poll_decode( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Status>>> { match self.encoding { Encoding::Base64 => loop { if let Some(bytes) = self.as_mut().decode_chunk()? { - return Poll::Ready(Some(Ok(bytes))); + return Poll::Ready(Some(Ok(Frame::data(bytes)))); } let mut this = self.as_mut().project(); - match ready!(this.inner.as_mut().poll_data(cx)) { - Some(Ok(data)) => this.buf.put(data), + match ready!(this.inner.as_mut().poll_frame(cx)) { + Some(Ok(frame)) if frame.is_data() => this.buf.put(frame.into_data().unwrap()), + Some(Ok(frame)) if frame.is_trailers() => { + return Poll::Ready(Some(Err(internal_error( + "malformed base64 request has unencoded trailers", + )))) + } + Some(Ok(_)) => { + return Poll::Ready(Some(Err(internal_error("unexpected frame type")))) + } Some(Err(e)) => return Poll::Ready(Some(Err(internal_error(e)))), None => { return if this.buf.has_remaining() { Poll::Ready(Some(Err(internal_error("malformed base64 request")))) + } else if let Some(trailers) = this.trailers.take() { + Poll::Ready(Some(Ok(Frame::trailers(trailers)))) } else { Poll::Ready(None) } @@ -185,7 +201,7 @@ where } }, - Encoding::None => match ready!(self.project().inner.poll_data(cx)) { + Encoding::None => match ready!(self.project().inner.poll_frame(cx)) { Some(res) => Poll::Ready(Some(res.map_err(internal_error))), None => Poll::Ready(None), }, @@ -195,37 +211,33 @@ where fn poll_encode( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Status>>> { let mut this = self.as_mut().project(); - if let Some(mut res) = ready!(this.inner.as_mut().poll_data(cx)) { - if *this.encoding == Encoding::Base64 { - res = res.map(|b| crate::util::base64::STANDARD.encode(b).into()) - } - - return Poll::Ready(Some(res.map_err(internal_error))); - } - - // this flag is needed because the inner stream never - // returns Poll::Ready(None) when polled for trailers - if *this.poll_trailers { - return match ready!(this.inner.poll_trailers(cx)) { - Ok(Some(map)) => { - let mut frame = make_trailers_frame(map); + match ready!(this.inner.as_mut().poll_frame(cx)) { + Some(Ok(frame)) if frame.is_data() => { + let mut res = frame.into_data().unwrap(); - if *this.encoding == Encoding::Base64 { - frame = crate::util::base64::STANDARD.encode(frame).into_bytes(); - } + if *this.encoding == Encoding::Base64 { + let mut buf = Vec::with_capacity(res.len()); + buf.extend_from_slice(&res); + res = crate::util::base64::STANDARD.encode(buf).into(); + } - *this.poll_trailers = false; - Poll::Ready(Some(Ok(frame.into()))) + Poll::Ready(Some(Ok(Frame::data(res)))) + } + Some(Ok(frame)) if frame.is_trailers() => { + let trailers = frame.into_trailers().expect("must be trailers"); + let mut frame = make_trailers_frame(trailers); + if *this.encoding == Encoding::Base64 { + frame = crate::util::base64::STANDARD.encode(frame).into_bytes(); } - Ok(None) => Poll::Ready(None), - Err(e) => Poll::Ready(Some(Err(internal_error(e)))), - }; + Poll::Ready(Some(Ok(Frame::data(frame.into())))) + } + Some(Ok(_)) => Poll::Ready(Some(Err(internal_error("unexepected frame type")))), + Some(Err(e)) => Poll::Ready(Some(Err(internal_error(e)))), + None => Poll::Ready(None), } - - Poll::Ready(None) } } @@ -237,28 +249,34 @@ where type Data = Bytes; type Error = Status; - fn poll_data( + fn poll_frame( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Self::Error>>> { if self.client && self.direction == Direction::Decode { let mut me = self.as_mut(); loop { - let incoming_buf = match ready!(me.as_mut().poll_decode(cx)) { - Some(Ok(incoming_buf)) => incoming_buf, - None => { - // TODO: Consider eofing here? - // Even if the buffer has more data, this will hit the eof branch - // of decode in tonic - return Poll::Ready(None); + match ready!(me.as_mut().poll_decode(cx)) { + Some(Ok(incoming_buf)) if incoming_buf.is_data() => { + me.as_mut() + .project() + .decoded + .put(incoming_buf.into_data().unwrap()); + } + Some(Ok(incoming_buf)) if incoming_buf.is_trailers() => { + let trailers = incoming_buf.into_trailers().unwrap(); + me.as_mut().project().trailers.replace(trailers); + continue; } + Some(Ok(_)) => unreachable!("unexpected frame type"), + None => {} // No more data to decode, time to look for trailers Some(Err(e)) => return Poll::Ready(Some(Err(e))), }; - let buf = &mut me.as_mut().project().buf; - - buf.put(incoming_buf); + // Hold the incoming, decoded data until we have a full message + // or trailers to return. + let buf = me.as_mut().project().decoded; return match find_trailers(&buf[..])? { FindTrailers::Trailer(len) => { @@ -266,20 +284,24 @@ where let msg_buf = buf.copy_to_bytes(len); match decode_trailers_frame(buf.split().freeze()) { Ok(Some(trailers)) => { - self.project().trailers.replace(trailers); + me.as_mut().project().trailers.replace(trailers); } Err(e) => return Poll::Ready(Some(Err(e))), _ => {} } if msg_buf.has_remaining() { - Poll::Ready(Some(Ok(msg_buf))) + Poll::Ready(Some(Ok(Frame::data(msg_buf)))) + } else if let Some(trailers) = me.as_mut().project().trailers.take() { + Poll::Ready(Some(Ok(Frame::trailers(trailers)))) } else { Poll::Ready(None) } } FindTrailers::IncompleteBuf => continue, - FindTrailers::Done(len) => Poll::Ready(Some(Ok(buf.split_to(len).freeze()))), + FindTrailers::Done(len) => { + Poll::Ready(Some(Ok(Frame::data(buf.split_to(len).freeze())))) + } }; } } @@ -291,14 +313,6 @@ where } } - fn poll_trailers( - self: Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll>, Self::Error>> { - let trailers = self.project().trailers.take(); - Poll::Ready(Ok(trailers)) - } - fn is_end_stream(&self) -> bool { self.inner.is_end_stream() } @@ -313,10 +327,10 @@ where B: Body, B::Error: Error, { - type Item = Result; + type Item = Result, Status>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Body::poll_data(self, cx) + self.poll_frame(cx) } } diff --git a/tonic-web/src/layer.rs b/tonic-web/src/layer.rs index 77b03c77e..7834f1990 100644 --- a/tonic-web/src/layer.rs +++ b/tonic-web/src/layer.rs @@ -24,7 +24,7 @@ impl Default for GrpcWebLayer { impl Layer for GrpcWebLayer where - S: Service, Response = http::Response>, + S: Service, Response = http::Response>, S: Send + 'static, S::Future: Send + 'static, S::Error: Into + Send, diff --git a/tonic-web/src/lib.rs b/tonic-web/src/lib.rs index 16e57e19d..50ed8c0a8 100644 --- a/tonic-web/src/lib.rs +++ b/tonic-web/src/lib.rs @@ -127,7 +127,7 @@ type BoxError = Box; /// You can customize the CORS configuration composing the [`GrpcWebLayer`] with the cors layer of your choice. pub fn enable(service: S) -> CorsGrpcWeb where - S: Service, Response = http::Response>, + S: Service, Response = http::Response>, S: Clone + Send + 'static, S::Future: Send + 'static, S::Error: Into + Send, @@ -159,9 +159,9 @@ where #[derive(Debug, Clone)] pub struct CorsGrpcWeb(tower_http::cors::Cors>); -impl Service> for CorsGrpcWeb +impl Service> for CorsGrpcWeb where - S: Service, Response = http::Response>, + S: Service, Response = http::Response>, S: Clone + Send + 'static, S::Future: Send + 'static, S::Error: Into + Send, @@ -169,7 +169,7 @@ where type Response = S::Response; type Error = S::Error; type Future = - > as Service>>::Future; + > as Service>>::Future; fn poll_ready( &mut self, @@ -178,7 +178,7 @@ where self.0.poll_ready(cx) } - fn call(&mut self, req: http::Request) -> Self::Future { + fn call(&mut self, req: http::Request) -> Self::Future { self.0.call(req) } } diff --git a/tonic-web/src/service.rs b/tonic-web/src/service.rs index af4c5276f..da65ba832 100644 --- a/tonic-web/src/service.rs +++ b/tonic-web/src/service.rs @@ -3,7 +3,7 @@ use std::pin::Pin; use std::task::{ready, Context, Poll}; use http::{header, HeaderMap, HeaderValue, Method, Request, Response, StatusCode, Version}; -use hyper::Body; +use http_body_util::BodyExt; use pin_project::pin_project; use tonic::{ body::{empty_body, BoxBody}, @@ -50,7 +50,7 @@ impl GrpcWebService { impl GrpcWebService where - S: Service, Response = Response> + Send + 'static, + S: Service, Response = Response> + Send + 'static, { fn response(&self, status: StatusCode) -> ResponseFuture { ResponseFuture { @@ -66,9 +66,9 @@ where } } -impl Service> for GrpcWebService +impl Service> for GrpcWebService where - S: Service, Response = Response> + Send + 'static, + S: Service, Response = Response> + Send + 'static, S::Future: Send + 'static, S::Error: Into + Send, { @@ -80,7 +80,7 @@ where self.inner.poll_ready(cx) } - fn call(&mut self, req: Request) -> Self::Future { + fn call(&mut self, req: Request) -> Self::Future { match RequestKind::new(req.headers(), req.method(), req.version()) { // A valid grpc-web request, regardless of HTTP version. // @@ -202,7 +202,7 @@ impl<'a> RequestKind<'a> { // Mutating request headers to conform to a gRPC request is not really // necessary for us at this point. We could remove most of these except // maybe for inserting `header::TE`, which tonic should check? -fn coerce_request(mut req: Request, encoding: Encoding) -> Request { +fn coerce_request(mut req: Request, encoding: Encoding) -> Request { req.headers_mut().remove(header::CONTENT_LENGTH); req.headers_mut() @@ -216,8 +216,7 @@ fn coerce_request(mut req: Request, encoding: Encoding) -> Request { HeaderValue::from_static("identity,deflate,gzip"), ); - req.map(|b| GrpcWebCall::request(b, encoding)) - .map(Body::wrap_stream) + req.map(|b| GrpcWebCall::request(b, encoding).boxed_unsync()) } fn coerce_response(res: Response, encoding: Encoding) -> Response { @@ -246,7 +245,7 @@ mod tests { #[derive(Debug, Clone)] struct Svc; - impl tower_service::Service> for Svc { + impl tower_service::Service> for Svc { type Response = Response; type Error = String; type Future = BoxFuture; @@ -255,7 +254,7 @@ mod tests { Poll::Ready(Ok(())) } - fn call(&mut self, _: Request) -> Self::Future { + fn call(&mut self, _: Request) -> Self::Future { Box::pin(async { Ok(Response::new(empty_body())) }) } } @@ -266,15 +265,14 @@ mod tests { mod grpc_web { use super::*; - use http::HeaderValue; use tower_layer::Layer; - fn request() -> Request { + fn request() -> Request { Request::builder() .method(Method::POST) .header(CONTENT_TYPE, GRPC_WEB) .header(ORIGIN, "http://example.com") - .body(Body::empty()) + .body(empty_body()) .unwrap() } @@ -350,13 +348,13 @@ mod tests { mod options { use super::*; - fn request() -> Request { + fn request() -> Request { Request::builder() .method(Method::OPTIONS) .header(ORIGIN, "http://example.com") .header(ACCESS_CONTROL_REQUEST_HEADERS, "x-grpc-web") .header(ACCESS_CONTROL_REQUEST_METHOD, "POST") - .body(Body::empty()) + .body(empty_body()) .unwrap() } @@ -371,13 +369,12 @@ mod tests { mod grpc { use super::*; - use http::HeaderValue; - fn request() -> Request { + fn request() -> Request { Request::builder() .version(Version::HTTP_2) .header(CONTENT_TYPE, GRPC) - .body(Body::empty()) + .body(empty_body()) .unwrap() } @@ -397,7 +394,7 @@ mod tests { let req = Request::builder() .header(CONTENT_TYPE, GRPC) - .body(Body::empty()) + .body(empty_body()) .unwrap(); let res = svc.call(req).await.unwrap(); @@ -425,10 +422,10 @@ mod tests { mod other { use super::*; - fn request() -> Request { + fn request() -> Request { Request::builder() .header(CONTENT_TYPE, "application/text") - .body(Body::empty()) + .body(empty_body()) .unwrap() } diff --git a/tonic-web/tests/integration/Cargo.toml b/tonic-web/tests/integration/Cargo.toml index 5c6d5727e..38fd9ff32 100644 --- a/tonic-web/tests/integration/Cargo.toml +++ b/tonic-web/tests/integration/Cargo.toml @@ -9,7 +9,10 @@ license = "MIT" [dependencies] base64 = "0.22" bytes = "1.0" -hyper = "0.14" +http-body = "1" +http-body-util = "0.1" +hyper = "1" +hyper-util = "0.1" prost = "0.12" tokio = { version = "1", features = ["macros", "rt", "net"] } tokio-stream = { version = "0.1", features = ["net"] } diff --git a/tonic-web/tests/integration/tests/grpc_web.rs b/tonic-web/tests/integration/tests/grpc_web.rs index 3343d754c..2c57f2680 100644 --- a/tonic-web/tests/integration/tests/grpc_web.rs +++ b/tonic-web/tests/integration/tests/grpc_web.rs @@ -2,21 +2,27 @@ use std::net::SocketAddr; use base64::Engine as _; use bytes::{Buf, BufMut, Bytes, BytesMut}; +use http_body_util::{BodyExt as _, Full}; +use hyper::body::Incoming; use hyper::http::{header, StatusCode}; -use hyper::{Body, Client, Method, Request, Uri}; +use hyper::{Method, Request, Uri}; +use hyper_util::client::legacy::Client; +use hyper_util::rt::TokioExecutor; use prost::Message; use tokio::net::TcpListener; use tokio_stream::wrappers::TcpListenerStream; +use tonic::body::BoxBody; use tonic::transport::Server; use integration::pb::{test_server::TestServer, Input, Output}; use integration::Svc; +use tonic::Status; use tonic_web::GrpcWebLayer; #[tokio::test] async fn binary_request() { let server_url = spawn().await; - let client = Client::new(); + let client = Client::builder(TokioExecutor::new()).build_http(); let req = build_request(server_url, "grpc-web", "grpc-web"); let res = client.request(req).await.unwrap(); @@ -39,7 +45,7 @@ async fn binary_request() { #[tokio::test] async fn text_request() { let server_url = spawn().await; - let client = Client::new(); + let client = Client::builder(TokioExecutor::new()).build_http(); let req = build_request(server_url, "grpc-web-text", "grpc-web-text"); let res = client.request(req).await.unwrap(); @@ -102,7 +108,7 @@ fn encode_body() -> Bytes { buf.split_to(len + 5).freeze() } -fn build_request(base_uri: String, content_type: &str, accept: &str) -> Request { +fn build_request(base_uri: String, content_type: &str, accept: &str) -> Request { use header::{ACCEPT, CONTENT_TYPE, ORIGIN}; let request_uri = format!("{}/{}/{}", base_uri, "test.Test", "UnaryCall") @@ -123,12 +129,14 @@ fn build_request(base_uri: String, content_type: &str, accept: &str) -> Request< .header(ORIGIN, "http://example.com") .header(ACCEPT, format!("application/{}", accept)) .uri(request_uri) - .body(Body::from(bytes)) + .body(BoxBody::new( + Full::new(bytes).map_err(|err| Status::internal(err.to_string())), + )) .unwrap() } -async fn decode_body(body: Body, content_type: &str) -> (Output, Bytes) { - let mut body = hyper::body::to_bytes(body).await.unwrap(); +async fn decode_body(body: Incoming, content_type: &str) -> (Output, Bytes) { + let mut body = body.collect().await.unwrap().to_bytes(); if content_type == "application/grpc-web-text+proto" { body = integration::util::base64::STANDARD diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index 0d2be669f..b795cc5f9 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -37,10 +37,10 @@ transport = [ "dep:axum", "channel", "dep:h2", - "dep:hyper", - "dep:tokio", "tokio?/net", "tokio?/time", + "dep:hyper", "dep:hyper-util", "dep:hyper-timeout", + "dep:socket2", + "dep:tokio", "tokio?/macros", "tokio?/net", "tokio?/time", "dep:tower", - "dep:hyper-timeout", ] channel = [] @@ -51,10 +51,11 @@ channel = [] [dependencies] base64 = "0.22" bytes = "1.0" -http = "0.2" +http = "1" tracing = "0.1" -http-body = "0.4.4" +http-body = "1" +http-body-util = "0.1" percent-encoding = "2.1" pin-project = "1.0.11" tower-layer = "0.3" @@ -68,13 +69,15 @@ async-trait = {version = "0.1.13", optional = true} # transport async-stream = {version = "0.3", optional = true} -h2 = {version = "0.3.24", optional = true} -hyper = {version = "0.14.26", features = ["full"], optional = true} -hyper-timeout = {version = "0.4", optional = true} -tokio = {version = "1.0.1", optional = true} -tokio-stream = "0.1" +h2 = {version = "0.4", optional = true} +hyper = {version = "1", features = ["full"], optional = true} +hyper-util = { version = ">=0.1.4, <0.2", features = ["full"], optional = true } +hyper-timeout = {version = "0.5", optional = true} +socket2 = { version = ">=0.4.7, <0.6.0", optional = true, features = ["all"] } +tokio = {version = "1", default-features = false, optional = true} +tokio-stream = { version = "0.1", features = ["net"] } tower = {version = "0.4.7", default-features = false, features = ["balance", "buffer", "discover", "limit", "load", "make", "timeout", "util"], optional = true} -axum = {version = "0.6.9", default-features = false, optional = true} +axum = {version = "0.7", default-features = false, optional = true} # rustls rustls-pemfile = { version = "2.0", optional = true } diff --git a/tonic/benches/decode.rs b/tonic/benches/decode.rs index 5c7cd0159..22ab6d9d4 100644 --- a/tonic/benches/decode.rs +++ b/tonic/benches/decode.rs @@ -1,6 +1,6 @@ use bencher::{benchmark_group, benchmark_main, Bencher}; use bytes::{Buf, BufMut, Bytes, BytesMut}; -use http_body::Body; +use http_body::{Body, Frame, SizeHint}; use std::{ fmt::{Error, Formatter}, pin::Pin, @@ -58,23 +58,24 @@ impl Body for MockBody { type Data = Bytes; type Error = Status; - fn poll_data( + fn poll_frame( mut self: Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll>> { + _cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { if self.data.has_remaining() { let split = std::cmp::min(self.chunk_size, self.data.remaining()); - Poll::Ready(Some(Ok(self.data.split_to(split)))) + Poll::Ready(Some(Ok(Frame::data(self.data.split_to(split))))) } else { Poll::Ready(None) } } - fn poll_trailers( - self: Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll, Self::Error>> { - Poll::Ready(Ok(None)) + fn is_end_stream(&self) -> bool { + !self.data.is_empty() + } + + fn size_hint(&self) -> SizeHint { + SizeHint::with_exact(self.data.len() as u64) } } diff --git a/tonic/src/body.rs b/tonic/src/body.rs index ef95eec47..428c0dade 100644 --- a/tonic/src/body.rs +++ b/tonic/src/body.rs @@ -1,9 +1,9 @@ //! HTTP specific body utilities. -use http_body::Body; +use http_body_util::BodyExt; /// A type erased HTTP body used for tonic services. -pub type BoxBody = http_body::combinators::UnsyncBoxBody; +pub type BoxBody = http_body_util::combinators::UnsyncBoxBody; /// Convert a [`http_body::Body`] into a [`BoxBody`]. pub(crate) fn boxed(body: B) -> BoxBody @@ -16,7 +16,7 @@ where /// Create an empty `BoxBody` pub fn empty_body() -> BoxBody { - http_body::Empty::new() + http_body_util::Empty::new() .map_err(|err| match err {}) .boxed_unsync() } diff --git a/tonic/src/codec/decode.rs b/tonic/src/codec/decode.rs index 081f6193d..dea83d931 100644 --- a/tonic/src/codec/decode.rs +++ b/tonic/src/codec/decode.rs @@ -2,8 +2,9 @@ use super::compression::{decompress, CompressionEncoding, CompressionSettings}; use super::{BufferSettings, DecodeBuf, Decoder, DEFAULT_MAX_RECV_MESSAGE_SIZE, HEADER_SIZE}; use crate::{body::BoxBody, metadata::MetadataMap, Code, Status}; use bytes::{Buf, BufMut, BytesMut}; -use http::StatusCode; +use http::{HeaderMap, StatusCode}; use http_body::Body; +use http_body_util::BodyExt; use std::{ fmt, future, pin::Pin, @@ -27,7 +28,7 @@ struct StreamingInner { state: State, direction: Direction, buf: BytesMut, - trailers: Option, + trailers: Option, decompress_buf: BytesMut, encoding: Option, max_message_size: Option, @@ -121,7 +122,7 @@ impl Streaming { decoder: Box::new(decoder), inner: StreamingInner { body: body - .map_data(|mut buf| buf.copy_to_bytes(buf.remaining())) + .map_frame(|frame| frame.map_data(|mut buf| buf.copy_to_bytes(buf.remaining()))) .map_err(|err| Status::map_error(err.into())) .boxed_unsync(), state: State::ReadHeader, @@ -239,8 +240,8 @@ impl StreamingInner { } // Returns Some(()) if data was found or None if the loop in `poll_next` should break - fn poll_data(&mut self, cx: &mut Context<'_>) -> Poll, Status>> { - let chunk = match ready!(Pin::new(&mut self.body).poll_data(cx)) { + fn poll_frame(&mut self, cx: &mut Context<'_>) -> Poll, Status>> { + let chunk = match ready!(Pin::new(&mut self.body).poll_frame(cx)) { Some(Ok(d)) => Some(d), Some(Err(status)) => { if self.direction == Direction::Request && status.code() == Code::Cancelled { @@ -254,9 +255,18 @@ impl StreamingInner { None => None, }; - Poll::Ready(if let Some(data) = chunk { - self.buf.put(data); - Ok(Some(())) + Poll::Ready(if let Some(frame) = chunk { + match frame { + frame if frame.is_data() => { + self.buf.put(frame.into_data().unwrap()); + Ok(Some(())) + } + frame if frame.is_trailers() => { + self.trailers = Some(frame.into_trailers().unwrap()); + Ok(None) + } + frame => panic!("unexpected frame: {:?}", frame), + } } else { // FIXME: improve buf usage. if self.buf.has_remaining() { @@ -271,27 +281,18 @@ impl StreamingInner { }) } - fn poll_response(&mut self, cx: &mut Context<'_>) -> Poll> { + fn response(&mut self) -> Result<(), Status> { if let Direction::Response(status) = self.direction { - match ready!(Pin::new(&mut self.body).poll_trailers(cx)) { - Ok(trailer) => { - if let Err(e) = crate::status::infer_grpc_status(trailer.as_ref(), status) { - if let Some(e) = e { - return Poll::Ready(Err(e)); - } else { - return Poll::Ready(Ok(())); - } - } else { - self.trailers = trailer.map(MetadataMap::from_headers); - } - } - Err(status) => { - debug!("decoder inner trailers error: {:?}", status); - return Poll::Ready(Err(status)); + if let Err(e) = crate::status::infer_grpc_status(self.trailers.as_ref(), status) { + if let Some(e) = e { + // If the trailers contain a grpc-status, then we should return that as the error + // and otherwise stop the stream (by taking the error state) + self.trailers.take(); + return Err(e); } } } - Poll::Ready(Ok(())) + Ok(()) } } @@ -351,7 +352,7 @@ impl Streaming { // Shortcut to see if we already pulled the trailers in the stream step // we need to do that so that the stream can error on trailing grpc-status if let Some(trailers) = self.inner.trailers.take() { - return Ok(Some(trailers)); + return Ok(Some(MetadataMap::from_headers(trailers))); } // To fetch the trailers we must clear the body and drop it. @@ -360,16 +361,11 @@ impl Streaming { // Since we call poll_trailers internally on poll_next we need to // check if it got cached again. if let Some(trailers) = self.inner.trailers.take() { - return Ok(Some(trailers)); + return Ok(Some(MetadataMap::from_headers(trailers))); } - // Trailers were not caught during poll_next and thus lets poll for - // them manually. - let map = future::poll_fn(|cx| Pin::new(&mut self.inner.body).poll_trailers(cx)) - .await - .map_err(|e| Status::from_error(Box::new(e))); - - map.map(|x| x.map(MetadataMap::from_headers)) + // We've polled through all the frames, and still no trailers, return None + Ok(None) } fn decode_chunk(&mut self) -> Result, Status> { @@ -395,20 +391,17 @@ impl Stream for Streaming { return Poll::Ready(None); } - // FIXME: implement the ability to poll trailers when we _know_ that - // the consumer of this stream will only poll for the first message. - // This means we skip the poll_trailers step. if let Some(item) = self.decode_chunk()? { return Poll::Ready(Some(Ok(item))); } - match ready!(self.inner.poll_data(cx))? { + match ready!(self.inner.poll_frame(cx))? { Some(()) => (), None => break, } } - Poll::Ready(match ready!(self.inner.poll_response(cx)) { + Poll::Ready(match self.inner.response() { Ok(()) => None, Err(err) => Some(Err(err)), }) diff --git a/tonic/src/codec/encode.rs b/tonic/src/codec/encode.rs index 07aed1dda..3f7c628ee 100644 --- a/tonic/src/codec/encode.rs +++ b/tonic/src/codec/encode.rs @@ -5,7 +5,7 @@ use super::{BufferSettings, EncodeBuf, Encoder, DEFAULT_MAX_SEND_MESSAGE_SIZE, H use crate::{Code, Status}; use bytes::{BufMut, Bytes, BytesMut}; use http::HeaderMap; -use http_body::Body; +use http_body::{Body, Frame}; use pin_project::pin_project; use std::{ pin::Pin, @@ -298,22 +298,21 @@ where } impl EncodeState { - fn trailers(&mut self) -> Result, Status> { + fn trailers(&mut self) -> Option> { match self.role { - Role::Client => Ok(None), + Role::Client => None, Role::Server => { if self.is_end_stream { - return Ok(None); + return None; } + self.is_end_stream = true; let status = if let Some(status) = self.error.take() { - self.is_end_stream = true; status } else { Status::new(Code::Ok, "") }; - - Ok(Some(status.to_header_map()?)) + Some(status.to_header_map()) } } } @@ -330,38 +329,25 @@ where self.state.is_end_stream } - fn size_hint(&self) -> http_body::SizeHint { - let sh = self.inner.size_hint(); - let mut size_hint = http_body::SizeHint::new(); - size_hint.set_lower(sh.0 as u64); - if let Some(upper) = sh.1 { - size_hint.set_upper(upper as u64); - } - size_hint - } - - fn poll_data( + fn poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Self::Error>>> { let self_proj = self.project(); match ready!(self_proj.inner.poll_next(cx)) { - Some(Ok(d)) => Some(Ok(d)).into(), + Some(Ok(d)) => Some(Ok(Frame::data(d))).into(), Some(Err(status)) => match self_proj.state.role { Role::Client => Some(Err(status)).into(), Role::Server => { - self_proj.state.error = Some(status); - None.into() + self_proj.state.is_end_stream = true; + Some(Ok(Frame::trailers(status.to_header_map()?))).into() } }, - None => None.into(), + None => self_proj + .state + .trailers() + .map(|t| t.map(Frame::trailers)) + .into(), } } - - fn poll_trailers( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll, Status>> { - Poll::Ready(self.project().state.trailers()) - } } diff --git a/tonic/src/codec/prost.rs b/tonic/src/codec/prost.rs index 217934e9e..3e6789daf 100644 --- a/tonic/src/codec/prost.rs +++ b/tonic/src/codec/prost.rs @@ -156,6 +156,7 @@ mod tests { use crate::{Code, Status}; use bytes::{Buf, BufMut, BytesMut}; use http_body::Body; + use http_body_util::BodyExt as _; use std::pin::pin; const LEN: usize = 10000; @@ -238,7 +239,7 @@ mod tests { None, )); - while let Some(r) = body.data().await { + while let Some(r) = body.frame().await { r.unwrap(); } } @@ -260,12 +261,15 @@ mod tests { Some(MAX_MESSAGE_SIZE), )); - assert!(body.data().await.is_none()); + let frame = body + .frame() + .await + .expect("at least one frame") + .expect("no error polling frame"); assert_eq!( - body.trailers() - .await - .expect("no error polling trailers") - .expect("some trailers") + frame + .into_trailers() + .expect("got trailers") .get("grpc-status") .expect("grpc-status header"), "11" @@ -292,12 +296,15 @@ mod tests { Some(usize::MAX), )); - assert!(body.data().await.is_none()); + let frame = body + .frame() + .await + .expect("at least one frame") + .expect("no error polling frame"); assert_eq!( - body.trailers() - .await - .expect("no error polling trailers") - .expect("some trailers") + frame + .into_trailers() + .expect("got trailers") .get("grpc-status") .expect("grpc-status header"), "8" @@ -343,7 +350,7 @@ mod tests { mod body { use crate::Status; use bytes::Bytes; - use http_body::Body; + use http_body::{Body, Frame}; use std::{ pin::Pin, task::{Context, Poll}, @@ -374,10 +381,10 @@ mod tests { type Data = Bytes; type Error = Status; - fn poll_data( + fn poll_frame( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Self::Error>>> { // every other call to poll_data returns data let should_send = self.count % 2 == 0; let data_len = self.data.len(); @@ -395,18 +402,11 @@ mod tests { }; // make some fake progress self.count += 1; - result + result.map(|opt| opt.map(|res| res.map(|data| Frame::data(data)))) } else { Poll::Ready(None) } } - - fn poll_trailers( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - Poll::Ready(Ok(None)) - } } } } diff --git a/tonic/src/extensions.rs b/tonic/src/extensions.rs index 37d84b87b..32b9ad021 100644 --- a/tonic/src/extensions.rs +++ b/tonic/src/extensions.rs @@ -24,7 +24,7 @@ impl Extensions { /// If a extension of this type already existed, it will /// be returned. #[inline] - pub fn insert(&mut self, val: T) -> Option { + pub fn insert(&mut self, val: T) -> Option { self.inner.insert(val) } diff --git a/tonic/src/request.rs b/tonic/src/request.rs index 76bf4e9eb..e0829424c 100644 --- a/tonic/src/request.rs +++ b/tonic/src/request.rs @@ -311,6 +311,7 @@ impl Request { /// ```no_run /// use tonic::{Request, service::interceptor}; /// + /// #[derive(Clone)] // Extensions must be Clone /// struct MyExtension { /// some_piece_of_data: String, /// } @@ -438,7 +439,6 @@ pub(crate) enum SanitizeHeaders { #[cfg(test)] mod tests { use super::*; - use crate::metadata::MetadataValue; use http::Uri; #[test] diff --git a/tonic/src/service/interceptor.rs b/tonic/src/service/interceptor.rs index cadff466f..ebe78093d 100644 --- a/tonic/src/service/interceptor.rs +++ b/tonic/src/service/interceptor.rs @@ -232,11 +232,8 @@ where mod tests { #[allow(unused_imports)] use super::*; - use http::header::HeaderMap; - use std::{ - pin::Pin, - task::{Context, Poll}, - }; + use http_body::Frame; + use http_body_util::Empty; use tower::ServiceExt; #[derive(Debug, Default)] @@ -246,19 +243,12 @@ mod tests { type Data = Bytes; type Error = Status; - fn poll_data( + fn poll_frame( self: Pin<&mut Self>, _cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Self::Error>>> { Poll::Ready(None) } - - fn poll_trailers( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - Poll::Ready(Ok(None)) - } } #[tokio::test] @@ -318,17 +308,17 @@ mod tests { #[tokio::test] async fn doesnt_change_http_method() { - let svc = tower::service_fn(|request: http::Request| async move { + let svc = tower::service_fn(|request: http::Request>| async move { assert_eq!(request.method(), http::Method::OPTIONS); - Ok::<_, hyper::Error>(hyper::Response::new(hyper::Body::empty())) + Ok::<_, hyper::Error>(hyper::Response::new(Empty::new())) }); let svc = InterceptedService::new(svc, Ok); let request = http::Request::builder() .method(http::Method::OPTIONS) - .body(hyper::Body::empty()) + .body(Empty::new()) .unwrap(); svc.oneshot(request).await.unwrap(); diff --git a/tonic/src/status.rs b/tonic/src/status.rs index 4ce3abde6..77c85b47d 100644 --- a/tonic/src/status.rs +++ b/tonic/src/status.rs @@ -412,13 +412,7 @@ impl Status { // > status. Note that the frequency of PINGs is highly dependent on the network // > environment, implementations are free to adjust PING frequency based on network and // > application requirements, which is why it's mapped to unavailable here. - // - // Likewise, if we are unable to connect to the server, map this to UNAVAILABLE. This is - // consistent with the behavior of a C++ gRPC client when the server is not running, and - // matches the spec of: - // > The service is currently unavailable. This is most likely a transient condition that - // > can be corrected if retried with a backoff. - if err.is_timeout() || err.is_connect() { + if err.is_timeout() { return Some(Status::unavailable(err.to_string())); } @@ -623,6 +617,11 @@ fn find_status_in_source_chain(err: &(dyn Error + 'static)) -> Option { return Some(Status::cancelled(timeout.to_string())); } + #[cfg(feature = "transport")] + if let Some(connect) = err.downcast_ref::() { + return Some(Status::unavailable(connect.to_string())); + } + #[cfg(feature = "transport")] if let Some(hyper) = err .downcast_ref::() diff --git a/tonic/src/transport/channel/endpoint.rs b/tonic/src/transport/channel/endpoint.rs index 995e2a15b..6014960a8 100644 --- a/tonic/src/transport/channel/endpoint.rs +++ b/tonic/src/transport/channel/endpoint.rs @@ -7,8 +7,10 @@ use crate::transport::service::TlsConnector; use crate::transport::{service::SharedExec, Error, Executor}; use bytes::Bytes; use http::{uri::Uri, HeaderValue}; +use hyper::rt; +use hyper_util::client::legacy::connect::HttpConnector; use std::{fmt, future::Future, pin::Pin, str::FromStr, time::Duration}; -use tower::make::MakeConnection; +use tower_service::Service; /// Channel builder. /// @@ -332,7 +334,7 @@ impl Endpoint { /// Create a channel from this config. pub async fn connect(&self) -> Result { - let mut http = hyper::client::connect::HttpConnector::new(); + let mut http = HttpConnector::new(); http.enforce_http(false); http.set_nodelay(self.tcp_nodelay); http.set_keepalive(self.tcp_keepalive); @@ -348,7 +350,7 @@ impl Endpoint { /// The channel returned by this method does not attempt to connect to the endpoint until first /// use. pub fn connect_lazy(&self) -> Channel { - let mut http = hyper::client::connect::HttpConnector::new(); + let mut http = HttpConnector::new(); http.enforce_http(false); http.set_nodelay(self.tcp_nodelay); http.set_keepalive(self.tcp_keepalive); @@ -368,8 +370,8 @@ impl Endpoint { /// The [`connect_timeout`](Endpoint::connect_timeout) will still be applied. pub async fn connect_with_connector(&self, connector: C) -> Result where - C: MakeConnection + Send + 'static, - C::Connection: Unpin + Send + 'static, + C: Service + Send + 'static, + C::Response: rt::Read + rt::Write + Send + Unpin + 'static, C::Future: Send + 'static, crate::Error: From + Send + 'static, { @@ -393,8 +395,8 @@ impl Endpoint { /// uses a Unix socket transport. pub fn connect_with_connector_lazy(&self, connector: C) -> Channel where - C: MakeConnection + Send + 'static, - C::Connection: Unpin + Send + 'static, + C: Service + Send + 'static, + C::Response: rt::Read + rt::Write + Send + Unpin + 'static, C::Future: Send + 'static, crate::Error: From + Send + 'static, { diff --git a/tonic/src/transport/channel/mod.rs b/tonic/src/transport/channel/mod.rs index b510a6980..0983725f8 100644 --- a/tonic/src/transport/channel/mod.rs +++ b/tonic/src/transport/channel/mod.rs @@ -17,7 +17,7 @@ use http::{ uri::{InvalidUri, Uri}, Request, Response, }; -use hyper::client::connect::Connection as HyperConnection; +use hyper_util::client::legacy::connect::Connection as HyperConnection; use std::{ fmt, future::Future, @@ -25,11 +25,9 @@ use std::{ pin::Pin, task::{ready, Context, Poll}, }; -use tokio::{ - io::{AsyncRead, AsyncWrite}, - sync::mpsc::{channel, Sender}, -}; +use tokio::sync::mpsc::{channel, Sender}; +use hyper::rt; use tower::balance::p2c::Balance; use tower::{ buffer::{self, Buffer}, @@ -38,13 +36,13 @@ use tower::{ Service, }; -type Svc = Either, Response, crate::Error>>; +type Svc = Either, Response, crate::Error>>; const DEFAULT_BUFFER_SIZE: usize = 1024; /// A default batteries included `transport` channel. /// -/// This provides a fully featured http2 gRPC client based on [`hyper::Client`] +/// This provides a fully featured http2 gRPC client based on `hyper` /// and `tower` services. /// /// # Multiplexing requests @@ -152,7 +150,7 @@ impl Channel { C: Service + Send + 'static, C::Error: Into + Send, C::Future: Unpin + Send, - C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static, + C::Response: rt::Read + rt::Write + HyperConnection + Unpin + Send + 'static, { let buffer_size = endpoint.buffer_size.unwrap_or(DEFAULT_BUFFER_SIZE); let executor = endpoint.executor.clone(); @@ -169,7 +167,7 @@ impl Channel { C: Service + Send + 'static, C::Error: Into + Send, C::Future: Unpin + Send, - C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static, + C::Response: rt::Read + rt::Write + HyperConnection + Unpin + Send + 'static, { let buffer_size = endpoint.buffer_size.unwrap_or(DEFAULT_BUFFER_SIZE); let executor = endpoint.executor.clone(); @@ -201,7 +199,7 @@ impl Channel { } impl Service> for Channel { - type Response = http::Response; + type Response = http::Response; type Error = super::Error; type Future = ResponseFuture; @@ -217,7 +215,7 @@ impl Service> for Channel { } impl Future for ResponseFuture { - type Output = Result, super::Error>; + type Output = Result, super::Error>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let val = ready!(Pin::new(&mut self.inner).poll(cx)).map_err(super::Error::from_source)?; diff --git a/tonic/src/transport/mod.rs b/tonic/src/transport/mod.rs index a0435c797..357ab2bf4 100644 --- a/tonic/src/transport/mod.rs +++ b/tonic/src/transport/mod.rs @@ -1,7 +1,7 @@ //! Batteries included server and client. //! //! This module provides a set of batteries included, fully featured and -//! fast set of HTTP/2 server and client's. These components each provide a or +//! fast set of HTTP/2 server and client's. These components each provide a //! `rustls` tls backend when the respective feature flag is enabled, and //! provides builders to configure transport behavior. //! @@ -22,6 +22,7 @@ //! # use tonic::transport::{Channel, Certificate, ClientTlsConfig}; //! # use std::time::Duration; //! # use tonic::body::BoxBody; +//! # use tonic::body::empty_body; //! # use tonic::client::GrpcService;; //! # use http::Request; //! # #[cfg(feature = "rustls")] @@ -38,7 +39,7 @@ //! .connect() //! .await?; //! -//! channel.call(Request::new(BoxBody::empty())).await?; +//! channel.call(Request::new(empty_body())).await?; //! # Ok(()) //! # } //! ``` @@ -46,21 +47,23 @@ //! ## Server //! //! ```no_run +//! # use std::convert::Infallible; //! # #[cfg(feature = "rustls")] //! # use tonic::transport::{Server, Identity, ServerTlsConfig}; +//! # use tonic::body::BoxBody; //! # use tower::Service; //! # #[cfg(feature = "rustls")] //! # async fn do_thing() -> Result<(), Box> { //! # #[derive(Clone)] //! # pub struct Svc; -//! # impl Service> for Svc { -//! # type Response = hyper::Response; -//! # type Error = tonic::Status; +//! # impl Service> for Svc { +//! # type Response = hyper::Response; +//! # type Error = Infallible; //! # type Future = std::future::Ready>; //! # fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> std::task::Poll> { //! # Ok(()).into() //! # } -//! # fn call(&mut self, _req: hyper::Request) -> Self::Future { +//! # fn call(&mut self, _req: hyper::Request) -> Self::Future { //! # unimplemented!() //! # } //! # } @@ -104,11 +107,14 @@ pub use self::error::Error; pub use self::server::Server; #[doc(inline)] pub use self::service::grpc_timeout::TimeoutExpired; +pub(crate) use self::service::ConnectError; + #[cfg(feature = "tls")] #[cfg_attr(docsrs, doc(cfg(feature = "tls")))] pub use self::tls::Certificate; -pub use axum::{body::BoxBody as AxumBoxBody, Router as AxumRouter}; -pub use hyper::{Body, Uri}; +pub use axum::{body::Body as AxumBoxBody, Router as AxumRouter}; +pub use hyper::body::Body; +pub use hyper::Uri; pub(crate) use self::service::executor::Executor; diff --git a/tonic/src/transport/server/conn.rs b/tonic/src/transport/server/conn.rs index 907cf4965..37bcc561b 100644 --- a/tonic/src/transport/server/conn.rs +++ b/tonic/src/transport/server/conn.rs @@ -1,4 +1,3 @@ -use hyper::server::conn::AddrStream; use std::net::SocketAddr; use tokio::net::TcpStream; @@ -86,17 +85,6 @@ impl TcpConnectInfo { } } -impl Connected for AddrStream { - type ConnectInfo = TcpConnectInfo; - - fn connect_info(&self) -> Self::ConnectInfo { - TcpConnectInfo { - local_addr: Some(self.local_addr()), - remote_addr: Some(self.remote_addr()), - } - } -} - impl Connected for TcpStream { type ConnectInfo = TcpConnectInfo; diff --git a/tonic/src/transport/server/incoming.rs b/tonic/src/transport/server/incoming.rs index bc1bb7650..7f5f76c25 100644 --- a/tonic/src/transport/server/incoming.rs +++ b/tonic/src/transport/server/incoming.rs @@ -1,20 +1,18 @@ use super::{Connected, Server}; use crate::transport::service::ServerIo; -use hyper::server::{ - accept::Accept, - conn::{AddrIncoming, AddrStream}, -}; use std::{ - net::SocketAddr, + net::{SocketAddr, TcpListener as StdTcpListener}, pin::{pin, Pin}, - task::{Context, Poll}, + task::{ready, Context, Poll}, time::Duration, }; use tokio::{ io::{AsyncRead, AsyncWrite}, - net::TcpListener, + net::{TcpListener, TcpStream}, }; +use tokio_stream::wrappers::TcpListenerStream; use tokio_stream::{Stream, StreamExt}; +use tracing::warn; #[cfg(not(feature = "tls"))] pub(crate) fn tcp_incoming( @@ -127,7 +125,9 @@ enum SelectOutput { /// of `AsyncRead + AsyncWrite` that communicate with clients that connect to a socket address. #[derive(Debug)] pub struct TcpIncoming { - inner: AddrIncoming, + inner: TcpListenerStream, + nodelay: bool, + keepalive: Option, } impl TcpIncoming { @@ -139,13 +139,13 @@ impl TcpIncoming { /// ```no_run /// # use tower_service::Service; /// # use http::{request::Request, response::Response}; - /// # use tonic::{body::BoxBody, server::NamedService, transport::{Body, Server, server::TcpIncoming}}; + /// # use tonic::{body::BoxBody, server::NamedService, transport::{Server, server::TcpIncoming}}; /// # use core::convert::Infallible; /// # use std::error::Error; /// # fn main() { } // Cannot have type parameters, hence instead define: /// # fn run(some_service: S) -> Result<(), Box> /// # where - /// # S: Service, Response = Response, Error = Infallible> + NamedService + Clone + Send + 'static, + /// # S: Service, Response = Response, Error = Infallible> + NamedService + Clone + Send + 'static, /// # S::Future: Send + 'static, /// # { /// // Find a free port @@ -167,10 +167,15 @@ impl TcpIncoming { nodelay: bool, keepalive: Option, ) -> Result { - let mut inner = AddrIncoming::bind(&addr)?; - inner.set_nodelay(nodelay); - inner.set_keepalive(keepalive); - Ok(TcpIncoming { inner }) + let std_listener = StdTcpListener::bind(addr)?; + std_listener.set_nonblocking(true)?; + + let inner = TcpListenerStream::new(TcpListener::from_std(std_listener)?); + Ok(Self { + inner, + nodelay, + keepalive, + }) } /// Creates a new `TcpIncoming` from an existing `tokio::net::TcpListener`. @@ -179,18 +184,43 @@ impl TcpIncoming { nodelay: bool, keepalive: Option, ) -> Result { - let mut inner = AddrIncoming::from_listener(listener)?; - inner.set_nodelay(nodelay); - inner.set_keepalive(keepalive); - Ok(TcpIncoming { inner }) + Ok(Self { + inner: TcpListenerStream::new(listener), + nodelay, + keepalive, + }) } } impl Stream for TcpIncoming { - type Item = Result; + type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.inner).poll_accept(cx) + match ready!(Pin::new(&mut self.inner).poll_next(cx)) { + Some(Ok(stream)) => { + set_accepted_socket_options(&stream, self.nodelay, self.keepalive); + Some(Ok(stream)).into() + } + other => Poll::Ready(other), + } + } +} + +// Consistent with hyper-0.14, this function does not return an error. +fn set_accepted_socket_options(stream: &TcpStream, nodelay: bool, keepalive: Option) { + if nodelay { + if let Err(e) = stream.set_nodelay(true) { + warn!("error trying to set TCP nodelay: {}", e); + } + } + + if let Some(timeout) = keepalive { + let sock_ref = socket2::SockRef::from(&stream); + let sock_keepalive = socket2::TcpKeepalive::new().with_time(timeout); + + if let Err(e) = sock_ref.set_tcp_keepalive(&sock_keepalive) { + warn!("error trying to set TCP keepalive: {}", e); + } } } diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs index 7f2ffde2b..fb63058ad 100644 --- a/tonic/src/transport/server/mod.rs +++ b/tonic/src/transport/server/mod.rs @@ -9,10 +9,18 @@ mod tls; #[cfg(unix)] mod unix; +use tokio_stream::StreamExt as _; +use tower::util::BoxCloneService; +use tower::util::Oneshot; +use tower::ServiceExt; +use tracing::debug; +use tracing::trace; + pub use super::service::Routes; pub use super::service::RoutesBuilder; pub use conn::{Connected, TcpConnectInfo}; +use hyper_util::rt::{TokioExecutor, TokioIo}; #[cfg(feature = "tls")] pub use tls::ServerTlsConfig; @@ -35,20 +43,22 @@ use crate::transport::Error; use self::recover_error::RecoverError; use super::service::{GrpcTimeout, ServerIo}; +use crate::body::boxed; use crate::body::BoxBody; use crate::server::NamedService; use bytes::Bytes; use http::{Request, Response}; -use http_body::Body as _; -use hyper::{server::accept, Body}; +use http_body_util::BodyExt; +use hyper::body::Incoming; use pin_project::pin_project; +use std::future::poll_fn; use std::{ convert::Infallible, fmt, future::{self, Future}, marker::PhantomData, net::SocketAddr, - pin::Pin, + pin::{pin, Pin}, sync::Arc, task::{ready, Context, Poll}, time::Duration, @@ -63,16 +73,17 @@ use tower::{ Service, ServiceBuilder, }; -type BoxHttpBody = http_body::combinators::UnsyncBoxBody; -type BoxService = tower::util::BoxService, Response, crate::Error>; +type BoxHttpBody = crate::body::BoxBody; +type BoxError = crate::Error; +type BoxService = tower::util::BoxCloneService, Response, crate::Error>; type TraceInterceptor = Arc) -> tracing::Span + Send + Sync + 'static>; const DEFAULT_HTTP2_KEEPALIVE_TIMEOUT_SECS: u64 = 20; /// A default batteries included `transport` server. /// -/// This is a wrapper around [`hyper::Server`] and provides an easy builder -/// pattern style builder [`Server`]. This builder exposes easy configuration parameters +/// This provides an easy builder pattern style builder [`Server`] on top of +/// `hyper` connections. This builder exposes easy configuration parameters /// for providing a fully featured http2 based gRPC server. This should provide /// a very good out of the box http2 server for use with tonic but is also a /// reference implementation that should be a good starting point for anyone @@ -122,7 +133,7 @@ impl Default for Server { } } -/// A stack based `Service` router. +/// A stack based [`Service`] router. #[derive(Debug)] pub struct Router { server: Server, @@ -359,7 +370,7 @@ impl Server { /// route around different services. pub fn add_service(&mut self, svc: S) -> Router where - S: Service, Response = Response, Error = Infallible> + S: Service, Response = Response, Error = Infallible> + NamedService + Clone + Send @@ -380,7 +391,7 @@ impl Server { /// As a result, one cannot use this to toggle between two identically named implementations. pub fn add_optional_service(&mut self, svc: Option) -> Router where - S: Service, Response = Response, Error = Infallible> + S: Service, Response = Response, Error = Infallible> + NamedService + Clone + Send @@ -494,9 +505,11 @@ impl Server { ) -> Result<(), super::Error> where L: Layer, - L::Service: Service, Response = Response> + Clone + Send + 'static, - <>::Service as Service>>::Future: Send + 'static, - <>::Service as Service>>::Error: Into + Send, + L::Service: + Service, Response = Response> + Clone + Send + 'static, + <>::Service as Service>>::Future: Send + 'static, + <>::Service as Service>>::Error: + Into + Send + 'static, I: Stream>, IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, IO::ConnectInfo: Clone + Send + Sync + 'static, @@ -523,10 +536,8 @@ impl Server { let svc = self.service_builder.service(svc); - let tcp = incoming::tcp_incoming(incoming, self); - let incoming = accept::from_stream::<_, _, crate::Error>(tcp); - - let svc = MakeSvc { + let incoming = incoming::tcp_incoming(incoming, self); + let mut svc = MakeSvc { inner: svc, concurrency_limit, timeout, @@ -534,31 +545,204 @@ impl Server { _io: PhantomData, }; - let server = hyper::Server::builder(incoming) - .http2_only(http2_only) - .http2_initial_connection_window_size(init_connection_window_size) - .http2_initial_stream_window_size(init_stream_window_size) - .http2_max_concurrent_streams(max_concurrent_streams) - .http2_keep_alive_interval(http2_keepalive_interval) - .http2_keep_alive_timeout(http2_keepalive_timeout) - .http2_adaptive_window(http2_adaptive_window.unwrap_or_default()) - .http2_max_pending_accept_reset_streams(http2_max_pending_accept_reset_streams) - .http2_max_frame_size(max_frame_size); - - if let Some(signal) = signal { - server - .serve(svc) - .with_graceful_shutdown(signal) - .await - .map_err(super::Error::from_source)? - } else { - server.serve(svc).await.map_err(super::Error::from_source)?; + let server = { + let mut builder = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()); + + if http2_only { + builder = builder.http2_only(); + } + + builder + .http2() + .initial_connection_window_size(init_connection_window_size) + .initial_stream_window_size(init_stream_window_size) + .max_concurrent_streams(max_concurrent_streams) + .keep_alive_interval(http2_keepalive_interval) + .keep_alive_timeout(http2_keepalive_timeout) + .adaptive_window(http2_adaptive_window.unwrap_or_default()) + .max_pending_accept_reset_streams(http2_max_pending_accept_reset_streams) + .max_frame_size(max_frame_size); + + builder + }; + + let (signal_tx, signal_rx) = tokio::sync::watch::channel(()); + let signal_tx = Arc::new(signal_tx); + + let graceful = signal.is_some(); + let mut sig = pin!(Fuse { inner: signal }); + let mut incoming = pin!(incoming); + + loop { + tokio::select! { + _ = &mut sig => { + trace!("signal received, shutting down"); + break; + }, + io = incoming.next() => { + let io = match io { + Some(Ok(io)) => io, + Some(Err(e)) => { + trace!("error accepting connection: {:#}", e); + continue; + }, + None => { + break + }, + }; + + trace!("connection accepted"); + + poll_fn(|cx| svc.poll_ready(cx)) + .await + .map_err(super::Error::from_source)?; + + let req_svc = svc + .call(&io) + .await + .map_err(super::Error::from_source)?; + let hyper_svc = TowerToHyperService::new(req_svc); + + serve_connection(io, hyper_svc, server.clone(), graceful.then(|| signal_rx.clone())); + } + } + } + + if graceful { + let _ = signal_tx.send(()); + drop(signal_rx); + trace!( + "waiting for {} connections to close", + signal_tx.receiver_count() + ); + + // Wait for all connections to close + signal_tx.closed().await; } Ok(()) } } +// This is moved to its own function as a way to get around +// https://github.com/rust-lang/rust/issues/102211 +fn serve_connection( + io: ServerIo, + hyper_svc: TowerToHyperService, + builder: ConnectionBuilder, + mut watcher: Option>, +) where + S: Service, Response = Response> + Clone + Send + 'static, + S::Future: Send + 'static, + S::Error: Into + Send, + IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, + IO::ConnectInfo: Clone + Send + Sync + 'static, +{ + tokio::spawn(async move { + { + let mut sig = pin!(Fuse { + inner: watcher.as_mut().map(|w| w.changed()), + }); + + let mut conn = pin!(builder.serve_connection(TokioIo::new(io), hyper_svc)); + + loop { + tokio::select! { + rv = &mut conn => { + if let Err(err) = rv { + debug!("failed serving connection: {:#}", err); + } + break; + }, + _ = &mut sig => { + conn.as_mut().graceful_shutdown(); + } + } + } + } + + drop(watcher); + trace!("connection closed"); + }); +} + +type ConnectionBuilder = hyper_util::server::conn::auto::Builder; + +/// An adaptor which converts a [`tower::Service`] to a [`hyper::service::Service`]. +/// +/// The [`hyper::service::Service`] trait is used by hyper to handle incoming requests, +/// and does not support the `poll_ready` method that is used by tower services. +#[derive(Debug, Copy, Clone)] +pub struct TowerToHyperService { + service: S, +} + +impl TowerToHyperService { + /// Create a new `TowerToHyperService` from a tower service. + pub fn new(service: S) -> Self { + Self { service } + } + + /// Extract the inner tower service. + pub fn into_inner(self) -> S { + self.service + } + + /// Get a reference to the inner tower service. + pub fn as_inner(&self) -> &S { + &self.service + } + + /// Get a mutable reference to the inner tower service. + pub fn as_inner_mut(&mut self) -> &mut S { + &mut self.service + } +} + +impl hyper::service::Service> for TowerToHyperService +where + S: tower_service::Service> + Clone, + S::Error: Into + 'static, +{ + type Response = S::Response; + type Error = super::Error; + type Future = TowerToHyperServiceFuture>; + + fn call(&self, req: Request) -> Self::Future { + let req = req.map(crate::body::boxed); + TowerToHyperServiceFuture { + future: self.service.clone().oneshot(req), + } + } +} + +/// Future returned by [`TowerToHyperService`]. +#[derive(Debug)] +#[pin_project] +pub struct TowerToHyperServiceFuture +where + S: tower_service::Service, +{ + #[pin] + future: Oneshot, +} + +impl Future for TowerToHyperServiceFuture +where + S: tower_service::Service, + S::Error: Into + 'static, +{ + type Output = Result; + + #[inline] + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.project() + .future + .poll(cx) + .map_err(super::Error::from_source) + } +} + impl Router { pub(crate) fn new(server: Server, routes: Routes) -> Self { Self { server, routes } @@ -569,7 +753,7 @@ impl Router { /// Add a new service to this router. pub fn add_service(mut self, svc: S) -> Self where - S: Service, Response = Response, Error = Infallible> + S: Service, Response = Response, Error = Infallible> + NamedService + Clone + Send @@ -588,7 +772,7 @@ impl Router { #[allow(clippy::type_complexity)] pub fn add_optional_service(mut self, svc: Option) -> Self where - S: Service, Response = Response, Error = Infallible> + S: Service, Response = Response, Error = Infallible> + NamedService + Clone + Send @@ -613,10 +797,12 @@ impl Router { /// [tokio]: https://docs.rs/tokio pub async fn serve(self, addr: SocketAddr) -> Result<(), super::Error> where - L: Layer, - L::Service: Service, Response = Response> + Clone + Send + 'static, - <>::Service as Service>>::Future: Send + 'static, - <>::Service as Service>>::Error: Into + Send, + L: Layer + Clone, + L::Service: + Service, Response = Response> + Clone + Send + 'static, + <>::Service as Service>>::Future: Send + 'static, + <>::Service as Service>>::Error: + Into + Send, ResBody: http_body::Body + Send + 'static, ResBody::Error: Into, { @@ -644,9 +830,11 @@ impl Router { ) -> Result<(), super::Error> where L: Layer, - L::Service: Service, Response = Response> + Clone + Send + 'static, - <>::Service as Service>>::Future: Send + 'static, - <>::Service as Service>>::Error: Into + Send, + L::Service: + Service, Response = Response> + Clone + Send + 'static, + <>::Service as Service>>::Future: Send + 'static, + <>::Service as Service>>::Error: + Into + Send, ResBody: http_body::Body + Send + 'static, ResBody::Error: Into, { @@ -673,9 +861,11 @@ impl Router { IO::ConnectInfo: Clone + Send + Sync + 'static, IE: Into, L: Layer, - L::Service: Service, Response = Response> + Clone + Send + 'static, - <>::Service as Service>>::Future: Send + 'static, - <>::Service as Service>>::Error: Into + Send, + L::Service: + Service, Response = Response> + Clone + Send + 'static, + <>::Service as Service>>::Future: Send + 'static, + <>::Service as Service>>::Error: + Into + Send, ResBody: http_body::Body + Send + 'static, ResBody::Error: Into, { @@ -708,9 +898,11 @@ impl Router { IE: Into, F: Future, L: Layer, - L::Service: Service, Response = Response> + Clone + Send + 'static, - <>::Service as Service>>::Future: Send + 'static, - <>::Service as Service>>::Error: Into + Send, + L::Service: + Service, Response = Response> + Clone + Send + 'static, + <>::Service as Service>>::Future: Send + 'static, + <>::Service as Service>>::Error: + Into + Send, ResBody: http_body::Body + Send + 'static, ResBody::Error: Into, { @@ -723,9 +915,11 @@ impl Router { pub fn into_service(self) -> L::Service where L: Layer, - L::Service: Service, Response = Response> + Clone + Send + 'static, - <>::Service as Service>>::Future: Send + 'static, - <>::Service as Service>>::Error: Into + Send, + L::Service: + Service, Response = Response> + Clone + Send + 'static, + <>::Service as Service>>::Future: Send + 'static, + <>::Service as Service>>::Error: + Into + Send, ResBody: http_body::Body + Send + 'static, ResBody::Error: Into, { @@ -739,14 +933,15 @@ impl fmt::Debug for Server { } } +#[derive(Clone)] struct Svc { inner: S, trace_interceptor: Option, } -impl Service> for Svc +impl Service> for Svc where - S: Service, Response = Response>, + S: Service, Response = Response>, S::Error: Into, ResBody: http_body::Body + Send + 'static, ResBody::Error: Into, @@ -759,7 +954,7 @@ where self.inner.poll_ready(cx).map_err(Into::into) } - fn call(&mut self, mut req: Request) -> Self::Future { + fn call(&mut self, mut req: Request) -> Self::Future { let span = if let Some(trace_interceptor) = &self.trace_interceptor { let (parts, body) = req.into_parts(); let bodyless_request = Request::from_parts(parts, ()); @@ -802,7 +997,7 @@ where let _guard = this.span.enter(); let response: Response = ready!(this.inner.poll(cx)).map_err(Into::into)?; - let response = response.map(|body| body.map_err(Into::into).boxed_unsync()); + let response = response.map(|body| boxed(body.map_err(Into::into))); Poll::Ready(Ok(response)) } } @@ -813,6 +1008,7 @@ impl fmt::Debug for Svc { } } +#[derive(Clone)] struct MakeSvc { concurrency_limit: Option, timeout: Option, @@ -824,7 +1020,7 @@ struct MakeSvc { impl Service<&ServerIo> for MakeSvc where IO: Connected, - S: Service, Response = Response> + Clone + Send + 'static, + S: Service, Response = Response> + Clone + Send + 'static, S::Future: Send + 'static, S::Error: Into + Send, ResBody: http_body::Body + Send + 'static, @@ -853,8 +1049,8 @@ where .service(svc); let svc = ServiceBuilder::new() - .layer(BoxService::layer()) - .map_request(move |mut request: Request| { + .layer(BoxCloneService::layer()) + .map_request(move |mut request: Request| { match &conn_info { tower::util::Either::A(inner) => { request.extensions_mut().insert(inner.clone()); @@ -885,3 +1081,29 @@ where future::ready(Ok(svc)) } } + +// From `futures-util` crate, borrowed since this is the only dependency tonic requires. +// LICENSE: MIT or Apache-2.0 +// A future which only yields `Poll::Ready` once, and thereafter yields `Poll::Pending`. +#[pin_project] +struct Fuse { + #[pin] + inner: Option, +} + +impl Future for Fuse +where + F: Future, +{ + type Output = F::Output; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.as_mut().project().inner.as_pin_mut() { + Some(fut) => fut.poll(cx).map(|output| { + self.project().inner.set(None); + output + }), + None => Poll::Pending, + } + } +} diff --git a/tonic/src/transport/server/recover_error.rs b/tonic/src/transport/server/recover_error.rs index fdb14a66a..60b0d9a7b 100644 --- a/tonic/src/transport/server/recover_error.rs +++ b/tonic/src/transport/server/recover_error.rs @@ -1,5 +1,6 @@ use crate::Status; use http::Response; +use http_body::Frame; use pin_project::pin_project; use std::{ future::Future, @@ -98,26 +99,16 @@ where type Data = B::Data; type Error = B::Error; - fn poll_data( + fn poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Self::Error>>> { match self.project().inner.as_pin_mut() { - Some(b) => b.poll_data(cx), + Some(b) => b.poll_frame(cx), None => Poll::Ready(None), } } - fn poll_trailers( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - match self.project().inner.as_pin_mut() { - Some(b) => b.poll_trailers(cx), - None => Poll::Ready(Ok(None)), - } - } - fn is_end_stream(&self) -> bool { match &self.inner { Some(b) => b.is_end_stream(), diff --git a/tonic/src/transport/service/connection.rs b/tonic/src/transport/service/connection.rs index 46a88dda5..a31c9868b 100644 --- a/tonic/src/transport/service/connection.rs +++ b/tonic/src/transport/service/connection.rs @@ -1,17 +1,16 @@ +use super::SharedExec; use super::{grpc_timeout::GrpcTimeout, reconnect::Reconnect, AddOrigin, UserAgent}; use crate::{ - body::BoxBody, + body::{boxed, BoxBody}, transport::{BoxFuture, Endpoint}, }; use http::Uri; -use hyper::client::conn::Builder; -use hyper::client::connect::Connection as HyperConnection; -use hyper::client::service::Connect as HyperConnect; +use hyper::rt; +use hyper::{client::conn::http2::Builder, rt::Executor}; use std::{ fmt, task::{Context, Poll}, }; -use tokio::io::{AsyncRead, AsyncWrite}; use tower::load::Load; use tower::{ layer::Layer, @@ -21,8 +20,8 @@ use tower::{ }; use tower_service::Service; -pub(crate) type Request = http::Request; -pub(crate) type Response = http::Response; +pub(crate) type Response = http::Response; +pub(crate) type Request = http::Request; pub(crate) struct Connection { inner: BoxService, @@ -34,26 +33,24 @@ impl Connection { C: Service + Send + 'static, C::Error: Into + Send, C::Future: Unpin + Send, - C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static, + C::Response: rt::Read + rt::Write + Unpin + Send + 'static, { - let mut settings = Builder::new() - .http2_initial_stream_window_size(endpoint.init_stream_window_size) - .http2_initial_connection_window_size(endpoint.init_connection_window_size) - .http2_only(true) - .http2_keep_alive_interval(endpoint.http2_keep_alive_interval) - .executor(endpoint.executor.clone()) + let mut settings: Builder = Builder::new(endpoint.executor.clone()) + .initial_stream_window_size(endpoint.init_stream_window_size) + .initial_connection_window_size(endpoint.init_connection_window_size) + .keep_alive_interval(endpoint.http2_keep_alive_interval) .clone(); if let Some(val) = endpoint.http2_keep_alive_timeout { - settings.http2_keep_alive_timeout(val); + settings.keep_alive_timeout(val); } if let Some(val) = endpoint.http2_keep_alive_while_idle { - settings.http2_keep_alive_while_idle(val); + settings.keep_alive_while_idle(val); } if let Some(val) = endpoint.http2_adaptive_window { - settings.http2_adaptive_window(val); + settings.adaptive_window(val); } let stack = ServiceBuilder::new() @@ -68,13 +65,13 @@ impl Connection { .option_layer(endpoint.rate_limit.map(|(l, d)| RateLimitLayer::new(l, d))) .into_inner(); - let connector = HyperConnect::new(connector, settings); - let conn = Reconnect::new(connector, endpoint.uri.clone(), is_lazy); + let make_service = + MakeSendRequestService::new(connector, endpoint.executor.clone(), settings); - let inner = stack.layer(conn); + let conn = Reconnect::new(make_service, endpoint.uri.clone(), is_lazy); Self { - inner: BoxService::new(inner), + inner: BoxService::new(stack.layer(conn)), } } @@ -83,7 +80,7 @@ impl Connection { C: Service + Send + 'static, C::Error: Into + Send, C::Future: Unpin + Send, - C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static, + C::Response: rt::Read + rt::Write + Unpin + Send + 'static, { Self::new(connector, endpoint, false).ready_oneshot().await } @@ -93,7 +90,7 @@ impl Connection { C: Service + Send + 'static, C::Error: Into + Send, C::Future: Unpin + Send, - C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static, + C::Response: rt::Read + rt::Write + Unpin + Send + 'static, { Self::new(connector, endpoint, true) } @@ -126,3 +123,87 @@ impl fmt::Debug for Connection { f.debug_struct("Connection").finish() } } + +struct SendRequest { + inner: hyper::client::conn::http2::SendRequest, +} + +impl From> for SendRequest { + fn from(inner: hyper::client::conn::http2::SendRequest) -> Self { + Self { inner } + } +} + +impl tower::Service> for SendRequest { + type Response = http::Response; + type Error = crate::Error; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx).map_err(Into::into) + } + + fn call(&mut self, req: Request) -> Self::Future { + let fut = self.inner.send_request(req); + + Box::pin(async move { + fut.await + .map_err(Into::into) + .map(|res| res.map(|body| boxed(body))) + }) + } +} + +struct MakeSendRequestService { + connector: C, + executor: super::SharedExec, + settings: Builder, +} + +impl MakeSendRequestService { + fn new(connector: C, executor: SharedExec, settings: Builder) -> Self { + Self { + connector, + executor, + settings, + } + } +} + +impl tower::Service for MakeSendRequestService +where + C: Service + Send + 'static, + C::Error: Into + Send, + C::Future: Unpin + Send, + C::Response: rt::Read + rt::Write + Unpin + Send + 'static, +{ + type Response = SendRequest; + type Error = crate::Error; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.connector.poll_ready(cx).map_err(Into::into) + } + + fn call(&mut self, req: Uri) -> Self::Future { + let fut = self.connector.call(req); + let builder = self.settings.clone(); + let executor = self.executor.clone(); + + Box::pin(async move { + let io = fut.await.map_err(Into::into)?; + let (send_request, conn) = builder.handshake(io).await?; + + Executor::>::execute( + &executor, + Box::pin(async move { + if let Err(e) = conn.await { + tracing::debug!("connection task error: {:?}", e); + } + }) as _, + ); + + Ok(SendRequest::from(send_request)) + }) + } +} diff --git a/tonic/src/transport/service/connector.rs b/tonic/src/transport/service/connector.rs index 12336813a..978441d75 100644 --- a/tonic/src/transport/service/connector.rs +++ b/tonic/src/transport/service/connector.rs @@ -3,12 +3,32 @@ use super::io::BoxedIo; #[cfg(feature = "tls")] use super::tls::TlsConnector; use http::Uri; -#[cfg(feature = "tls")] use std::fmt; use std::task::{Context, Poll}; -use tower::make::MakeConnection; + +use hyper::rt; + +#[cfg(feature = "tls")] +use hyper_util::rt::TokioIo; use tower_service::Service; +/// Wrapper type to indicate that an error occurs during the connection +/// process, so that the appropriate gRPC Status can be inferred. +#[derive(Debug)] +pub(crate) struct ConnectError(pub(crate) crate::Error); + +impl fmt::Display for ConnectError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&self.0, f) + } +} + +impl std::error::Error for ConnectError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + Some(self.0.as_ref()) + } +} + pub(crate) struct Connector { inner: C, #[cfg(feature = "tls")] @@ -51,17 +71,19 @@ impl Connector { impl Service for Connector where - C: MakeConnection, - C::Connection: Unpin + Send + 'static, + C: Service, + C::Response: rt::Read + rt::Write + Unpin + Send + 'static, C::Future: Send + 'static, crate::Error: From + Send + 'static, { type Response = BoxedIo; - type Error = crate::Error; + type Error = ConnectError; type Future = BoxFuture<'static, Result>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - MakeConnection::poll_ready(&mut self.inner, cx).map_err(Into::into) + self.inner + .poll_ready(cx) + .map_err(|err| ConnectError(From::from(err))) } fn call(&mut self, uri: Uri) -> Self::Future { @@ -73,26 +95,30 @@ where #[cfg(feature = "tls")] let is_https = uri.scheme_str() == Some("https"); - let connect = self.inner.make_connection(uri); + let connect = self.inner.call(uri); Box::pin(async move { - let io = connect.await?; - - #[cfg(feature = "tls")] - { - if let Some(tls) = tls { - if is_https { - let conn = tls.connect(io).await?; - return Ok(BoxedIo::new(conn)); - } else { - return Ok(BoxedIo::new(io)); + async { + let io = connect.await?; + + #[cfg(feature = "tls")] + { + if let Some(tls) = tls { + return if is_https { + let io = tls.connect(TokioIo::new(io)).await?; + Ok(io) + } else { + Ok(BoxedIo::new(io)) + }; + } else if is_https { + return Err(HttpsUriWithoutTlsSupport(()).into()); } - } else if is_https { - return Err(HttpsUriWithoutTlsSupport(()).into()); } - } - Ok(BoxedIo::new(io)) + Ok::<_, crate::Error>(BoxedIo::new(io)) + } + .await + .map_err(|err| ConnectError(From::from(err))) }) } } diff --git a/tonic/src/transport/service/discover.rs b/tonic/src/transport/service/discover.rs index 2d23ca74c..b9356110e 100644 --- a/tonic/src/transport/service/discover.rs +++ b/tonic/src/transport/service/discover.rs @@ -1,6 +1,7 @@ use super::connection::Connection; use crate::transport::Endpoint; +use hyper_util::client::legacy::connect::HttpConnector; use std::{ hash::Hash, pin::Pin, @@ -32,7 +33,7 @@ impl Stream for DynamicServiceStream { Poll::Pending | Poll::Ready(None) => Poll::Pending, Poll::Ready(Some(change)) => match change { Change::Insert(k, endpoint) => { - let mut http = hyper::client::connect::HttpConnector::new(); + let mut http = HttpConnector::new(); http.set_nodelay(endpoint.tcp_nodelay); http.set_keepalive(endpoint.tcp_keepalive); http.set_connect_timeout(endpoint.connect_timeout); diff --git a/tonic/src/transport/service/executor.rs b/tonic/src/transport/service/executor.rs index de3cfbe6e..7b699c307 100644 --- a/tonic/src/transport/service/executor.rs +++ b/tonic/src/transport/service/executor.rs @@ -36,8 +36,11 @@ impl SharedExec { } } -impl Executor> for SharedExec { - fn execute(&self, fut: BoxFuture<'static, ()>) { - self.inner.execute(fut) +impl Executor for SharedExec +where + F: Future + Send + 'static, +{ + fn execute(&self, fut: F) { + self.inner.execute(Box::pin(fut)) } } diff --git a/tonic/src/transport/service/io.rs b/tonic/src/transport/service/io.rs index 2230b9b2e..cb2296cac 100644 --- a/tonic/src/transport/service/io.rs +++ b/tonic/src/transport/service/io.rs @@ -1,5 +1,6 @@ use crate::transport::server::Connected; -use hyper::client::connect::{Connected as HyperConnected, Connection}; +use hyper::rt; +use hyper_util::client::legacy::connect::{Connected as HyperConnected, Connection}; use std::io; use std::io::IoSlice; use std::pin::Pin; @@ -9,11 +10,11 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio_rustls::server::TlsStream; pub(in crate::transport) trait Io: - AsyncRead + AsyncWrite + Send + 'static + rt::Read + rt::Write + Send + 'static { } -impl Io for T where T: AsyncRead + AsyncWrite + Send + 'static {} +impl Io for T where T: rt::Read + rt::Write + Send + 'static {} pub(crate) struct BoxedIo(Pin>); @@ -40,17 +41,17 @@ impl Connected for BoxedIo { #[derive(Copy, Clone)] pub(crate) struct NoneConnectInfo; -impl AsyncRead for BoxedIo { +impl rt::Read for BoxedIo { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, + buf: rt::ReadBufCursor<'_>, ) -> Poll> { Pin::new(&mut self.0).poll_read(cx, buf) } } -impl AsyncWrite for BoxedIo { +impl rt::Write for BoxedIo { fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, diff --git a/tonic/src/transport/service/mod.rs b/tonic/src/transport/service/mod.rs index 69d850f10..2b2a84070 100644 --- a/tonic/src/transport/service/mod.rs +++ b/tonic/src/transport/service/mod.rs @@ -13,6 +13,7 @@ mod user_agent; pub(crate) use self::add_origin::AddOrigin; pub(crate) use self::connection::Connection; +pub(crate) use self::connector::ConnectError; pub(crate) use self::connector::Connector; pub(crate) use self::discover::DynamicServiceStream; pub(crate) use self::executor::SharedExec; diff --git a/tonic/src/transport/service/router.rs b/tonic/src/transport/service/router.rs index 85636c4d4..c43782ba9 100644 --- a/tonic/src/transport/service/router.rs +++ b/tonic/src/transport/service/router.rs @@ -1,9 +1,9 @@ use crate::{ body::{boxed, BoxBody}, server::NamedService, + transport::BoxFuture, }; use http::{Request, Response}; -use hyper::Body; use pin_project::pin_project; use std::{ convert::Infallible, @@ -12,7 +12,6 @@ use std::{ pin::Pin, task::{ready, Context, Poll}, }; -use tower::ServiceExt; use tower_service::Service; /// A [`Service`] router. @@ -31,7 +30,7 @@ impl RoutesBuilder { /// Add a new service. pub fn add_service(&mut self, svc: S) -> &mut Self where - S: Service, Response = Response, Error = Infallible> + S: Service, Response = Response, Error = Infallible> + NamedService + Clone + Send @@ -53,7 +52,7 @@ impl Routes { /// Create a new routes with `svc` already added to it. pub fn new(svc: S) -> Self where - S: Service, Response = Response, Error = Infallible> + S: Service, Response = Response, Error = Infallible> + NamedService + Clone + Send @@ -68,7 +67,7 @@ impl Routes { /// Add a new service. pub fn add_service(mut self, svc: S) -> Self where - S: Service, Response = Response, Error = Infallible> + S: Service, Response = Response, Error = Infallible> + NamedService + Clone + Send @@ -76,10 +75,10 @@ impl Routes { S::Future: Send + 'static, S::Error: Into + Send, { - let svc = svc.map_response(|res| res.map(axum::body::boxed)); - self.router = self - .router - .route_service(&format!("/{}/*rest", S::NAME), svc); + self.router = self.router.route_service( + &format!("/{}/*rest", S::NAME), + AxumBodyService { service: svc }, + ); self } @@ -103,7 +102,7 @@ async fn unimplemented() -> impl axum::response::IntoResponse { (status, headers) } -impl Service> for Routes { +impl Service> for Routes { type Response = Response; type Error = crate::Error; type Future = RoutesFuture; @@ -113,13 +112,13 @@ impl Service> for Routes { Poll::Ready(Ok(())) } - fn call(&mut self, req: Request) -> Self::Future { + fn call(&mut self, req: Request) -> Self::Future { RoutesFuture(self.router.call(req)) } } #[pin_project] -pub struct RoutesFuture(#[pin] axum::routing::future::RouteFuture); +pub struct RoutesFuture(#[pin] axum::routing::future::RouteFuture); impl fmt::Debug for RoutesFuture { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -137,3 +136,33 @@ impl Future for RoutesFuture { } } } + +#[derive(Clone)] +struct AxumBodyService { + service: S, +} + +impl Service> for AxumBodyService +where + S: Service, Response = Response, Error = Infallible> + + Clone + + Send + + 'static, + S::Future: Send + 'static, +{ + type Response = Response; + type Error = Infallible; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.service.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + let fut = self.service.call(req.map(|body| boxed(body))); + Box::pin(async move { + fut.await + .map(|res| res.map(|body| axum::body::Body::new(body))) + }) + } +} diff --git a/tonic/src/transport/service/tls.rs b/tonic/src/transport/service/tls.rs index 2a6394a4f..2ce9dc5da 100644 --- a/tonic/src/transport/service/tls.rs +++ b/tonic/src/transport/service/tls.rs @@ -18,6 +18,7 @@ use crate::transport::{ server::{Connected, TlsStream}, Certificate, Identity, }; +use hyper_util::rt::TokioIo; /// h2 alpn in plain format for rustls. const ALPN_H2: &[u8] = b"h2"; @@ -88,7 +89,7 @@ impl TlsConnector { if !(alpn_protocol == Some(ALPN_H2) || self.assume_http2) { return Err(TlsError::H2NotNegotiated.into()); } - Ok(BoxedIo::new(io)) + Ok(BoxedIo::new(TokioIo::new(io))) } }