Skip to content

Commit

Permalink
Type safe state inheritance (#1532)
Browse files Browse the repository at this point in the history
* Make state type safe

* fix examples

* remove unnecessary `#[track_caller]`s

* Router::into_service -> Router::with_state

* fixup docs

* macro docs

* add missing docs

* fix examples

* format

* changelog

* Update trybuild tests

* Make sure fallbacks are still inherited for opaque services (#1540)

* Document nesting routers with different state

* fix leftover conflicts
  • Loading branch information
davidpdrsn committed Nov 18, 2022
1 parent ba8e9c1 commit 64960bb
Show file tree
Hide file tree
Showing 62 changed files with 674 additions and 735 deletions.
10 changes: 6 additions & 4 deletions axum-core/src/extract/default_body_limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,28 +34,30 @@ use tower_layer::Layer;
/// http::Request,
/// };
///
/// let router = Router::new()
/// let app = Router::new()
/// .route(
/// "/",
/// // even with `DefaultBodyLimit` the request body is still just `Body`
/// post(|request: Request<Body>| async {}),
/// )
/// .layer(DefaultBodyLimit::max(1024));
/// # let _: Router<(), _> = app;
/// ```
///
/// ```
/// use axum::{Router, routing::post, body::Body, extract::RawBody, http::Request};
/// use tower_http::limit::RequestBodyLimitLayer;
/// use http_body::Limited;
///
/// let router = Router::new()
/// let app = Router::new()
/// .route(
/// "/",
/// // `RequestBodyLimitLayer` changes the request body type to `Limited<Body>`
/// // extracting a different body type wont work
/// post(|request: Request<Limited<Body>>| async {}),
/// )
/// .layer(RequestBodyLimitLayer::new(1024));
/// # let _: Router<(), _> = app;
/// ```
///
/// In general using `DefaultBodyLimit` is recommended but if you need to use third party
Expand Down Expand Up @@ -102,7 +104,7 @@ impl DefaultBodyLimit {
/// use tower_http::limit::RequestBodyLimitLayer;
/// use http_body::Limited;
///
/// let app: Router<_, Limited<Body>> = Router::new()
/// let app: Router<(), Limited<Body>> = Router::new()
/// .route("/", get(|body: Bytes| async {}))
/// // Disable the default limit
/// .layer(DefaultBodyLimit::disable())
Expand Down Expand Up @@ -138,7 +140,7 @@ impl DefaultBodyLimit {
/// use tower_http::limit::RequestBodyLimitLayer;
/// use http_body::Limited;
///
/// let app: Router<_, Limited<Body>> = Router::new()
/// let app: Router<(), Limited<Body>> = Router::new()
/// .route("/", get(|body: Bytes| async {}))
/// // Replace the default of 2MB with 1024 bytes.
/// .layer(DefaultBodyLimit::max(1024));
Expand Down
8 changes: 4 additions & 4 deletions axum-extra/src/extract/cookie/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,11 +255,11 @@ mod tests {
custom_key: CustomKey(Key::generate()),
};

let app = Router::<_, Body>::with_state(state)
let app = Router::<_, Body>::new()
.route("/set", get(set_cookie))
.route("/get", get(get_cookie))
.route("/remove", get(remove_cookie))
.into_service();
.with_state(state);

let res = app
.clone()
Expand Down Expand Up @@ -352,9 +352,9 @@ mod tests {
custom_key: CustomKey(Key::generate()),
};

let app = Router::<_, Body>::with_state(state)
let app = Router::<_, Body>::new()
.route("/get", get(get_cookie))
.into_service();
.with_state(state);

let res = app
.clone()
Expand Down
7 changes: 4 additions & 3 deletions axum-extra/src/extract/cookie/private.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,11 @@ use std::{convert::Infallible, fmt, marker::PhantomData};
/// key: Key::generate(),
/// };
///
/// let app = Router::with_state(state)
/// let app = Router::new()
/// .route("/set", post(set_secret))
/// .route("/get", get(get_secret));
/// # let app: Router<_> = app;
/// .route("/get", get(get_secret))
/// .with_state(state);
/// # let _: axum::routing::RouterService = app;
/// ```
pub struct PrivateCookieJar<K = Key> {
jar: cookie::CookieJar,
Expand Down
7 changes: 4 additions & 3 deletions axum-extra/src/extract/cookie/signed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,11 @@ use std::{convert::Infallible, fmt, marker::PhantomData};
/// key: Key::generate(),
/// };
///
/// let app = Router::with_state(state)
/// let app = Router::new()
/// .route("/sessions", post(create_session))
/// .route("/me", get(me));
/// # let app: Router<_> = app;
/// .route("/me", get(me))
/// .with_state(state);
/// # let _: axum::routing::RouterService = app;
/// ```
pub struct SignedCookieJar<K = Key> {
jar: cookie::CookieJar,
Expand Down
18 changes: 3 additions & 15 deletions axum-extra/src/routing/resource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,30 +38,18 @@ pub struct Resource<S = (), B = Body> {
pub(crate) router: Router<S, B>,
}

impl<B> Resource<(), B>
where
B: axum::body::HttpBody + Send + 'static,
{
/// Create a `Resource` with the given name.
///
/// All routes will be nested at `/{resource_name}`.
pub fn named(resource_name: &str) -> Self {
Self::named_with((), resource_name)
}
}

impl<S, B> Resource<S, B>
where
B: axum::body::HttpBody + Send + 'static,
S: Clone + Send + Sync + 'static,
{
/// Create a `Resource` with the given name and state.
/// Create a `Resource` with the given name.
///
/// All routes will be nested at `/{resource_name}`.
pub fn named_with(state: S, resource_name: &str) -> Self {
pub fn named(resource_name: &str) -> Self {
Self {
name: resource_name.to_owned(),
router: Router::with_state(state),
router: Router::new(),
}
}

Expand Down
24 changes: 13 additions & 11 deletions axum-extra/src/routing/spa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ use tower_service::Service;
/// - `GET /some/other/path` will serve `index.html` since there isn't another
/// route for it
/// - `GET /api/foo` will serve the `api_foo` handler function
pub struct SpaRouter<B = Body, T = (), F = fn(io::Error) -> Ready<StatusCode>> {
pub struct SpaRouter<S = (), B = Body, T = (), F = fn(io::Error) -> Ready<StatusCode>> {
paths: Arc<Paths>,
handle_error: F,
_marker: PhantomData<fn() -> (B, T)>,
_marker: PhantomData<fn() -> (S, B, T)>,
}

#[derive(Debug)]
Expand All @@ -63,7 +63,7 @@ struct Paths {
index_file: PathBuf,
}

impl<B> SpaRouter<B, (), fn(io::Error) -> Ready<StatusCode>> {
impl<S, B> SpaRouter<S, B, (), fn(io::Error) -> Ready<StatusCode>> {
/// Create a new `SpaRouter`.
///
/// Assets will be served at `GET /{serve_assets_at}` from the directory at `assets_dir`.
Expand All @@ -86,7 +86,7 @@ impl<B> SpaRouter<B, (), fn(io::Error) -> Ready<StatusCode>> {
}
}

impl<B, T, F> SpaRouter<B, T, F> {
impl<S, B, T, F> SpaRouter<S, B, T, F> {
/// Set the path to the index file.
///
/// `path` must be relative to `assets_dir` passed to [`SpaRouter::new`].
Expand Down Expand Up @@ -138,7 +138,7 @@ impl<B, T, F> SpaRouter<B, T, F> {
/// let app = Router::new().merge(spa);
/// # let _: Router = app;
/// ```
pub fn handle_error<T2, F2>(self, f: F2) -> SpaRouter<B, T2, F2> {
pub fn handle_error<T2, F2>(self, f: F2) -> SpaRouter<S, B, T2, F2> {
SpaRouter {
paths: self.paths,
handle_error: f,
Expand All @@ -147,16 +147,17 @@ impl<B, T, F> SpaRouter<B, T, F> {
}
}

impl<B, F, T> From<SpaRouter<B, T, F>> for Router<(), B>
impl<S, B, F, T> From<SpaRouter<S, B, T, F>> for Router<S, B>
where
F: Clone + Send + Sync + 'static,
HandleError<Route<B, io::Error>, F, T>: Service<Request<B>, Error = Infallible>,
<HandleError<Route<B, io::Error>, F, T> as Service<Request<B>>>::Response: IntoResponse + Send,
<HandleError<Route<B, io::Error>, F, T> as Service<Request<B>>>::Future: Send,
B: HttpBody + Send + 'static,
T: 'static,
S: Clone + Send + Sync + 'static,
{
fn from(spa: SpaRouter<B, T, F>) -> Self {
fn from(spa: SpaRouter<S, B, T, F>) -> Router<S, B> {
let assets_service = get_service(ServeDir::new(&spa.paths.assets_dir))
.handle_error(spa.handle_error.clone());

Expand Down Expand Up @@ -195,7 +196,7 @@ where
fn clone(&self) -> Self {
Self {
paths: self.paths.clone(),
handle_error: self.handle_error.clone(),
handle_error: self.handle_error,
_marker: self._marker,
}
}
Expand Down Expand Up @@ -264,13 +265,14 @@ mod tests {

let spa = SpaRouter::new("/assets", "test_files").handle_error(handle_error);

Router::<_, Body>::new().merge(spa);
Router::<(), Body>::new().merge(spa);
}

#[allow(dead_code)]
fn works_with_router_with_state() {
let _: Router<String> = Router::with_state(String::new())
let _: axum::RouterService = Router::new()
.merge(SpaRouter::new("/assets", "test_files"))
.route("/", get(|_: axum::extract::State<String>| async {}));
.route("/", get(|_: axum::extract::State<String>| async {}))
.with_state(String::new());
}
}
2 changes: 1 addition & 1 deletion axum-macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ syn = { version = "1.0", features = [
] }

[dev-dependencies]
axum = { path = "../axum", version = "0.6.0-rc.2", features = ["headers"] }
axum = { path = "../axum", version = "0.6.0-rc.2", features = ["headers", "macros"] }
axum-extra = { path = "../axum-extra", version = "0.4.0-rc.1", features = ["typed-routing", "cookie-private"] }
rustversion = "1.0"
serde = { version = "1.0", features = ["derive"] }
Expand Down
26 changes: 13 additions & 13 deletions axum-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ use from_request::Trait::{FromRequest, FromRequestParts};
/// rejection type with `#[from_request(rejection(YourType))]`:
///
/// ```
/// use axum_macros::FromRequest;
/// use axum::{
/// extract::{
/// rejection::{ExtensionRejection, StringRejection},
Expand Down Expand Up @@ -463,8 +462,7 @@ pub fn derive_from_request_parts(item: TokenStream) -> TokenStream {
/// As the error message says, handler function needs to be async.
///
/// ```
/// use axum::{routing::get, Router};
/// use axum_macros::debug_handler;
/// use axum::{routing::get, Router, debug_handler};
///
/// #[tokio::main]
/// async fn main() {
Expand Down Expand Up @@ -493,8 +491,7 @@ pub fn derive_from_request_parts(item: TokenStream) -> TokenStream {
/// To work around that the request body type can be customized like so:
///
/// ```
/// use axum::{body::BoxBody, http::Request};
/// # use axum_macros::debug_handler;
/// use axum::{body::BoxBody, http::Request, debug_handler};
///
/// #[debug_handler(body = BoxBody)]
/// async fn handler(request: Request<BoxBody>) {}
Expand All @@ -506,8 +503,7 @@ pub fn derive_from_request_parts(item: TokenStream) -> TokenStream {
/// [`axum::extract::State`] argument:
///
/// ```
/// use axum::extract::State;
/// # use axum_macros::debug_handler;
/// use axum::{debug_handler, extract::State};
///
/// #[debug_handler]
/// async fn handler(
Expand All @@ -523,8 +519,7 @@ pub fn derive_from_request_parts(item: TokenStream) -> TokenStream {
/// customize the state type you can set it with `#[debug_handler(state = ...)]`:
///
/// ```
/// use axum::extract::{State, FromRef};
/// # use axum_macros::debug_handler;
/// use axum::{debug_handler, extract::{State, FromRef}};
///
/// #[debug_handler(state = AppState)]
/// async fn handler(
Expand Down Expand Up @@ -579,8 +574,11 @@ pub fn derive_typed_path(input: TokenStream) -> TokenStream {
/// # Example
///
/// ```
/// use axum_macros::FromRef;
/// use axum::{Router, routing::get, extract::State};
/// use axum::{
/// Router,
/// routing::get,
/// extract::{State, FromRef},
/// };
///
/// #
/// # type AuthToken = String;
Expand All @@ -605,8 +603,10 @@ pub fn derive_typed_path(input: TokenStream) -> TokenStream {
/// database_pool,
/// };
///
/// let app = Router::with_state(state).route("/", get(handler).post(other_handler));
/// # let _: Router<AppState> = app;
/// let app = Router::new()
/// .route("/", get(handler).post(other_handler))
/// .with_state(state);
/// # let _: axum::routing::RouterService = app;
/// ```
///
/// [`FromRef`]: https://docs.rs/axum/latest/axum/extract/trait.FromRef.html
Expand Down
7 changes: 4 additions & 3 deletions axum-macros/tests/from_ref/pass/basic.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use axum_macros::FromRef;
use axum::{Router, routing::get, extract::State};
use axum::{Router, routing::get, extract::{State, FromRef}};

// This will implement `FromRef` for each field in the struct.
#[derive(Clone, FromRef)]
Expand All @@ -15,5 +14,7 @@ fn main() {
auth_token: Default::default(),
};

let _: Router<AppState> = Router::with_state(state).route("/", get(handler));
let _: axum::routing::RouterService = Router::new()
.route("/", get(handler))
.with_state(state);
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use axum::{extract::FromRequestParts, response::Response};
use axum_macros::FromRequestParts;

#[derive(FromRequestParts)]
struct Extractor {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
error[E0277]: the trait bound `String: FromRequestParts<S>` is not satisfied
--> tests/from_request/fail/parts_extracting_body.rs:6:11
--> tests/from_request/fail/parts_extracting_body.rs:5:11
|
6 | body: String,
5 | body: String,
| ^^^^^^ the trait `FromRequestParts<S>` is not implemented for `String`
|
= help: the following other types implement trait `FromRequestParts<S>`:
Expand Down
1 change: 0 additions & 1 deletion axum-macros/tests/from_request/pass/container.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use axum::{
extract::{FromRequest, Json},
response::Response,
};
use axum_macros::FromRequest;
use serde::Deserialize;

#[derive(Deserialize, FromRequest)]
Expand Down
1 change: 0 additions & 1 deletion axum-macros/tests/from_request/pass/container_parts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use axum::{
extract::{FromRequestParts, Extension},
response::Response,
};
use axum_macros::FromRequestParts;

#[derive(Clone, FromRequestParts)]
#[from_request(via(Extension))]
Expand Down
1 change: 0 additions & 1 deletion axum-macros/tests/from_request/pass/named.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use axum::{
response::Response,
headers::{self, UserAgent},
};
use axum_macros::FromRequest;

#[derive(FromRequest)]
struct Extractor {
Expand Down
1 change: 0 additions & 1 deletion axum-macros/tests/from_request/pass/named_parts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use axum::{
headers::{self, UserAgent},
response::Response,
};
use axum_macros::FromRequestParts;

#[derive(FromRequestParts)]
struct Extractor {
Expand Down
1 change: 0 additions & 1 deletion axum-macros/tests/from_request/pass/named_via.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use axum::{
},
headers::{self, UserAgent},
};
use axum_macros::FromRequest;

#[derive(FromRequest)]
struct Extractor {
Expand Down
1 change: 0 additions & 1 deletion axum-macros/tests/from_request/pass/named_via_parts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use axum::{
},
headers::{self, UserAgent},
};
use axum_macros::FromRequestParts;

#[derive(FromRequestParts)]
struct Extractor {
Expand Down
1 change: 0 additions & 1 deletion axum-macros/tests/from_request/pass/override_rejection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use axum::{
routing::get,
Extension, Router,
};
use axum_macros::FromRequest;

fn main() {
let _: Router = Router::new().route("/", get(handler).post(handler_result));
Expand Down

0 comments on commit 64960bb

Please sign in to comment.