From 5f8a06abb567184a042a19092c0ce50ef179df55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ale=C5=A1=20Bizjak?= Date: Tue, 30 Jan 2024 22:17:20 +0100 Subject: [PATCH] Add support for load-shedding. When the per-connection service is not ready this gives an option of immediately rejecting requests. --- tonic/Cargo.toml | 2 +- tonic/src/status.rs | 11 +++++++++++ tonic/src/transport/server/mod.rs | 33 +++++++++++++++++++++++++++++++ 3 files changed, 45 insertions(+), 1 deletion(-) diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index 013cc6e72..9e4c73b6e 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -73,7 +73,7 @@ h2 = {version = "0.3.17", optional = true} hyper = {version = "0.14.26", features = ["full"], optional = true} hyper-timeout = {version = "0.4", optional = true} tokio-stream = "0.1" -tower = {version = "0.4.7", default-features = false, features = ["balance", "buffer", "discover", "limit", "load", "make", "timeout", "util"], optional = true} +tower = {version = "0.4.7", default-features = false, features = ["balance", "buffer", "discover", "limit", "load", "make", "timeout", "util", "load-shed"], optional = true} axum = {version = "0.6.9", default_features = false, optional = true} # rustls diff --git a/tonic/src/status.rs b/tonic/src/status.rs index da8b792e5..cb1de459f 100644 --- a/tonic/src/status.rs +++ b/tonic/src/status.rs @@ -350,6 +350,17 @@ impl Status { Err(err) => err, }; + // If the load shed middleware is enabled, respond to + // service overloaded with an appropriate grpc status. + let err = match err.downcast::() { + Ok(_) => { + return Ok(Status::resource_exhausted( + "Too many active requests for the connection", + )); + } + Err(err) => err, + }; + if let Some(mut status) = find_status_in_source_chain(&*err) { status.source = Some(err.into()); return Ok(status); diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs index 7f2ffde2b..54aa18c60 100644 --- a/tonic/src/transport/server/mod.rs +++ b/tonic/src/transport/server/mod.rs @@ -59,6 +59,7 @@ use tower::{ layer::util::{Identity, Stack}, layer::Layer, limit::concurrency::ConcurrencyLimitLayer, + load_shed::LoadShedLayer, util::Either, Service, ServiceBuilder, }; @@ -81,6 +82,7 @@ const DEFAULT_HTTP2_KEEPALIVE_TIMEOUT_SECS: u64 = 20; pub struct Server { trace_interceptor: Option, concurrency_limit: Option, + load_shed: bool, timeout: Option, #[cfg(feature = "tls")] tls: Option, @@ -103,6 +105,7 @@ impl Default for Server { Self { trace_interceptor: None, concurrency_limit: None, + load_shed: false, timeout: None, #[cfg(feature = "tls")] tls: None, @@ -173,6 +176,27 @@ impl Server { } } + /// Enable or disable load shedding. The default is disabled. + /// + /// When load shedding is enabled, if the service responds with not ready + /// the request will immediately be rejected with a + /// [`resource_exhausted`](https://docs.rs/tonic/latest/tonic/struct.Status.html#method.resource_exhausted) error. + /// The default is to buffer requests. This is especially useful in combination with + /// setting a concurrency limit per connection. + /// + /// # Example + /// + /// ``` + /// # use tonic::transport::Server; + /// # use tower_service::Service; + /// # let builder = Server::builder(); + /// builder.load_shed(true); + /// ``` + #[must_use] + pub fn load_shed(self, load_shed: bool) -> Self { + Server { load_shed, ..self } + } + /// Set a timeout on for all request handlers. /// /// # Example @@ -469,6 +493,7 @@ impl Server { service_builder: self.service_builder.layer(new_layer), trace_interceptor: self.trace_interceptor, concurrency_limit: self.concurrency_limit, + load_shed: self.load_shed, timeout: self.timeout, #[cfg(feature = "tls")] tls: self.tls, @@ -507,6 +532,7 @@ impl Server { { let trace_interceptor = self.trace_interceptor.clone(); let concurrency_limit = self.concurrency_limit; + let load_shed = self.load_shed; let init_connection_window_size = self.init_connection_window_size; let init_stream_window_size = self.init_stream_window_size; let max_concurrent_streams = self.max_concurrent_streams; @@ -529,6 +555,7 @@ impl Server { let svc = MakeSvc { inner: svc, concurrency_limit, + load_shed, timeout, trace_interceptor, _io: PhantomData, @@ -815,6 +842,7 @@ impl fmt::Debug for Svc { struct MakeSvc { concurrency_limit: Option, + load_shed: bool, timeout: Option, inner: S, trace_interceptor: Option, @@ -848,6 +876,11 @@ where let svc = ServiceBuilder::new() .layer_fn(RecoverError::new) + .option_layer(if self.load_shed { + Some(LoadShedLayer::new()) + } else { + None + }) .option_layer(concurrency_limit.map(ConcurrencyLimitLayer::new)) .layer_fn(|s| GrpcTimeout::new(s, timeout)) .service(svc);