Skip to content

Commit

Permalink
Add RouterService::{layer, route_layer} (#1550)
Browse files Browse the repository at this point in the history
* Add `RouterService::{layer, route_layer}`

Figure we might as well have these.

* changelog
  • Loading branch information
davidpdrsn committed Nov 19, 2022
1 parent ce8ea56 commit b816ac7
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 2 deletions.
4 changes: 3 additions & 1 deletion axum/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

# Unreleased

- None.
- **added:** Add `RouterService::{layer, route_layer}` ([#1550])

[#1550]: https://github.com/tokio-rs/axum/pull/1550

# 0.6.0-rc.5 (18. November, 2022)

Expand Down
21 changes: 20 additions & 1 deletion axum/src/routing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,13 +311,14 @@ where
}

#[doc = include_str!("../docs/routing/layer.md")]
pub fn layer<L, NewReqBody: 'static>(self, layer: L) -> Router<S, NewReqBody>
pub fn layer<L, NewReqBody>(self, layer: L) -> Router<S, NewReqBody>
where
L: Layer<Route<B>> + Clone + Send + 'static,
L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static,
<L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static,
<L::Service as Service<Request<NewReqBody>>>::Error: Into<Infallible> + 'static,
<L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static,
NewReqBody: 'static,
{
let routes = self
.routes
Expand Down Expand Up @@ -566,6 +567,24 @@ pub(crate) enum FallbackRoute<B, E = Infallible> {
Service(Route<B, E>),
}

impl<B, E> FallbackRoute<B, E> {
fn layer<L, NewReqBody, NewError>(self, layer: L) -> FallbackRoute<NewReqBody, NewError>
where
L: Layer<Route<B, E>> + Clone + Send + 'static,
L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static,
<L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static,
<L::Service as Service<Request<NewReqBody>>>::Error: Into<NewError> + 'static,
<L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static,
NewReqBody: 'static,
NewError: 'static,
{
match self {
FallbackRoute::Default(route) => FallbackRoute::Default(route.layer(layer)),
FallbackRoute::Service(route) => FallbackRoute::Service(route.layer(layer)),
}
}
}

impl<B, E> fmt::Debug for FallbackRoute<B, E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Expand Down
54 changes: 54 additions & 0 deletions axum/src/routing/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::{
body::{Body, HttpBody},
response::Response,
};
use axum_core::response::IntoResponse;
use http::Request;
use matchit::MatchError;
use std::{
Expand All @@ -15,6 +16,7 @@ use std::{
};
use sync_wrapper::SyncWrapper;
use tower::Service;
use tower_layer::Layer;

/// A [`Router`] converted into a [`Service`].
#[derive(Debug)]
Expand Down Expand Up @@ -76,6 +78,58 @@ where
route.call(req)
}

/// Apply a [`tower::Layer`] to all routes in the router.
///
/// See [`Router::layer`] for more details.
pub fn layer<L, NewReqBody>(self, layer: L) -> RouterService<NewReqBody>
where
L: Layer<Route<B>> + Clone + Send + 'static,
L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static,
<L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static,
<L::Service as Service<Request<NewReqBody>>>::Error: Into<Infallible> + 'static,
<L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static,
NewReqBody: 'static,
{
let routes = self
.routes
.into_iter()
.map(|(id, route)| (id, route.layer(layer.clone())))
.collect();

let fallback = self.fallback.layer(layer);

RouterService {
routes,
node: self.node,
fallback,
}
}

/// Apply a [`tower::Layer`] to the router that will only run if the request matches
/// a route.
///
/// See [`Router::route_layer`] for more details.
pub fn route_layer<L>(self, layer: L) -> Self
where
L: Layer<Route<B>> + Clone + Send + 'static,
L::Service: Service<Request<B>> + Clone + Send + 'static,
<L::Service as Service<Request<B>>>::Response: IntoResponse + 'static,
<L::Service as Service<Request<B>>>::Error: Into<Infallible> + 'static,
<L::Service as Service<Request<B>>>::Future: Send + 'static,
{
let routes = self
.routes
.into_iter()
.map(|(id, route)| (id, route.layer(layer.clone())))
.collect();

Self {
routes,
node: self.node,
fallback: self.fallback,
}
}

/// Convert the `RouterService` into a [`MakeService`].
///
/// See [`Router::into_make_service`] for more details.
Expand Down

0 comments on commit b816ac7

Please sign in to comment.