diff --git a/tower-http/src/timeout.rs b/tower-http/src/timeout.rs deleted file mode 100644 index 6c669e9b..00000000 --- a/tower-http/src/timeout.rs +++ /dev/null @@ -1,156 +0,0 @@ -//! Middleware that applies a timeout to requests. -//! -//! If the request does not complete within the specified timeout it will be aborted and a `408 -//! Request Timeout` response will be sent. -//! -//! # Differences from `tower::timeout` -//! -//! tower's [`Timeout`](tower::timeout::Timeout) middleware uses an error to signal timeout, i.e. -//! it changes the error type to [`BoxError`](tower::BoxError). For HTTP services that is rarely -//! what you want as returning errors will terminate the connection without sending a response. -//! -//! This middleware won't change the error type and instead return a `408 Request Timeout` -//! response. That means if your service's error type is [`Infallible`] it will still be -//! [`Infallible`] after applying this middleware. -//! -//! # Example -//! -//! ``` -//! use http::{Request, Response}; -//! use hyper::Body; -//! use std::{convert::Infallible, time::Duration}; -//! use tower::ServiceBuilder; -//! use tower_http::timeout::TimeoutLayer; -//! -//! async fn handle(_: Request) -> Result, Infallible> { -//! // ... -//! # Ok(Response::new(Body::empty())) -//! } -//! -//! # #[tokio::main] -//! # async fn main() -> Result<(), Box> { -//! let svc = ServiceBuilder::new() -//! // Timeout requests after 30 seconds -//! .layer(TimeoutLayer::new(Duration::from_secs(30))) -//! .service_fn(handle); -//! # Ok(()) -//! # } -//! ``` -//! -//! [`Infallible`]: std::convert::Infallible - -use http::{Request, Response, StatusCode}; -use pin_project_lite::pin_project; -use std::{ - future::Future, - pin::Pin, - task::{Context, Poll}, - time::Duration, -}; -use tokio::time::Sleep; -use tower_layer::Layer; -use tower_service::Service; - -/// Layer that applies the [`Timeout`] middleware which apply a timeout to requests. -/// -/// See the [module docs](self) for an example. -#[derive(Debug, Clone, Copy)] -pub struct TimeoutLayer { - timeout: Duration, -} - -impl TimeoutLayer { - /// Create a new [`TimeoutLayer`]. - pub fn new(timeout: Duration) -> Self { - TimeoutLayer { timeout } - } -} - -impl Layer for TimeoutLayer { - type Service = Timeout; - - fn layer(&self, inner: S) -> Self::Service { - Timeout::new(inner, self.timeout) - } -} - -/// Middleware which apply a timeout to requests. -/// -/// If the request does not complete within the specified timeout it will be aborted and a `408 -/// Request Timeout` response will be sent. -/// -/// See the [module docs](self) for an example. -#[derive(Debug, Clone, Copy)] -pub struct Timeout { - inner: S, - timeout: Duration, -} - -impl Timeout { - /// Create a new [`Timeout`]. - pub fn new(inner: S, timeout: Duration) -> Self { - Self { inner, timeout } - } - - define_inner_service_accessors!(); - - /// Returns a new [`Layer`] that wraps services with a `Timeout` middleware. - /// - /// [`Layer`]: tower_layer::Layer - pub fn layer(timeout: Duration) -> TimeoutLayer { - TimeoutLayer::new(timeout) - } -} - -impl Service> for Timeout -where - S: Service, Response = Response>, - ResBody: Default, -{ - type Response = S::Response; - type Error = S::Error; - type Future = ResponseFuture; - - #[inline] - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_ready(cx) - } - - fn call(&mut self, req: Request) -> Self::Future { - let sleep = tokio::time::sleep(self.timeout); - ResponseFuture { - inner: self.inner.call(req), - sleep, - } - } -} - -pin_project! { - /// Response future for [`Timeout`]. - pub struct ResponseFuture { - #[pin] - inner: F, - #[pin] - sleep: Sleep, - } -} - -impl Future for ResponseFuture -where - F: Future, E>>, - B: Default, -{ - type Output = Result, E>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); - - if this.sleep.poll(cx).is_ready() { - let mut res = Response::new(B::default()); - *res.status_mut() = StatusCode::REQUEST_TIMEOUT; - return Poll::Ready(Ok(res)); - } - - this.inner.poll(cx) - } -} diff --git a/tower-http/src/timeout/body.rs b/tower-http/src/timeout/body.rs new file mode 100644 index 00000000..79712efd --- /dev/null +++ b/tower-http/src/timeout/body.rs @@ -0,0 +1,219 @@ +//! Middleware that applies a timeout to request and response bodies. +//! +//! Bodies must produce data at most within the specified timeout. +//! If the body does not produce a requested data frame within the timeout period, it will return an error. +//! +//! # Differences from [`crate::timeout::Timeout`] +//! +//! [`crate::timeout::Timeout`] applies a timeout to the request future, not body. +//! That timeout is not reset when bytes are handled, whether the request is active or not. +//! Bodies are handled asynchronously outside of the tower stack's future and thus needs an additional timeout. +//! +//! This middleware will return a [`TimeoutError`]. +//! +//! # Example +//! +//! ``` +//! use http::{Request, Response}; +//! use hyper::Body; +//! use std::time::Duration; +//! use tower::ServiceBuilder; +//! use tower_http::timeout::RequestBodyTimeoutLayer; +//! +//! async fn handle(_: Request) -> Result, std::convert::Infallible> { +//! // ... +//! # todo!() +//! } +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box> { +//! let svc = ServiceBuilder::new() +//! // Timeout bodies after 30 seconds of inactivity +//! .layer(RequestBodyTimeoutLayer::new(Duration::from_secs(30))) +//! .service_fn(handle); +//! # Ok(()) +//! # } +//! ``` + +use crate::BoxError; +use futures_core::{ready, Future}; +use http_body::Body; +use pin_project_lite::pin_project; +use std::{ + pin::Pin, + task::{Context, Poll}, + time::Duration, +}; +use tokio::time::{sleep, Sleep}; + +pin_project! { + /// Wrapper around a [`http_body::Body`] to time out if data is not ready within the specified duration. + pub struct TimeoutBody { + timeout: Duration, + // In http-body 1.0, `poll_*` will be merged into `poll_frame`. + // Merge the two `sleep_data` and `sleep_trailers` into one `sleep`. + // See: https://github.com/tower-rs/tower-http/pull/303#discussion_r1004834958 + #[pin] + sleep_data: Option, + #[pin] + sleep_trailers: Option, + #[pin] + body: B, + } +} + +impl TimeoutBody { + /// Creates a new [`TimeoutBody`]. + pub fn new(timeout: Duration, body: B) -> Self { + TimeoutBody { + timeout, + sleep_data: None, + sleep_trailers: None, + body, + } + } +} + +impl Body for TimeoutBody +where + B: Body, + B::Error: Into, +{ + type Data = B::Data; + type Error = Box; + + fn poll_data( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + let mut this = self.project(); + + // Start the `Sleep` if not active. + let sleep_pinned = if let Some(some) = this.sleep_data.as_mut().as_pin_mut() { + some + } else { + this.sleep_data.set(Some(sleep(*this.timeout))); + this.sleep_data.as_mut().as_pin_mut().unwrap() + }; + + // Error if the timeout has expired. + if let Poll::Ready(()) = sleep_pinned.poll(cx) { + return Poll::Ready(Some(Err(Box::new(TimeoutError(()))))); + } + + // Check for body data. + let data = ready!(this.body.poll_data(cx)); + // Some data is ready. Reset the `Sleep`... + this.sleep_data.set(None); + + Poll::Ready(data.transpose().map_err(Into::into).transpose()) + } + + fn poll_trailers( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { + let mut this = self.project(); + + // In http-body 1.0, `poll_*` will be merged into `poll_frame`. + // Merge the two `sleep_data` and `sleep_trailers` into one `sleep`. + // See: https://github.com/tower-rs/tower-http/pull/303#discussion_r1004834958 + + let sleep_pinned = if let Some(some) = this.sleep_trailers.as_mut().as_pin_mut() { + some + } else { + this.sleep_trailers.set(Some(sleep(*this.timeout))); + this.sleep_trailers.as_mut().as_pin_mut().unwrap() + }; + + // Error if the timeout has expired. + if let Poll::Ready(()) = sleep_pinned.poll(cx) { + return Poll::Ready(Err(Box::new(TimeoutError(())))); + } + + this.body.poll_trailers(cx).map_err(Into::into) + } +} + +/// Error for [`TimeoutBody`]. +#[derive(Debug)] +pub struct TimeoutError(()); + +impl std::error::Error for TimeoutError {} + +impl std::fmt::Display for TimeoutError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "data was not received within the designated timeout") + } +} +#[cfg(test)] +mod tests { + use super::*; + + use bytes::Bytes; + use pin_project_lite::pin_project; + use std::{error::Error, fmt::Display}; + + #[derive(Debug)] + struct MockError; + + impl Error for MockError {} + impl Display for MockError { + fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + todo!() + } + } + + pin_project! { + struct MockBody { + #[pin] + sleep: Sleep + } + } + + impl Body for MockBody { + type Data = Bytes; + type Error = MockError; + + fn poll_data( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + let this = self.project(); + this.sleep.poll(cx).map(|_| Some(Ok(vec![].into()))) + } + + fn poll_trailers( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { + todo!() + } + } + + #[tokio::test] + async fn test_body_available_within_timeout() { + let mock_sleep = Duration::from_secs(1); + let timeout_sleep = Duration::from_secs(2); + + let mock_body = MockBody { + sleep: sleep(mock_sleep), + }; + let timeout_body = TimeoutBody::new(timeout_sleep, mock_body); + + assert!(timeout_body.boxed().data().await.unwrap().is_ok()); + } + + #[tokio::test] + async fn test_body_unavailable_within_timeout_error() { + let mock_sleep = Duration::from_secs(2); + let timeout_sleep = Duration::from_secs(1); + + let mock_body = MockBody { + sleep: sleep(mock_sleep), + }; + let timeout_body = TimeoutBody::new(timeout_sleep, mock_body); + + assert!(timeout_body.boxed().data().await.unwrap().is_err()); + } +} diff --git a/tower-http/src/timeout/mod.rs b/tower-http/src/timeout/mod.rs new file mode 100644 index 00000000..4cbe476f --- /dev/null +++ b/tower-http/src/timeout/mod.rs @@ -0,0 +1,10 @@ +//! Middleware for setting timeouts on requests and responses. + +mod body; +mod service; + +pub use body::{TimeoutBody, TimeoutError}; +pub use service::{ + RequestBodyTimeout, RequestBodyTimeoutLayer, ResponseBodyTimeout, ResponseBodyTimeoutLayer, + Timeout, TimeoutLayer, +}; diff --git a/tower-http/src/timeout/service.rs b/tower-http/src/timeout/service.rs new file mode 100644 index 00000000..13347a34 --- /dev/null +++ b/tower-http/src/timeout/service.rs @@ -0,0 +1,315 @@ +//! Middleware that applies a timeout to requests. +//! +//! If the request does not complete within the specified timeout it will be aborted and a `408 +//! Request Timeout` response will be sent. +//! +//! # Differences from `tower::timeout` +//! +//! tower's [`Timeout`](tower::timeout::Timeout) middleware uses an error to signal timeout, i.e. +//! it changes the error type to [`BoxError`](tower::BoxError). For HTTP services that is rarely +//! what you want as returning errors will terminate the connection without sending a response. +//! +//! This middleware won't change the error type and instead return a `408 Request Timeout` +//! response. That means if your service's error type is [`Infallible`] it will still be +//! [`Infallible`] after applying this middleware. +//! +//! # Example +//! +//! ``` +//! use http::{Request, Response}; +//! use hyper::Body; +//! use std::{convert::Infallible, time::Duration}; +//! use tower::ServiceBuilder; +//! use tower_http::timeout::TimeoutLayer; +//! +//! async fn handle(_: Request) -> Result, Infallible> { +//! // ... +//! # Ok(Response::new(Body::empty())) +//! } +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box> { +//! let svc = ServiceBuilder::new() +//! // Timeout requests after 30 seconds +//! .layer(TimeoutLayer::new(Duration::from_secs(30))) +//! .service_fn(handle); +//! # Ok(()) +//! # } +//! ``` +//! +//! [`Infallible`]: std::convert::Infallible + +use crate::timeout::body::TimeoutBody; +use futures_core::ready; +use http::{Request, Response, StatusCode}; +use pin_project_lite::pin_project; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, + time::Duration, +}; +use tokio::time::Sleep; +use tower_layer::Layer; +use tower_service::Service; + +/// Layer that applies the [`Timeout`] middleware which apply a timeout to requests. +/// +/// See the [module docs](self) for an example. +#[derive(Debug, Clone, Copy)] +pub struct TimeoutLayer { + timeout: Duration, +} + +impl TimeoutLayer { + /// Creates a new [`TimeoutLayer`]. + pub fn new(timeout: Duration) -> Self { + TimeoutLayer { timeout } + } +} + +impl Layer for TimeoutLayer { + type Service = Timeout; + + fn layer(&self, inner: S) -> Self::Service { + Timeout::new(inner, self.timeout) + } +} + +/// Middleware which apply a timeout to requests. +/// +/// If the request does not complete within the specified timeout it will be aborted and a `408 +/// Request Timeout` response will be sent. +/// +/// See the [module docs](self) for an example. +#[derive(Debug, Clone, Copy)] +pub struct Timeout { + inner: S, + timeout: Duration, +} + +impl Timeout { + /// Creates a new [`Timeout`]. + pub fn new(inner: S, timeout: Duration) -> Self { + Self { inner, timeout } + } + + define_inner_service_accessors!(); + + /// Returns a new [`Layer`] that wraps services with a `Timeout` middleware. + /// + /// [`Layer`]: tower_layer::Layer + pub fn layer(timeout: Duration) -> TimeoutLayer { + TimeoutLayer::new(timeout) + } +} + +impl Service> for Timeout +where + S: Service, Response = Response>, + ResBody: Default, +{ + type Response = S::Response; + type Error = S::Error; + type Future = ResponseFuture; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + let sleep = tokio::time::sleep(self.timeout); + ResponseFuture { + inner: self.inner.call(req), + sleep, + } + } +} + +pin_project! { + /// Response future for [`Timeout`]. + pub struct ResponseFuture { + #[pin] + inner: F, + #[pin] + sleep: Sleep, + } +} + +impl Future for ResponseFuture +where + F: Future, E>>, + B: Default, +{ + type Output = Result, E>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + if this.sleep.poll(cx).is_ready() { + let mut res = Response::new(B::default()); + *res.status_mut() = StatusCode::REQUEST_TIMEOUT; + return Poll::Ready(Ok(res)); + } + + this.inner.poll(cx) + } +} + +/// Applies a [`TimeoutBody`] to the request body. +#[derive(Clone, Debug)] +pub struct RequestBodyTimeoutLayer { + timeout: Duration, +} + +impl RequestBodyTimeoutLayer { + /// Creates a new [`RequestBodyTimeoutLayer`]. + pub fn new(timeout: Duration) -> Self { + Self { timeout } + } +} + +impl Layer for RequestBodyTimeoutLayer { + type Service = RequestBodyTimeout; + + fn layer(&self, inner: S) -> Self::Service { + RequestBodyTimeout::new(inner, self.timeout) + } +} + +/// Applies a [`TimeoutBody`] to the request body. +#[derive(Clone, Debug)] +pub struct RequestBodyTimeout { + inner: S, + timeout: Duration, +} + +impl RequestBodyTimeout { + /// Creates a new [`RequestBodyTimeout`]. + pub fn new(service: S, timeout: Duration) -> Self { + Self { + inner: service, + timeout, + } + } + + /// Returns a new [`Layer`] that wraps services with a [`RequestBodyTimeoutLayer`] middleware. + /// + /// [`Layer`]: tower_layer::Layer + pub fn layer(timeout: Duration) -> RequestBodyTimeoutLayer { + RequestBodyTimeoutLayer::new(timeout) + } + + define_inner_service_accessors!(); +} + +impl Service> for RequestBodyTimeout +where + S: Service>>, + S::Error: Into>, +{ + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + let req = req.map(|body| TimeoutBody::new(self.timeout, body)); + self.inner.call(req) + } +} + +/// Applies a [`TimeoutBody`] to the response body. +#[derive(Clone)] +pub struct ResponseBodyTimeoutLayer { + timeout: Duration, +} + +impl ResponseBodyTimeoutLayer { + /// Creates a new [`ResponseBodyTimeoutLayer`]. + pub fn new(timeout: Duration) -> Self { + Self { timeout } + } +} + +impl Layer for ResponseBodyTimeoutLayer { + type Service = ResponseBodyTimeout; + + fn layer(&self, inner: S) -> Self::Service { + ResponseBodyTimeout::new(inner, self.timeout) + } +} + +/// Applies a [`TimeoutBody`] to the response body. +#[derive(Clone)] +pub struct ResponseBodyTimeout { + inner: S, + timeout: Duration, +} + +impl Service> for ResponseBodyTimeout +where + S: Service, Response = Response>, + S::Error: Into>, +{ + type Response = Response>; + type Error = S::Error; + type Future = ResponseBodyTimeoutFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + ResponseBodyTimeoutFuture { + inner: self.inner.call(req), + timeout: self.timeout, + } + } +} + +impl ResponseBodyTimeout { + /// Creates a new [`ResponseBodyTimeout`]. + pub fn new(service: S, timeout: Duration) -> Self { + Self { + inner: service, + timeout, + } + } + + /// Returns a new [`Layer`] that wraps services with a [`ResponseBodyTimeoutLayer`] middleware. + /// + /// [`Layer`]: tower_layer::Layer + pub fn layer(timeout: Duration) -> ResponseBodyTimeoutLayer { + ResponseBodyTimeoutLayer::new(timeout) + } + + define_inner_service_accessors!(); +} + +pin_project! { + /// Response future for [`ResponseBodyTimeout`]. + pub struct ResponseBodyTimeoutFuture { + #[pin] + inner: Fut, + timeout: Duration, + } +} + +impl Future for ResponseBodyTimeoutFuture +where + Fut: Future, E>>, +{ + type Output = Result>, E>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let timeout = self.timeout; + let this = self.project(); + let res = ready!(this.inner.poll(cx))?; + Poll::Ready(Ok(res.map(|body| TimeoutBody::new(timeout, body)))) + } +}