Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type safe state inheritance #1532

Merged
merged 16 commits into from
Nov 18, 2022
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 6 additions & 4 deletions axum-core/src/extract/default_body_limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,28 +34,30 @@ use tower_layer::Layer;
/// http::Request,
/// };
///
/// let router = Router::new()
/// let app = Router::new()
/// .route(
/// "/",
/// // even with `DefaultBodyLimit` the request body is still just `Body`
/// post(|request: Request<Body>| async {}),
/// )
/// .layer(DefaultBodyLimit::max(1024));
/// # let _: Router<(), _> = app;
/// ```
///
/// ```
/// use axum::{Router, routing::post, body::Body, extract::RawBody, http::Request};
/// use tower_http::limit::RequestBodyLimitLayer;
/// use http_body::Limited;
///
/// let router = Router::new()
/// let app = Router::new()
/// .route(
/// "/",
/// // `RequestBodyLimitLayer` changes the request body type to `Limited<Body>`
/// // extracting a different body type wont work
/// post(|request: Request<Limited<Body>>| async {}),
/// )
/// .layer(RequestBodyLimitLayer::new(1024));
/// # let _: Router<(), _> = app;
/// ```
///
/// In general using `DefaultBodyLimit` is recommended but if you need to use third party
Expand Down Expand Up @@ -102,7 +104,7 @@ impl DefaultBodyLimit {
/// use tower_http::limit::RequestBodyLimitLayer;
/// use http_body::Limited;
///
/// let app: Router<_, Limited<Body>> = Router::new()
/// let app: Router<(), Limited<Body>> = Router::new()
/// .route("/", get(|body: Bytes| async {}))
/// // Disable the default limit
/// .layer(DefaultBodyLimit::disable())
Expand Down Expand Up @@ -138,7 +140,7 @@ impl DefaultBodyLimit {
/// use tower_http::limit::RequestBodyLimitLayer;
/// use http_body::Limited;
///
/// let app: Router<_, Limited<Body>> = Router::new()
/// let app: Router<(), Limited<Body>> = Router::new()
/// .route("/", get(|body: Bytes| async {}))
/// // Replace the default of 2MB with 1024 bytes.
/// .layer(DefaultBodyLimit::max(1024));
Expand Down
8 changes: 4 additions & 4 deletions axum-extra/src/extract/cookie/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,11 +255,11 @@ mod tests {
custom_key: CustomKey(Key::generate()),
};

let app = Router::<_, Body>::with_state(state)
let app = Router::<_, Body>::new()
.route("/set", get(set_cookie))
.route("/get", get(get_cookie))
.route("/remove", get(remove_cookie))
.into_service();
.with_state(state);

let res = app
.clone()
Expand Down Expand Up @@ -352,9 +352,9 @@ mod tests {
custom_key: CustomKey(Key::generate()),
};

let app = Router::<_, Body>::with_state(state)
let app = Router::<_, Body>::new()
.route("/get", get(get_cookie))
.into_service();
.with_state(state);

let res = app
.clone()
Expand Down
7 changes: 4 additions & 3 deletions axum-extra/src/extract/cookie/private.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,11 @@ use std::{convert::Infallible, fmt, marker::PhantomData};
/// key: Key::generate(),
/// };
///
/// let app = Router::with_state(state)
/// let app = Router::new()
/// .route("/set", post(set_secret))
/// .route("/get", get(get_secret));
/// # let app: Router<_> = app;
/// .route("/get", get(get_secret))
/// .with_state(state);
/// # let _: axum::routing::RouterService = app;
/// ```
pub struct PrivateCookieJar<K = Key> {
jar: cookie::CookieJar,
Expand Down
7 changes: 4 additions & 3 deletions axum-extra/src/extract/cookie/signed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,11 @@ use std::{convert::Infallible, fmt, marker::PhantomData};
/// key: Key::generate(),
/// };
///
/// let app = Router::with_state(state)
/// let app = Router::new()
/// .route("/sessions", post(create_session))
/// .route("/me", get(me));
/// # let app: Router<_> = app;
/// .route("/me", get(me))
/// .with_state(state);
/// # let _: axum::routing::RouterService = app;
/// ```
pub struct SignedCookieJar<K = Key> {
jar: cookie::CookieJar,
Expand Down
18 changes: 3 additions & 15 deletions axum-extra/src/routing/resource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,30 +38,18 @@ pub struct Resource<S = (), B = Body> {
pub(crate) router: Router<S, B>,
}

impl<B> Resource<(), B>
where
B: axum::body::HttpBody + Send + 'static,
{
/// Create a `Resource` with the given name.
///
/// All routes will be nested at `/{resource_name}`.
pub fn named(resource_name: &str) -> Self {
Self::named_with((), resource_name)
}
}

impl<S, B> Resource<S, B>
where
B: axum::body::HttpBody + Send + 'static,
S: Clone + Send + Sync + 'static,
{
/// Create a `Resource` with the given name and state.
/// Create a `Resource` with the given name.
///
/// All routes will be nested at `/{resource_name}`.
pub fn named_with(state: S, resource_name: &str) -> Self {
pub fn named(resource_name: &str) -> Self {
Self {
name: resource_name.to_owned(),
router: Router::with_state(state),
router: Router::new(),
}
}

Expand Down
24 changes: 13 additions & 11 deletions axum-extra/src/routing/spa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ use tower_service::Service;
/// - `GET /some/other/path` will serve `index.html` since there isn't another
/// route for it
/// - `GET /api/foo` will serve the `api_foo` handler function
pub struct SpaRouter<B = Body, T = (), F = fn(io::Error) -> Ready<StatusCode>> {
pub struct SpaRouter<S = (), B = Body, T = (), F = fn(io::Error) -> Ready<StatusCode>> {
paths: Arc<Paths>,
handle_error: F,
_marker: PhantomData<fn() -> (B, T)>,
_marker: PhantomData<fn() -> (S, B, T)>,
}

#[derive(Debug)]
Expand All @@ -63,7 +63,7 @@ struct Paths {
index_file: PathBuf,
}

impl<B> SpaRouter<B, (), fn(io::Error) -> Ready<StatusCode>> {
impl<S, B> SpaRouter<S, B, (), fn(io::Error) -> Ready<StatusCode>> {
/// Create a new `SpaRouter`.
///
/// Assets will be served at `GET /{serve_assets_at}` from the directory at `assets_dir`.
Expand All @@ -86,7 +86,7 @@ impl<B> SpaRouter<B, (), fn(io::Error) -> Ready<StatusCode>> {
}
}

impl<B, T, F> SpaRouter<B, T, F> {
impl<S, B, T, F> SpaRouter<S, B, T, F> {
/// Set the path to the index file.
///
/// `path` must be relative to `assets_dir` passed to [`SpaRouter::new`].
Expand Down Expand Up @@ -138,7 +138,7 @@ impl<B, T, F> SpaRouter<B, T, F> {
/// let app = Router::new().merge(spa);
/// # let _: Router = app;
/// ```
pub fn handle_error<T2, F2>(self, f: F2) -> SpaRouter<B, T2, F2> {
pub fn handle_error<T2, F2>(self, f: F2) -> SpaRouter<S, B, T2, F2> {
SpaRouter {
paths: self.paths,
handle_error: f,
Expand All @@ -147,16 +147,17 @@ impl<B, T, F> SpaRouter<B, T, F> {
}
}

impl<B, F, T> From<SpaRouter<B, T, F>> for Router<(), B>
impl<S, B, F, T> From<SpaRouter<S, B, T, F>> for Router<S, B>
where
F: Clone + Send + Sync + 'static,
HandleError<Route<B, io::Error>, F, T>: Service<Request<B>, Error = Infallible>,
<HandleError<Route<B, io::Error>, F, T> as Service<Request<B>>>::Response: IntoResponse + Send,
<HandleError<Route<B, io::Error>, F, T> as Service<Request<B>>>::Future: Send,
B: HttpBody + Send + 'static,
T: 'static,
S: Clone + Send + Sync + 'static,
{
fn from(spa: SpaRouter<B, T, F>) -> Self {
fn from(spa: SpaRouter<S, B, T, F>) -> Router<S, B> {
let assets_service = get_service(ServeDir::new(&spa.paths.assets_dir))
.handle_error(spa.handle_error.clone());

Expand Down Expand Up @@ -195,7 +196,7 @@ where
fn clone(&self) -> Self {
Self {
paths: self.paths.clone(),
handle_error: self.handle_error.clone(),
handle_error: self.handle_error,
_marker: self._marker,
}
}
Expand Down Expand Up @@ -264,13 +265,14 @@ mod tests {

let spa = SpaRouter::new("/assets", "test_files").handle_error(handle_error);

Router::<_, Body>::new().merge(spa);
Router::<(), Body>::new().merge(spa);
}

#[allow(dead_code)]
fn works_with_router_with_state() {
let _: Router<String> = Router::with_state(String::new())
let _: axum::RouterService = Router::new()
.merge(SpaRouter::new("/assets", "test_files"))
.route("/", get(|_: axum::extract::State<String>| async {}));
.route("/", get(|_: axum::extract::State<String>| async {}))
.with_state(String::new());
}
}
2 changes: 1 addition & 1 deletion axum-macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ syn = { version = "1.0", features = [
] }

[dev-dependencies]
axum = { path = "../axum", version = "0.6.0-rc.2", features = ["headers"] }
axum = { path = "../axum", version = "0.6.0-rc.2", features = ["headers", "macros"] }
axum-extra = { path = "../axum-extra", version = "0.4.0-rc.1", features = ["typed-routing", "cookie-private"] }
rustversion = "1.0"
serde = { version = "1.0", features = ["derive"] }
Expand Down
26 changes: 13 additions & 13 deletions axum-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ use from_request::Trait::{FromRequest, FromRequestParts};
/// rejection type with `#[from_request(rejection(YourType))]`:
///
/// ```
/// use axum_macros::FromRequest;
/// use axum::{
/// extract::{
/// rejection::{ExtensionRejection, StringRejection},
Expand Down Expand Up @@ -463,8 +462,7 @@ pub fn derive_from_request_parts(item: TokenStream) -> TokenStream {
/// As the error message says, handler function needs to be async.
///
/// ```
/// use axum::{routing::get, Router};
/// use axum_macros::debug_handler;
/// use axum::{routing::get, Router, debug_handler};
///
/// #[tokio::main]
/// async fn main() {
Expand Down Expand Up @@ -493,8 +491,7 @@ pub fn derive_from_request_parts(item: TokenStream) -> TokenStream {
/// To work around that the request body type can be customized like so:
///
/// ```
/// use axum::{body::BoxBody, http::Request};
/// # use axum_macros::debug_handler;
/// use axum::{body::BoxBody, http::Request, debug_handler};
///
/// #[debug_handler(body = BoxBody)]
/// async fn handler(request: Request<BoxBody>) {}
Expand All @@ -506,8 +503,7 @@ pub fn derive_from_request_parts(item: TokenStream) -> TokenStream {
/// [`axum::extract::State`] argument:
///
/// ```
/// use axum::extract::State;
/// # use axum_macros::debug_handler;
/// use axum::{debug_handler, extract::State};
///
/// #[debug_handler]
/// async fn handler(
Expand All @@ -523,8 +519,7 @@ pub fn derive_from_request_parts(item: TokenStream) -> TokenStream {
/// customize the state type you can set it with `#[debug_handler(state = ...)]`:
///
/// ```
/// use axum::extract::{State, FromRef};
/// # use axum_macros::debug_handler;
/// use axum::{debug_handler, extract::{State, FromRef}};
///
/// #[debug_handler(state = AppState)]
/// async fn handler(
Expand Down Expand Up @@ -579,8 +574,11 @@ pub fn derive_typed_path(input: TokenStream) -> TokenStream {
/// # Example
///
/// ```
/// use axum_macros::FromRef;
/// use axum::{Router, routing::get, extract::State};
/// use axum::{
/// Router,
/// routing::get,
/// extract::{State, FromRef},
/// };
///
/// #
/// # type AuthToken = String;
Expand All @@ -605,8 +603,10 @@ pub fn derive_typed_path(input: TokenStream) -> TokenStream {
/// database_pool,
/// };
///
/// let app = Router::with_state(state).route("/", get(handler).post(other_handler));
/// # let _: Router<AppState> = app;
/// let app = Router::new()
/// .route("/", get(handler).post(other_handler))
/// .with_state(state);
/// # let _: axum::routing::RouterService = app;
/// ```
///
/// [`FromRef`]: https://docs.rs/axum/latest/axum/extract/trait.FromRef.html
Expand Down
7 changes: 4 additions & 3 deletions axum-macros/tests/from_ref/pass/basic.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use axum_macros::FromRef;
use axum::{Router, routing::get, extract::State};
use axum::{Router, routing::get, extract::{State, FromRef}};

// This will implement `FromRef` for each field in the struct.
#[derive(Clone, FromRef)]
Expand All @@ -15,5 +14,7 @@ fn main() {
auth_token: Default::default(),
};

let _: Router<AppState> = Router::with_state(state).route("/", get(handler));
let _: axum::routing::RouterService = Router::new()
.route("/", get(handler))
.with_state(state);
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use axum::{extract::FromRequestParts, response::Response};
use axum_macros::FromRequestParts;

#[derive(FromRequestParts)]
struct Extractor {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
error[E0277]: the trait bound `String: FromRequestParts<S>` is not satisfied
--> tests/from_request/fail/parts_extracting_body.rs:6:11
--> tests/from_request/fail/parts_extracting_body.rs:5:11
|
6 | body: String,
5 | body: String,
| ^^^^^^ the trait `FromRequestParts<S>` is not implemented for `String`
|
= help: the following other types implement trait `FromRequestParts<S>`:
Expand Down
1 change: 0 additions & 1 deletion axum-macros/tests/from_request/pass/container.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use axum::{
extract::{FromRequest, Json},
response::Response,
};
use axum_macros::FromRequest;
use serde::Deserialize;

#[derive(Deserialize, FromRequest)]
Expand Down
1 change: 0 additions & 1 deletion axum-macros/tests/from_request/pass/container_parts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use axum::{
extract::{FromRequestParts, Extension},
response::Response,
};
use axum_macros::FromRequestParts;

#[derive(Clone, FromRequestParts)]
#[from_request(via(Extension))]
Expand Down
1 change: 0 additions & 1 deletion axum-macros/tests/from_request/pass/named.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use axum::{
response::Response,
headers::{self, UserAgent},
};
use axum_macros::FromRequest;

#[derive(FromRequest)]
struct Extractor {
Expand Down
1 change: 0 additions & 1 deletion axum-macros/tests/from_request/pass/named_parts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use axum::{
headers::{self, UserAgent},
response::Response,
};
use axum_macros::FromRequestParts;

#[derive(FromRequestParts)]
struct Extractor {
Expand Down
1 change: 0 additions & 1 deletion axum-macros/tests/from_request/pass/named_via.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use axum::{
},
headers::{self, UserAgent},
};
use axum_macros::FromRequest;

#[derive(FromRequest)]
struct Extractor {
Expand Down
1 change: 0 additions & 1 deletion axum-macros/tests/from_request/pass/named_via_parts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use axum::{
},
headers::{self, UserAgent},
};
use axum_macros::FromRequestParts;

#[derive(FromRequestParts)]
struct Extractor {
Expand Down
1 change: 0 additions & 1 deletion axum-macros/tests/from_request/pass/override_rejection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use axum::{
routing::get,
Extension, Router,
};
use axum_macros::FromRequest;

fn main() {
let _: Router = Router::new().route("/", get(handler).post(handler_result));
Expand Down