Skip to content

Commit

Permalink
expose on_connect v2 (#1754)
Browse files Browse the repository at this point in the history
Co-authored-by: Mikail Bagishov <bagishov.mikail@yandex.ru>
  • Loading branch information
robjtede and MikailBag committed Oct 30, 2020
1 parent 4519db3 commit 9963a5e
Show file tree
Hide file tree
Showing 16 changed files with 373 additions and 71 deletions.
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
* Add request-local data extractor `web::ReqData`. [#1748]
* Add ability to register closure for request middleware logging. [#1749]
* Add `app_data` to `ServiceConfig`. [#1757]
* Expose `on_connect` for access to the connection stream before request is handled. [#1754]

### Changed
* Print non-configured `Data<T>` type when attempting extraction. [#1743]
Expand All @@ -16,6 +17,7 @@
[#1743]: https://github.com/actix/actix-web/pull/1743
[#1748]: https://github.com/actix/actix-web/pull/1748
[#1750]: https://github.com/actix/actix-web/pull/1750
[#1754]: https://github.com/actix/actix-web/pull/1754
[#1749]: https://github.com/actix/actix-web/pull/1749


Expand Down
14 changes: 9 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,14 @@ required-features = ["compress"]
name = "test_server"
required-features = ["compress"]

[[example]]
name = "on_connect"
required-features = []

[[example]]
name = "client"
required-features = ["rustls"]

[dependencies]
actix-codec = "0.3.0"
actix-service = "1.0.6"
Expand Down Expand Up @@ -105,7 +113,7 @@ tinyvec = { version = "1", features = ["alloc"] }
actix = "0.10.0"
actix-http = { version = "2.0.0", features = ["actors"] }
rand = "0.7"
env_logger = "0.7"
env_logger = "0.8"
serde_derive = "1.0"
brotli2 = "0.3.2"
flate2 = "1.0.13"
Expand All @@ -125,10 +133,6 @@ actix-files = { path = "actix-files" }
actix-multipart = { path = "actix-multipart" }
awc = { path = "awc" }

[[example]]
name = "client"
required-features = ["rustls"]

[[bench]]
name = "server"
harness = false
Expand Down
7 changes: 7 additions & 0 deletions actix-http/CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
# Changes

## Unreleased - 2020-xx-xx
### Added
* Added more flexible `on_connect_ext` methods for on-connect handling. [#1754]

### Changed
* Upgrade `base64` to `0.13`.
* Upgrade `pin-project` to `1.0`.

[#1754]: https://github.com/actix/actix-web/pull/1754


## 2.0.0 - 2020-09-11
* No significant changes from `2.0.0-beta.4`.

Expand Down
41 changes: 34 additions & 7 deletions actix-http/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@ use crate::helpers::{Data, DataFactory};
use crate::request::Request;
use crate::response::Response;
use crate::service::HttpService;
use crate::{ConnectCallback, Extensions};

/// A http service builder
/// A HTTP service builder
///
/// This type can be used to construct an instance of `http service` through a
/// This type can be used to construct an instance of [`HttpService`] through a
/// builder-like pattern.
pub struct HttpServiceBuilder<T, S, X = ExpectHandler, U = UpgradeHandler<T>> {
keep_alive: KeepAlive,
Expand All @@ -27,7 +28,9 @@ pub struct HttpServiceBuilder<T, S, X = ExpectHandler, U = UpgradeHandler<T>> {
local_addr: Option<net::SocketAddr>,
expect: X,
upgrade: Option<U>,
// DEPRECATED: in favor of on_connect_ext
on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
on_connect_ext: Option<Rc<ConnectCallback<T>>>,
_t: PhantomData<(T, S)>,
}

Expand All @@ -49,6 +52,7 @@ where
expect: ExpectHandler,
upgrade: None,
on_connect: None,
on_connect_ext: None,
_t: PhantomData,
}
}
Expand Down Expand Up @@ -138,6 +142,7 @@ where
expect: expect.into_factory(),
upgrade: self.upgrade,
on_connect: self.on_connect,
on_connect_ext: self.on_connect_ext,
_t: PhantomData,
}
}
Expand Down Expand Up @@ -167,14 +172,16 @@ where
expect: self.expect,
upgrade: Some(upgrade.into_factory()),
on_connect: self.on_connect,
on_connect_ext: self.on_connect_ext,
_t: PhantomData,
}
}

/// Set on-connect callback.
///
/// It get called once per connection and result of the call
/// get stored to the request's extensions.
/// Called once per connection. Return value of the call is stored in request extensions.
///
/// *SOFT DEPRECATED*: Prefer the `on_connect_ext` style callback.
pub fn on_connect<F, I>(mut self, f: F) -> Self
where
F: Fn(&T) -> I + 'static,
Expand All @@ -184,7 +191,20 @@ where
self
}

/// Finish service configuration and create *http service* for HTTP/1 protocol.
/// Sets the callback to be run on connection establishment.
///
/// Has mutable access to a data container that will be merged into request extensions.
/// This enables transport layer data (like client certificates) to be accessed in middleware
/// and handlers.
pub fn on_connect_ext<F>(mut self, f: F) -> Self
where
F: Fn(&T, &mut Extensions) + 'static,
{
self.on_connect_ext = Some(Rc::new(f));
self
}

/// Finish service configuration and create a HTTP Service for HTTP/1 protocol.
pub fn h1<F, B>(self, service: F) -> H1Service<T, S, B, X, U>
where
B: MessageBody,
Expand All @@ -200,13 +220,15 @@ where
self.secure,
self.local_addr,
);

H1Service::with_config(cfg, service.into_factory())
.expect(self.expect)
.upgrade(self.upgrade)
.on_connect(self.on_connect)
.on_connect_ext(self.on_connect_ext)
}

/// Finish service configuration and create *http service* for HTTP/2 protocol.
/// Finish service configuration and create a HTTP service for HTTP/2 protocol.
pub fn h2<F, B>(self, service: F) -> H2Service<T, S, B>
where
B: MessageBody + 'static,
Expand All @@ -223,7 +245,10 @@ where
self.secure,
self.local_addr,
);
H2Service::with_config(cfg, service.into_factory()).on_connect(self.on_connect)

H2Service::with_config(cfg, service.into_factory())
.on_connect(self.on_connect)
.on_connect_ext(self.on_connect_ext)
}

/// Finish service configuration and create `HttpService` instance.
Expand All @@ -243,9 +268,11 @@ where
self.secure,
self.local_addr,
);

HttpService::with_config(cfg, service.into_factory())
.expect(self.expect)
.upgrade(self.upgrade)
.on_connect(self.on_connect)
.on_connect_ext(self.on_connect_ext)
}
}
30 changes: 29 additions & 1 deletion actix-http/src/extensions.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::any::{Any, TypeId};
use std::fmt;
use std::{fmt, mem};

