Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

httpauth: Refactor out Mutex #69

Merged
merged 2 commits into from Jun 11, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
90 changes: 81 additions & 9 deletions actix-web-httpauth/src/middleware.rs
@@ -1,15 +1,16 @@
//! HTTP Authentication middleware.

use std::cell::RefCell;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::rc::Rc;
use std::sync::Arc;

use actix_service::{Service, Transform};
use actix_web::dev::{ServiceRequest, ServiceResponse};
use actix_web::Error;
use futures_util::future::{self, FutureExt, LocalBoxFuture, TryFutureExt};
use futures_util::lock::Mutex;
use futures_util::task::{Context, Poll};

use crate::extractors::{basic, bearer, AuthExtractor};
Expand Down Expand Up @@ -142,7 +143,7 @@ where

fn new_transform(&self, service: S) -> Self::Future {
future::ok(AuthenticationMiddleware {
service: Arc::new(Mutex::new(service)),
service: Rc::new(RefCell::new(service)),
process_fn: self.process_fn.clone(),
_extractor: PhantomData,
})
Expand All @@ -154,7 +155,7 @@ pub struct AuthenticationMiddleware<S, F, T>
where
T: AuthExtractor,
{
service: Arc<Mutex<S>>,
service: Rc<RefCell<S>>,
process_fn: Arc<F>,
_extractor: PhantomData<T>,
}
Expand All @@ -181,21 +182,22 @@ where
ctx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.service
.try_lock()
.expect("AuthenticationMiddleware was called already")
.borrow_mut()
.poll_ready(ctx)
}

fn call(&mut self, req: Self::Request) -> Self::Future {
let process_fn = self.process_fn.clone();
// Note: cloning the mutex, not the service itself
let inner = self.service.clone();

let service = Rc::clone(&self.service);

async move {
let (req, credentials) = Extract::<T>::new(req).await?;
let req = process_fn(req, credentials).await?;
let mut service = inner.lock().await;
service.call(req).await
// It is important that `borrow_mut()` and `.await` are on
// separate lines, or else a panic occurs.
let fut = service.borrow_mut().call(req);
fut.await
}
.boxed_local()
}
Expand Down Expand Up @@ -246,3 +248,73 @@ where
Poll::Ready(Ok((req, credentials)))
}
}

#[cfg(test)]
mod tests {
use super::*;
use actix_web::test::TestRequest;
use actix_service::{into_service, Service};
use futures_util::join;
use crate::extractors::bearer::BearerAuth;
use actix_web::error;


/// This is a test for https://github.com/actix/actix-extras/issues/10
#[actix_rt::test]
async fn test_middleware_panic() {
let mut middleware = AuthenticationMiddleware {
service: Rc::new(RefCell::new(into_service(|_: ServiceRequest| {
async move {
actix_rt::time::delay_for(std::time::Duration::from_secs(1)).await;
Err::<ServiceResponse, _>(error::ErrorBadRequest("error"))
}}))),
process_fn: Arc::new(|req, _: BearerAuth| async {
Ok(req) }),
_extractor: PhantomData,
};

let req = TestRequest::with_header("Authorization", "Bearer 1").to_srv_request();

let f = middleware.call(req);

let res = futures_util::future::lazy(|cx| middleware.poll_ready(cx) );


assert!(join!(f, res).0.is_err());
}

/// This is a test for https://github.com/actix/actix-extras/issues/10
#[actix_rt::test]
async fn test_middleware_panic_several_orders() {
let mut middleware = AuthenticationMiddleware {
service: Rc::new(RefCell::new(into_service(|_: ServiceRequest| {
async move {
actix_rt::time::delay_for(std::time::Duration::from_secs(1)).await;
Err::<ServiceResponse, _>(error::ErrorBadRequest("error"))
}}))),
process_fn: Arc::new(|req, _: BearerAuth| async {
Ok(req) }),
_extractor: PhantomData,
};

let req = TestRequest::with_header("Authorization", "Bearer 1").to_srv_request();

let f1 = middleware.call(req);

let req = TestRequest::with_header("Authorization", "Bearer 1").to_srv_request();

let f2 = middleware.call(req);

let req = TestRequest::with_header("Authorization", "Bearer 1").to_srv_request();

let f3 = middleware.call(req);

let res = futures_util::future::lazy(|cx| middleware.poll_ready(cx));

let result = join!(f1, f2, f3, res);

assert!(result.0.is_err());
assert!(result.1.is_err());
assert!(result.2.is_err());
}
}