Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
robjtede committed Oct 26, 2020
1 parent b31803e commit 320fa71
Show file tree
Hide file tree
Showing 9 changed files with 96 additions and 85 deletions.
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
* Implement `exclude_regex` for Logger middleware. [#1723]
* Add request-local data extractor `web::ReqData`. [#1748]
* Add `app_data` to `ServiceConfig`. [#1757]
* Expose `on_connect` for access to the connection stream before request is handled. [#1748]

### Changed
* Print non-configured `Data<T>` type when attempting extraction. [#1743]
Expand All @@ -15,6 +16,7 @@
[#1743]: https://github.com/actix/actix-web/pull/1743
[#1748]: https://github.com/actix/actix-web/pull/1748
[#1750]: https://github.com/actix/actix-web/pull/1750
[#1754]: https://github.com/actix/actix-web/pull/1754


## 3.1.0 - 2020-09-29
Expand Down
6 changes: 2 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,12 @@ required-features = ["compress"]

[[example]]
name = "on_connect"
required-features = ["rustls"]
required-features = []

[[example]]
name = "client"
required-features = ["rustls"]


[dependencies]
actix-codec = "0.3.0"
actix-service = "1.0.6"
Expand Down Expand Up @@ -114,12 +113,11 @@ tinyvec = { version = "1", features = ["alloc"] }
actix = "0.10.0"
actix-http = { version = "2.0.0", features = ["actors"] }
rand = "0.7"
env_logger = "0.7"
env_logger = "0.8"
serde_derive = "1.0"
brotli2 = "0.3.2"
flate2 = "1.0.13"
criterion = "0.3"
webpki-roots = "0.20"

[profile.release]
lto = true
Expand Down
7 changes: 7 additions & 0 deletions actix-http/CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
# Changes

## Unreleased - 2020-xx-xx
### Added
* Added more flexible `on_connect_ext` methods for on-connect handling. [#1754]

### Changed
* Upgrade `base64` to `0.13`.
* Upgrade `pin-project` to `1.0`.

[#1754]: https://github.com/actix/actix-web/pull/1754


## 2.0.0 - 2020-09-11
* No significant changes from `2.0.0-beta.4`.

Expand Down
11 changes: 9 additions & 2 deletions actix-http/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -694,9 +694,16 @@ where
} else {
panic!()
};
let (_, cfg, srv, on_connect, on_connect_data, peer_addr) = data.take().unwrap();
let (_, cfg, srv, on_connect, on_connect_data, peer_addr) =
data.take().unwrap();
self.set(State::H2(Dispatcher::new(
srv, conn, on_connect, on_connect_data, cfg, None, peer_addr,
srv,
conn,
on_connect,
on_connect_data,
cfg,
None,
peer_addr,
)));
self.poll(cx)
}
Expand Down
2 changes: 2 additions & 0 deletions actix-http/tests/test_openssl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -411,8 +411,10 @@ async fn test_h2_on_connect() {
let srv = test_server(move || {
HttpService::build()
.on_connect(|_| 10usize)
.on_connect_ext(|_, data| data.insert(20isize))
.h2(|req: Request| {
assert!(req.extensions().contains::<usize>());
assert!(req.extensions().contains::<isize>());
ok::<_, ()>(Response::Ok().finish())
})
.openssl(ssl_acceptor())
Expand Down
2 changes: 2 additions & 0 deletions actix-http/tests/test_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -663,8 +663,10 @@ async fn test_h1_on_connect() {
let srv = test_server(|| {
HttpService::build()
.on_connect(|_| 10usize)
.on_connect_ext(|_, data| data.insert(20isize))
.h1(|req: Request| {
assert!(req.extensions().contains::<usize>());
assert!(req.extensions().contains::<isize>());
future::ok::<_, ()>(Response::Ok().finish())
})
.tcp()
Expand Down
101 changes: 24 additions & 77 deletions examples/on_connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,100 +4,47 @@
//! For an example of extracting a client TLS certificate, see:
//! <https://github.com/actix/examples/tree/HEAD/rustls-client-cert>

use std::{any::Any, env, fs::File, io::BufReader};
use std::{any::Any, env, io, net::SocketAddr};

use actix_tls::rustls::{ServerConfig, TlsStream};
use actix_web::{
dev::Extensions, rt::net::TcpStream, web, App, HttpResponse, HttpServer, Responder,
};
use log::info;
use rust_tls::{
internal::pemfile::{certs, pkcs8_private_keys},
AllowAnyAnonymousOrAuthenticatedClient, Certificate, RootCertStore, Session,
};

const CA_CERT: &str = "examples/certs/rootCA.pem";
const SERVER_CERT: &str = "examples/certs/server-cert.pem";
const SERVER_KEY: &str = "examples/certs/server-key.pem";
use actix_web::{dev::Extensions, rt::net::TcpStream, web, App, HttpServer};

#[derive(Debug, Clone)]
struct ConnectionInfo(String);

async fn route_whoami(
conn_info: web::ReqData<ConnectionInfo>,
client_cert: Option<web::ReqData<Certificate>>,
) -> impl Responder {
if let Some(cert) = client_cert {
HttpResponse::Ok().body(format!("{:?}\n\n{:?}", &conn_info, &cert))
} else {
HttpResponse::Unauthorized().body("No client certificate provided.")
}
struct ConnectionInfo {
bind: SocketAddr,
peer: SocketAddr,
ttl: Option<u32>,
}

fn get_client_cert(connection: &dyn Any, data: &mut Extensions) {
if let Some(tls_socket) = connection.downcast_ref::<TlsStream<TcpStream>>() {
info!("TLS on_connect");

let (socket, tls_session) = tls_socket.get_ref();

let msg = format!(
"local_addr={:?}; peer_addr={:?}",
socket.local_addr(),
socket.peer_addr()
);

data.insert(ConnectionInfo(msg));

if let Some(mut certs) = tls_session.get_peer_certificates() {
info!("client certificate found");

// insert a `rustls::Certificate` into request data
data.insert(certs.pop().unwrap());
}
} else if let Some(socket) = connection.downcast_ref::<TcpStream>() {
info!("plaintext on_connect");

let msg = format!(
"local_addr={:?}; peer_addr={:?}",
socket.local_addr(),
socket.peer_addr()
);
async fn route_whoami(conn_info: web::ReqData<ConnectionInfo>) -> String {
format!(
"Here is some info about your connection:\n\n{:#?}",
conn_info
)
}

data.insert(ConnectionInfo(msg));
fn get_conn_info(connection: &dyn Any, data: &mut Extensions) {
if let Some(sock) = connection.downcast_ref::<TcpStream>() {
data.insert(ConnectionInfo {
bind: sock.local_addr().unwrap(),
peer: sock.peer_addr().unwrap(),
ttl: sock.ttl().ok(),
});
} else {
unreachable!("socket should be TLS or plaintext");
unreachable!("connection should only be plaintext since no TLS is set up");
}
}

#[actix_web::main]
async fn main() -> std::io::Result<()> {
async fn main() -> io::Result<()> {
if env::var("RUST_LOG").is_err() {
env::set_var("RUST_LOG", "info");
}

env_logger::init();

let ca_cert = &mut BufReader::new(File::open(CA_CERT)?);

let mut cert_store = RootCertStore::empty();
cert_store
.add_pem_file(ca_cert)
.expect("root CA not added to store");
let client_auth = AllowAnyAnonymousOrAuthenticatedClient::new(cert_store);

let mut config = ServerConfig::new(client_auth);

let cert_file = &mut BufReader::new(File::open(SERVER_CERT)?);
let key_file = &mut BufReader::new(File::open(SERVER_KEY)?);

let cert_chain = certs(cert_file).unwrap();
let mut keys = pkcs8_private_keys(key_file).unwrap();
config.set_single_cert(cert_chain, keys.remove(0)).unwrap();

HttpServer::new(|| App::new().route("/", web::get().to(route_whoami)))
.on_connect(get_client_cert)
.bind(("localhost", 8080))?
.bind_rustls(("localhost", 8443), config)?
HttpServer::new(|| App::new().default_service(web::to(route_whoami)))
.on_connect(get_conn_info)
.bind(("127.0.0.1", 8080))?
.workers(1)
.run()
.await
Expand Down
5 changes: 3 additions & 2 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,9 @@ where
}
}

/// Sets function that will be called once for each connection.
/// It will receive &Any, which contains underlying connection type.
/// Sets function that will be called once before each connection is handled.
/// It will receive a `&std::any::Any`, which contains underlying connection type and an
/// [Extensions] container so that request-local data can be passed to middleware and handlers.
///
/// For example:
/// - `actix_tls::openssl::SslStream<actix_web::rt::net::TcpStream>` when using openssl.
Expand Down
45 changes: 45 additions & 0 deletions tests/test_on_connect.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
use std::{any::Any, env, io, net::SocketAddr};

use actix_web::{dev::Extensions, rt::net::TcpStream, web, App, HttpServer};

#[derive(Debug, Clone)]
struct ConnectionInfo {
bind: SocketAddr,
peer: SocketAddr,
ttl: Option<u32>,
}

async fn route_whoami(conn_info: web::ReqData<ConnectionInfo>) -> String {
format!(
"Here is some info about your connection:\n\n{:#?}",
conn_info
)
}

fn get_conn_info(connection: &dyn Any, data: &mut Extensions) {
if let Some(sock) = connection.downcast_ref::<TcpStream>() {
data.insert(ConnectionInfo {
bind: sock.local_addr().unwrap(),
peer: sock.peer_addr().unwrap(),
ttl: sock.ttl().ok(),
});
} else {
unreachable!("connection should only be plaintext since no TLS is set up");
}
}

#[actix_web::main]
async fn main() -> io::Result<()> {
if env::var("RUST_LOG").is_err() {
env::set_var("RUST_LOG", "info");
}

env_logger::init();

HttpServer::new(|| App::new().default_service(web::to(route_whoami)))
.on_connect(get_conn_info)
.bind(("127.0.0.1", 8080))?
.workers(1)
.run()
.await
}

0 comments on commit 320fa71

Please sign in to comment.