use fxhash::FxHashMap;

Expand Down Expand Up @@ -66,6 +66,11 @@ impl Extensions {
pub fn extend(&mut self, other: Extensions) {
self.map.extend(other.map);
}

/// Sets (or overrides) items from `other` into this map.
pub(crate) fn drain_from(&mut self, other: &mut Self) {
self.map.extend(mem::take(&mut other.map));
}
}

impl fmt::Debug for Extensions {
Expand Down Expand Up @@ -213,4 +218,27 @@ mod tests {
assert_eq!(extensions.get(), Some(&20u8));
assert_eq!(extensions.get_mut(), Some(&mut 20u8));
}

#[test]
fn test_drain_from() {
let mut ext = Extensions::new();
ext.insert(2isize);

let mut more_ext = Extensions::new();

more_ext.insert(5isize);
more_ext.insert(5usize);

assert_eq!(ext.get::<isize>(), Some(&2isize));
assert_eq!(ext.get::<usize>(), None);
assert_eq!(more_ext.get::<isize>(), Some(&5isize));
assert_eq!(more_ext.get::<usize>(), Some(&5usize));

ext.drain_from(&mut more_ext);

assert_eq!(ext.get::<isize>(), Some(&5isize));
assert_eq!(ext.get::<usize>(), Some(&5usize));
assert_eq!(more_ext.get::<isize>(), None);
assert_eq!(more_ext.get::<usize>(), None);
}
}
18 changes: 16 additions & 2 deletions actix-http/src/h1/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ use bytes::{Buf, BytesMut};
use log::{error, trace};
use pin_project::pin_project;

use crate::body::{Body, BodySize, MessageBody, ResponseBody};
use crate::cloneable::CloneableService;
use crate::config::ServiceConfig;
use crate::error::{DispatchError, Error};
Expand All @@ -21,6 +20,10 @@ use crate::helpers::DataFactory;
use crate::httpmessage::HttpMessage;
use crate::request::Request;
use crate::response::Response;
use crate::{
body::{Body, BodySize, MessageBody, ResponseBody},
Extensions,
};

use super::codec::Codec;
use super::payload::{Payload, PayloadSender, PayloadStatus};
Expand Down Expand Up @@ -88,6 +91,7 @@ where
expect: CloneableService<X>,
upgrade: Option<CloneableService<U>>,
on_connect: Option<Box<dyn DataFactory>>,
on_connect_data: Extensions,
flags: Flags,
peer_addr: Option<net::SocketAddr>,
error: Option<DispatchError>,
Expand Down Expand Up @@ -167,14 +171,15 @@ where
U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display,
{
/// Create http/1 dispatcher.
/// Create HTTP/1 dispatcher.
pub(crate) fn new(
stream: T,
config: ServiceConfig,
service: CloneableService<S>,
expect: CloneableService<X>,
upgrade: Option<CloneableService<U>>,
on_connect: Option<Box<dyn DataFactory>>,
on_connect_data: Extensions,
peer_addr: Option<net::SocketAddr>,
) -> Self {
Dispatcher::with_timeout(
Expand All @@ -187,6 +192,7 @@ where
expect,
upgrade,
on_connect,
on_connect_data,
peer_addr,
)
}
Expand All @@ -202,6 +208,7 @@ where
expect: CloneableService<X>,
upgrade: Option<CloneableService<U>>,
on_connect: Option<Box<dyn DataFactory>>,
on_connect_data: Extensions,
peer_addr: Option<net::SocketAddr>,
) -> Self {
let keepalive = config.keep_alive_enabled();
Expand Down Expand Up @@ -234,6 +241,7 @@ where
expect,
upgrade,
on_connect,
on_connect_data,
flags,
peer_addr,
ka_expire,
Expand Down Expand Up @@ -526,11 +534,15 @@ where
let pl = this.codec.message_type();
req.head_mut().peer_addr = *this.peer_addr;

// DEPRECATED
// set on_connect data
if let Some(ref on_connect) = this.on_connect {
on_connect.set(&mut req.extensions_mut());
}

// merge on_connect_ext data into request extensions
req.extensions_mut().drain_from(this.on_connect_data);

if pl == MessageType::Stream && this.upgrade.is_some() {
this.messages.push_back(DispatcherMessage::Upgrade(req));
break;
Expand Down Expand Up @@ -927,8 +939,10 @@ mod tests {
CloneableService::new(ExpectHandler),
None,
None,
Extensions::new(),
None,
);

match Pin::new(&mut h1).poll(cx) {
Poll::Pending => panic!(),
Poll::Ready(res) => assert!(res.is_err()),
Expand Down

0 comments on commit 9963a5e

Please sign in to comment.