Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ext::with_mut #457

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
88 changes: 87 additions & 1 deletion src/filters/ext.rs
Expand Up @@ -4,9 +4,11 @@ use std::convert::Infallible;

use futures::future;

use crate::filter::{filter_fn_one, Filter};
use crate::filter::{filter_fn_one, Filter, WrapSealed};
use crate::reject::{self, Rejection};

use self::internal::WithExtensions_;

/// Get a previously set extension of the current route.
///
/// If the extension doesn't exist, this rejects with a `MissingExtension`.
Expand Down Expand Up @@ -34,3 +36,87 @@ unit_error! {
/// An error used to reject if `get` cannot find the extension.
pub MissingExtension: "Missing request extension"
}

/// Access to request `Extensions`.
/// A given function is called before reaching a wrapped filter.
pub fn with_mut<F>(f: F) -> WithExtensions<F>
where
F: Fn(&mut http::Extensions),
{
WithExtensions { f }
}

/// Decorates a `Filter` to access `http::Extensions`.
#[derive(Clone, Copy, Debug)]
pub struct WithExtensions<F> {
f: F,
}

impl<FN, F> WrapSealed<F> for WithExtensions<FN>
where
FN: Fn(&mut http::Extensions) + Clone + Send,
F: Filter + Clone + Send,
{
type Wrapped = WithExtensions_<FN, F>;

fn wrap(&self, filter: F) -> Self::Wrapped {
WithExtensions_ {
f: self.f.clone(),
filter,
}
}
}

mod internal {
#[allow(missing_debug_implementations)]
pub struct WithExtensions_<FN, F> {
pub(super) f: FN,
pub(super) filter: F,
}

impl<FN, F> FilterBase for WithExtensions_<FN, F>
where
FN: Fn(&mut http::Extensions) + Clone + Send,
F: Filter,
{
type Extract = F::Extract;
type Error = F::Error;
type Future = WithExtensionsFuture<FN, F::Future>;

fn filter(&self, _: Internal) -> Self::Future {
WithExtensionsFuture {
f: self.f.clone(),
future: self.filter.filter(Internal),
}
}
}

use crate::filter::{Filter, FilterBase, Internal};
use crate::route;
use pin_project::pin_project;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};

#[allow(missing_debug_implementations)]
#[pin_project]
pub struct WithExtensionsFuture<FN, F> {
f: FN,
#[pin]
future: F,
}

impl<FN, F> Future for WithExtensionsFuture<FN, F>
where
F: Future,
FN: Fn(&mut http::Extensions),
{
type Output = F::Output;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let pin = self.as_mut().project();
route::with(|route| (pin.f)(route.extensions_mut()));
pin.future.poll(cx)
}
}
}
2 changes: 0 additions & 2 deletions src/route.rs
Expand Up @@ -75,11 +75,9 @@ impl Route {
self.req.extensions()
}

/*
pub(crate) fn extensions_mut(&mut self) -> &mut http::Extensions {
self.req.extensions_mut()
}
*/

pub(crate) fn uri(&self) -> &http::Uri {
self.req.uri()
Expand Down
12 changes: 6 additions & 6 deletions tests/ext.rs
Expand Up @@ -6,13 +6,13 @@ struct Ext1(i32);

#[tokio::test]
async fn set_and_get() {
let ext = warp::ext::get::<Ext1>();
let ext_filter = warp::any()
.with(warp::ext::with_mut(|ext| {
ext.insert(Ext1(55));
}))
.and(warp::ext::get::<Ext1>());

let extracted = warp::test::request()
.extension(Ext1(55))
.filter(&ext)
.await
.unwrap();
let extracted = warp::test::request().filter(&ext_filter).await.unwrap();

assert_eq!(extracted, Ext1(55));
}
Expand Down