From 234c8ccb13e41c94e32d68c13071de1af4be2ea0 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Mon, 25 Jul 2022 20:06:37 +0200 Subject: [PATCH] Improve build times by generating less IR (#1192) * example * `MethodRouter::merge` * `set_content_length` and `set_allow_header` * `MethodRouter::on_service_boxed_response_body` * `Router::route` * `MethodRouter::merge` again * `MethodRouter::on_service_boxed_response_body` * `Router::call_route` * `MethodRouter::{layer, route_layer}` * revert example * fix test * move function to method on `AllowHeader` --- axum/src/routing/method_routing.rs | 356 ++++++++++++++--------------- axum/src/routing/mod.rs | 74 +++--- axum/src/routing/route.rs | 23 +- 3 files changed, 232 insertions(+), 221 deletions(-) diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index fc86c9ab56..f948fab74a 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -13,10 +13,9 @@ use bytes::BytesMut; use std::{ convert::Infallible, fmt, - marker::PhantomData, task::{Context, Poll}, }; -use tower::{service_fn, util::MapResponseLayer, ServiceBuilder}; +use tower::{service_fn, util::MapResponseLayer}; use tower_layer::Layer; use tower_service::Service; @@ -482,7 +481,6 @@ pub struct MethodRouter { trace: Option>, fallback: Fallback, allow_header: AllowHeader, - _request_body: PhantomData (B, E)>, } #[derive(Clone)] @@ -495,6 +493,22 @@ enum AllowHeader { Bytes(BytesMut), } +impl AllowHeader { + fn merge(self, other: Self) -> Self { + match (self, other) { + (AllowHeader::Skip, _) | (_, AllowHeader::Skip) => AllowHeader::Skip, + (AllowHeader::None, AllowHeader::None) => AllowHeader::None, + (AllowHeader::None, AllowHeader::Bytes(pick)) => AllowHeader::Bytes(pick), + (AllowHeader::Bytes(pick), AllowHeader::None) => AllowHeader::Bytes(pick), + (AllowHeader::Bytes(mut a), AllowHeader::Bytes(b)) => { + a.extend_from_slice(b","); + a.extend_from_slice(&b); + AllowHeader::Bytes(a) + } + } + } +} + impl fmt::Debug for MethodRouter { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("MethodRouter") @@ -532,7 +546,6 @@ impl MethodRouter { trace: None, allow_header: AllowHeader::None, fallback: Fallback::Default(fallback), - _request_body: PhantomData, } } } @@ -723,12 +736,11 @@ impl MethodRouter { >>::Response: IntoResponse + 'static, >>::Future: Send + 'static, { - let layer = ServiceBuilder::new() - .layer_fn(Route::new) - .layer(MapResponseLayer::new(IntoResponse::into_response)) - .layer(layer) - .into_inner(); - let layer_fn = |s| layer.layer(s); + let layer_fn = |svc| { + let svc = layer.layer(svc); + let svc = MapResponseLayer::new(IntoResponse::into_response).layer(svc); + Route::new(svc) + }; MethodRouter { get: self.get.map(layer_fn), @@ -741,128 +753,77 @@ impl MethodRouter { trace: self.trace.map(layer_fn), fallback: self.fallback.map(layer_fn), allow_header: self.allow_header, - _request_body: PhantomData, } } #[doc = include_str!("../docs/method_routing/route_layer.md")] - pub fn route_layer(self, layer: L) -> MethodRouter + pub fn route_layer(mut self, layer: L) -> MethodRouter where L: Layer>, L::Service: Service, Error = E> + Clone + Send + 'static, >>::Response: IntoResponse + 'static, >>::Future: Send + 'static, { - let layer = ServiceBuilder::new() - .layer_fn(Route::new) - .layer(MapResponseLayer::new(IntoResponse::into_response)) - .layer(layer) - .into_inner(); - let layer_fn = |s| layer.layer(s); + let layer_fn = |svc| { + let svc = layer.layer(svc); + let svc = MapResponseLayer::new(IntoResponse::into_response).layer(svc); + Route::new(svc) + }; - MethodRouter { - get: self.get.map(layer_fn), - head: self.head.map(layer_fn), - delete: self.delete.map(layer_fn), - options: self.options.map(layer_fn), - patch: self.patch.map(layer_fn), - post: self.post.map(layer_fn), - put: self.put.map(layer_fn), - trace: self.trace.map(layer_fn), - fallback: self.fallback, - allow_header: self.allow_header, - _request_body: PhantomData, - } + self.get = self.get.map(layer_fn); + self.head = self.head.map(layer_fn); + self.delete = self.delete.map(layer_fn); + self.options = self.options.map(layer_fn); + self.patch = self.patch.map(layer_fn); + self.post = self.post.map(layer_fn); + self.put = self.put.map(layer_fn); + self.trace = self.trace.map(layer_fn); + + self } #[doc = include_str!("../docs/method_routing/merge.md")] - pub fn merge(self, other: MethodRouter) -> Self { - macro_rules! merge { - ( $first:ident, $second:ident ) => { - match ($first, $second) { - (Some(_), Some(_)) => panic!(concat!( - "Overlapping method route. Cannot merge two method routes that both define `", - stringify!($first), - "`" - )), - (Some(svc), None) => Some(svc), - (None, Some(svc)) => Some(svc), - (None, None) => None, + pub fn merge(mut self, other: MethodRouter) -> Self { + // written using inner functions to generate less IR + fn merge_inner(name: &str, first: Option, second: Option) -> Option { + match (first, second) { + (Some(_), Some(_)) => panic!( + "Overlapping method route. Cannot merge two method routes that both define `{}`", name + ), + (Some(svc), None) => Some(svc), + (None, Some(svc)) => Some(svc), + (None, None) => None, + } + } + + fn merge_fallback( + fallback: Fallback, + fallback_other: Fallback, + ) -> Fallback { + match (fallback, fallback_other) { + (pick @ Fallback::Default(_), Fallback::Default(_)) => pick, + (Fallback::Default(_), pick @ Fallback::Custom(_)) => pick, + (pick @ Fallback::Custom(_), Fallback::Default(_)) => pick, + (Fallback::Custom(_), Fallback::Custom(_)) => { + panic!("Cannot merge two `MethodRouter`s that both have a fallback") } - }; + } } - let Self { - get, - head, - delete, - options, - patch, - post, - put, - trace, - fallback, - allow_header, - _request_body: _, - } = self; + self.get = merge_inner("get", self.get, other.get); + self.head = merge_inner("head", self.head, other.head); + self.delete = merge_inner("delete", self.delete, other.delete); + self.options = merge_inner("options", self.options, other.options); + self.patch = merge_inner("patch", self.patch, other.patch); + self.post = merge_inner("post", self.post, other.post); + self.put = merge_inner("put", self.put, other.put); + self.trace = merge_inner("trace", self.trace, other.trace); - let Self { - get: get_other, - head: head_other, - delete: delete_other, - options: options_other, - patch: patch_other, - post: post_other, - put: put_other, - trace: trace_other, - fallback: fallback_other, - allow_header: allow_header_other, - _request_body: _, - } = other; - - let get = merge!(get, get_other); - let head = merge!(head, head_other); - let delete = merge!(delete, delete_other); - let options = merge!(options, options_other); - let patch = merge!(patch, patch_other); - let post = merge!(post, post_other); - let put = merge!(put, put_other); - let trace = merge!(trace, trace_other); - - let fallback = match (fallback, fallback_other) { - (pick @ Fallback::Default(_), Fallback::Default(_)) => pick, - (Fallback::Default(_), pick @ Fallback::Custom(_)) => pick, - (pick @ Fallback::Custom(_), Fallback::Default(_)) => pick, - (Fallback::Custom(_), Fallback::Custom(_)) => { - panic!("Cannot merge two `MethodRouter`s that both have a fallback") - } - }; + self.fallback = merge_fallback(self.fallback, other.fallback); - let allow_header = match (allow_header, allow_header_other) { - (AllowHeader::Skip, _) | (_, AllowHeader::Skip) => AllowHeader::Skip, - (AllowHeader::None, AllowHeader::None) => AllowHeader::None, - (AllowHeader::None, AllowHeader::Bytes(pick)) => AllowHeader::Bytes(pick), - (AllowHeader::Bytes(pick), AllowHeader::None) => AllowHeader::Bytes(pick), - (AllowHeader::Bytes(mut a), AllowHeader::Bytes(b)) => { - a.extend_from_slice(b","); - a.extend_from_slice(&b); - AllowHeader::Bytes(a) - } - }; + self.allow_header = self.allow_header.merge(other.allow_header); - Self { - get, - head, - delete, - options, - patch, - post, - put, - trace, - fallback, - allow_header, - _request_body: PhantomData, - } + self } /// Apply a [`HandleErrorLayer`]. @@ -882,81 +843,118 @@ impl MethodRouter { self.layer(HandleErrorLayer::new(f)) } - fn on_service_boxed_response_body(self, filter: MethodFilter, svc: S) -> Self + fn on_service_boxed_response_body(mut self, filter: MethodFilter, svc: S) -> Self where S: Service, Error = E> + Clone + Send + 'static, S::Response: IntoResponse + 'static, S::Future: Send + 'static, { - macro_rules! set_service { - ( - $filter:ident, - $svc:ident, - $allow_header:ident, - [ - $( - ($out:ident, $variant:ident, [$($method:literal),+]) - ),+ - $(,)? - ] - ) => { - $( - if $filter.contains(MethodFilter::$variant) { - if $out.is_some() { - panic!("Overlapping method route. Cannot add two method routes that both handle `{}`", stringify!($variant)) - } - $out = $svc.clone(); - $( - append_allow_header(&mut $allow_header, $method); - )+ - } - )+ + // written using an inner function to generate less IR + fn set_service( + method_name: &str, + out: &mut Option, + svc: &T, + svc_filter: MethodFilter, + filter: MethodFilter, + allow_header: &mut AllowHeader, + methods: &[&'static str], + ) where + T: Clone, + { + if svc_filter.contains(filter) { + if out.is_some() { + panic!("Overlapping method route. Cannot add two method routes that both handle `{}`", method_name) + } + *out = Some(svc.clone()); + for method in methods { + append_allow_header(allow_header, method); + } } } - // written with a pattern match like this to ensure we update all fields - let Self { - mut get, - mut head, - mut delete, - mut options, - mut patch, - mut post, - mut put, - mut trace, - fallback, - mut allow_header, - _request_body: _, - } = self; - let svc = Some(Route::new(svc)); - set_service!( + let svc = Route::new(svc); + + set_service( + "GET", + &mut self.get, + &svc, filter, - svc, - allow_header, - [ - (get, GET, ["GET", "HEAD"]), - (head, HEAD, ["HEAD"]), - (delete, DELETE, ["DELETE"]), - (options, OPTIONS, ["OPTIONS"]), - (patch, PATCH, ["PATCH"]), - (post, POST, ["POST"]), - (put, PUT, ["PUT"]), - (trace, TRACE, ["TRACE"]), - ] + MethodFilter::GET, + &mut self.allow_header, + &["GET", "HEAD"], ); - Self { - get, - head, - delete, - options, - patch, - post, - put, - trace, - fallback, - allow_header, - _request_body: PhantomData, - } + + set_service( + "HEAD", + &mut self.head, + &svc, + filter, + MethodFilter::HEAD, + &mut self.allow_header, + &["HEAD"], + ); + + set_service( + "TRACE", + &mut self.trace, + &svc, + filter, + MethodFilter::TRACE, + &mut self.allow_header, + &["TRACE"], + ); + + set_service( + "PUT", + &mut self.put, + &svc, + filter, + MethodFilter::PUT, + &mut self.allow_header, + &["PUT"], + ); + + set_service( + "POST", + &mut self.post, + &svc, + filter, + MethodFilter::POST, + &mut self.allow_header, + &["POST"], + ); + + set_service( + "PATCH", + &mut self.patch, + &svc, + filter, + MethodFilter::PATCH, + &mut self.allow_header, + &["PATCH"], + ); + + set_service( + "OPTIONS", + &mut self.options, + &svc, + filter, + MethodFilter::OPTIONS, + &mut self.allow_header, + &["OPTIONS"], + ); + + set_service( + "DELETE", + &mut self.delete, + &svc, + filter, + MethodFilter::DELETE, + &mut self.allow_header, + &["DELETE"], + ); + + self } fn skip_allow_header(mut self) -> Self { @@ -998,7 +996,6 @@ impl Clone for MethodRouter { trace: self.trace.clone(), fallback: self.fallback.clone(), allow_header: self.allow_header.clone(), - _request_body: PhantomData, } } } @@ -1056,7 +1053,6 @@ where trace, fallback, allow_header, - _request_body: _, } = self; call!(req, method, HEAD, head); @@ -1091,7 +1087,7 @@ mod tests { use axum_core::response::IntoResponse; use http::{header::ALLOW, HeaderMap}; use std::time::Duration; - use tower::{timeout::TimeoutLayer, Service, ServiceExt}; + use tower::{timeout::TimeoutLayer, Service, ServiceBuilder, ServiceExt}; use tower_http::{auth::RequireAuthorizationLayer, services::fs::ServeDir}; #[tokio::test] diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index 135363971b..eda58c732c 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -126,12 +126,16 @@ where T::Response: IntoResponse, T::Future: Send + 'static, { - if path.is_empty() { - panic!("Paths must start with a `/`. Use \"/\" for root routes"); - } else if !path.starts_with('/') { - panic!("Paths must start with a `/`"); + fn validate_path(path: &str) { + if path.is_empty() { + panic!("Paths must start with a `/`. Use \"/\" for root routes"); + } else if !path.starts_with('/') { + panic!("Paths must start with a `/`"); + } } + validate_path(path); + let service = match try_downcast::, _>(service) { Ok(_) => { panic!("Invalid route: `Router::route` cannot be used with `Router`s. Use `Router::nest` instead") @@ -162,16 +166,20 @@ where Err(service) => Endpoint::Route(Route::new(service)), }; + self.set_node(path, id); + + self.routes.insert(id, service); + + self + } + + fn set_node(&mut self, path: &str, id: RouteId) { let mut node = Arc::try_unwrap(Arc::clone(&self.node)).unwrap_or_else(|node| (*node).clone()); if let Err(err) = node.insert(path, id) { self.panic_on_matchit_error(err); } self.node = Arc::new(node); - - self.routes.insert(id, service); - - self } #[doc = include_str!("../docs/routing/nest.md")] @@ -419,28 +427,38 @@ where let id = *match_.value; #[cfg(feature = "matched-path")] - if let Some(matched_path) = self.node.route_id_to_path.get(&id) { - use crate::extract::MatchedPath; - - let matched_path = if let Some(previous) = req.extensions_mut().get::() { - // a previous `MatchedPath` might exist if we're inside a nested Router - let previous = if let Some(previous) = - previous.as_str().strip_suffix(NEST_TAIL_PARAM_CAPTURE) - { - previous + { + fn set_matched_path( + id: RouteId, + route_id_to_path: &HashMap>, + extensions: &mut http::Extensions, + ) { + if let Some(matched_path) = route_id_to_path.get(&id) { + use crate::extract::MatchedPath; + + let matched_path = if let Some(previous) = extensions.get::() { + // a previous `MatchedPath` might exist if we're inside a nested Router + let previous = if let Some(previous) = + previous.as_str().strip_suffix(NEST_TAIL_PARAM_CAPTURE) + { + previous + } else { + previous.as_str() + }; + + let matched_path = format!("{}{}", previous, matched_path); + matched_path.into() + } else { + Arc::clone(matched_path) + }; + extensions.insert(MatchedPath(matched_path)); } else { - previous.as_str() - }; + #[cfg(debug_assertions)] + panic!("should always have a matched path for a route id"); + } + } - let matched_path = format!("{}{}", previous, matched_path); - matched_path.into() - } else { - Arc::clone(matched_path) - }; - req.extensions_mut().insert(MatchedPath(matched_path)); - } else { - #[cfg(debug_assertions)] - panic!("should always have a matched path for a route id"); + set_matched_path(id, &self.node.route_id_to_path, req.extensions_mut()); } url_params::insert_url_params(req.extensions_mut(), match_.params); diff --git a/axum/src/routing/route.rs b/axum/src/routing/route.rs index 71ff3e0529..e116c57aa7 100644 --- a/axum/src/routing/route.rs +++ b/axum/src/routing/route.rs @@ -6,7 +6,7 @@ use axum_core::response::IntoResponse; use bytes::Bytes; use http::{ header::{self, CONTENT_LENGTH}, - HeaderValue, Request, + HeaderMap, HeaderValue, Request, }; use pin_project_lite::pin_project; use std::{ @@ -161,10 +161,10 @@ where res.extensions_mut().insert(AlreadyPassedThroughRouteFuture); } - set_allow_header(&mut res, this.allow_header); + set_allow_header(res.headers_mut(), this.allow_header); // make sure to set content-length before removing the body - set_content_length(&mut res); + set_content_length(res.size_hint(), res.headers_mut()); let res = if *this.strip_body { res.map(|_| boxed(Empty::new())) @@ -176,10 +176,10 @@ where } } -fn set_allow_header(res: &mut Response, allow_header: &mut Option) { +fn set_allow_header(headers: &mut HeaderMap, allow_header: &mut Option) { match allow_header.take() { - Some(allow_header) if !res.headers().contains_key(header::ALLOW) => { - res.headers_mut().insert( + Some(allow_header) if !headers.contains_key(header::ALLOW) => { + headers.insert( header::ALLOW, HeaderValue::from_maybe_shared(allow_header).expect("invalid `Allow` header"), ); @@ -188,15 +188,12 @@ fn set_allow_header(res: &mut Response, allow_header: &mut Option) } } -fn set_content_length(res: &mut Response) -where - B: HttpBody, -{ - if res.headers().contains_key(CONTENT_LENGTH) { +fn set_content_length(size_hint: http_body::SizeHint, headers: &mut HeaderMap) { + if headers.contains_key(CONTENT_LENGTH) { return; } - if let Some(size) = res.size_hint().exact() { + if let Some(size) = size_hint.exact() { let header_value = if size == 0 { #[allow(clippy::declare_interior_mutable_const)] const ZERO: HeaderValue = HeaderValue::from_static("0"); @@ -207,7 +204,7 @@ where HeaderValue::from_str(buffer.format(size)).unwrap() }; - res.headers_mut().insert(CONTENT_LENGTH, header_value); + headers.insert(CONTENT_LENGTH, header_value); } }