Skip to content

Commit

Permalink
feat(client): Add connection capturing API to hyper-util (#112)
Browse files Browse the repository at this point in the history
- rework features to allow enabling only tokio/sync for the client
- a `capture_connection` API
  • Loading branch information
rcoh committed Apr 15, 2024
1 parent f87fe0d commit a77d866
Show file tree
Hide file tree
Showing 5 changed files with 235 additions and 5 deletions.
6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pin-project-lite = "0.2.4"
futures-channel = { version = "0.3", optional = true }
socket2 = { version = "0.5", optional = true, features = ["all"] }
tracing = { version = "0.1", default-features = false, features = ["std"], optional = true }
tokio = { version = "1", optional = true, features = ["net", "rt", "time"] }
tokio = { version = "1", optional = true, default-features = false }
tower-service ={ version = "0.3", optional = true }
tower = { version = "0.4.1", optional = true, default-features = false, features = ["make", "util"] }

Expand Down Expand Up @@ -57,7 +57,7 @@ full = [
]

client = ["hyper/client", "dep:tracing", "dep:futures-channel", "dep:tower", "dep:tower-service"]
client-legacy = ["client", "dep:socket2"]
client-legacy = ["client", "dep:socket2", "tokio/sync"]

server = ["hyper/server"]
server-auto = ["server", "http1", "http2"]
Expand All @@ -67,7 +67,7 @@ service = ["dep:tower", "dep:tower-service"]
http1 = ["hyper/http1"]
http2 = ["hyper/http2"]

tokio = ["dep:tokio"]
tokio = ["dep:tokio", "tokio/net", "tokio/rt", "tokio/time"]

# internal features used in CI
__internal_happy_eyeballs_tests = []
Expand Down
5 changes: 5 additions & 0 deletions src/client/legacy/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use hyper::rt::Timer;
use hyper::{body::Body, Method, Request, Response, Uri, Version};
use tracing::{debug, trace, warn};

use super::connect::capture::CaptureConnectionExtension;
#[cfg(feature = "tokio")]
use super::connect::HttpConnector;
use super::connect::{Alpn, Connect, Connected, Connection};
Expand Down Expand Up @@ -265,6 +266,10 @@ where
) -> Result<Response<hyper::body::Incoming>, Error> {
let mut pooled = self.connection_for(pool_key).await?;

req.extensions_mut()
.get_mut::<CaptureConnectionExtension>()
.map(|conn| conn.set(&pooled.conn_info));

if pooled.is_http1() {
if req.version() == Version::HTTP_2 {
warn!("Connection is HTTP/1, but request requires HTTP/2");
Expand Down
191 changes: 191 additions & 0 deletions src/client/legacy/connect/capture.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
use std::{ops::Deref, sync::Arc};

use http::Request;
use tokio::sync::watch;

use super::Connected;

/// [`CaptureConnection`] allows callers to capture [`Connected`] information
///
/// To capture a connection for a request, use [`capture_connection`].
#[derive(Debug, Clone)]
pub struct CaptureConnection {
rx: watch::Receiver<Option<Connected>>,
}

/// Capture the connection for a given request
///
/// When making a request with Hyper, the underlying connection must implement the [`Connection`] trait.
/// [`capture_connection`] allows a caller to capture the returned [`Connected`] structure as soon
/// as the connection is established.
///
/// *Note*: If establishing a connection fails, [`CaptureConnection::connection_metadata`] will always return none.
///
/// # Examples
///
/// **Synchronous access**:
/// The [`CaptureConnection::connection_metadata`] method allows callers to check if a connection has been
/// established. This is ideal for situations where you are certain the connection has already
/// been established (e.g. after the response future has already completed).
/// ```rust
/// use hyper_util::client::legacy::connect::capture_connection;
/// let mut request = http::Request::builder()
/// .uri("http://foo.com")
/// .body(())
/// .unwrap();
///
/// let captured_connection = capture_connection(&mut request);
/// // some time later after the request has been sent...
/// let connection_info = captured_connection.connection_metadata();
/// println!("we are connected! {:?}", connection_info.as_ref());
/// ```
///
/// **Asynchronous access**:
/// The [`CaptureConnection::wait_for_connection_metadata`] method returns a future resolves as soon as the
/// connection is available.
///
/// ```rust
/// # #[cfg(feature = "tokio")]
/// # async fn example() {
/// use hyper_util::client::legacy::connect::capture_connection;
/// use hyper_util::client::legacy::Client;
/// use hyper_util::rt::TokioExecutor;
/// use bytes::Bytes;
/// use http_body_util::Empty;
/// let mut request = http::Request::builder()
/// .uri("http://foo.com")
/// .body(Empty::<Bytes>::new())
/// .unwrap();
///
/// let mut captured = capture_connection(&mut request);
/// tokio::task::spawn(async move {
/// let connection_info = captured.wait_for_connection_metadata().await;
/// println!("we are connected! {:?}", connection_info.as_ref());
/// });
///
/// let client = Client::builder(TokioExecutor::new()).build_http();
/// client.request(request).await.expect("request failed");
/// # }
/// ```
pub fn capture_connection<B>(request: &mut Request<B>) -> CaptureConnection {
let (tx, rx) = CaptureConnection::new();
request.extensions_mut().insert(tx);
rx
}

/// TxSide for [`CaptureConnection`]
///
/// This is inserted into `Extensions` to allow Hyper to back channel connection info
#[derive(Clone)]
pub(crate) struct CaptureConnectionExtension {
tx: Arc<watch::Sender<Option<Connected>>>,
}

impl CaptureConnectionExtension {
pub(crate) fn set(&self, connected: &Connected) {
self.tx.send_replace(Some(connected.clone()));
}
}

impl CaptureConnection {
/// Internal API to create the tx and rx half of [`CaptureConnection`]
pub(crate) fn new() -> (CaptureConnectionExtension, Self) {
let (tx, rx) = watch::channel(None);
(
CaptureConnectionExtension { tx: Arc::new(tx) },
CaptureConnection { rx },
)
}

/// Retrieve the connection metadata, if available
pub fn connection_metadata(&self) -> impl Deref<Target = Option<Connected>> + '_ {
self.rx.borrow()
}

/// Wait for the connection to be established
///
/// If a connection was established, this will always return `Some(...)`. If the request never
/// successfully connected (e.g. DNS resolution failure), this method will never return.
pub async fn wait_for_connection_metadata(
&mut self,
) -> impl Deref<Target = Option<Connected>> + '_ {
if self.rx.borrow().is_some() {
return self.rx.borrow();
}
let _ = self.rx.changed().await;
self.rx.borrow()
}
}

#[cfg(all(test, not(miri)))]
mod test {
use super::*;

#[test]
fn test_sync_capture_connection() {
let (tx, rx) = CaptureConnection::new();
assert!(
rx.connection_metadata().is_none(),
"connection has not been set"
);
tx.set(&Connected::new().proxy(true));
assert_eq!(
rx.connection_metadata()
.as_ref()
.expect("connected should be set")
.is_proxied(),
true
);

// ensure it can be called multiple times
assert_eq!(
rx.connection_metadata()
.as_ref()
.expect("connected should be set")
.is_proxied(),
true
);
}

#[tokio::test]
async fn async_capture_connection() {
let (tx, mut rx) = CaptureConnection::new();
assert!(
rx.connection_metadata().is_none(),
"connection has not been set"
);
let test_task = tokio::spawn(async move {
assert_eq!(
rx.wait_for_connection_metadata()
.await
.as_ref()
.expect("connection should be set")
.is_proxied(),
true
);
// can be awaited multiple times
assert!(
rx.wait_for_connection_metadata().await.is_some(),
"should be awaitable multiple times"
);

assert_eq!(rx.connection_metadata().is_some(), true);
});
// can't be finished, we haven't set the connection yet
assert_eq!(test_task.is_finished(), false);
tx.set(&Connected::new().proxy(true));

assert!(test_task.await.is_ok());
}

#[tokio::test]
async fn capture_connection_sender_side_dropped() {
let (tx, mut rx) = CaptureConnection::new();
assert!(
rx.connection_metadata().is_none(),
"connection has not been set"
);
drop(tx);
assert!(rx.wait_for_connection_metadata().await.is_none());
}
}
4 changes: 3 additions & 1 deletion src/client/legacy/connect/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ pub mod dns;
#[cfg(feature = "tokio")]
mod http;

pub(crate) mod capture;
pub use capture::{capture_connection, CaptureConnection};

pub use self::sealed::Connect;

/// Describes a type returned by a connector.
Expand Down Expand Up @@ -169,7 +172,6 @@ impl Connected {

// Don't public expose that `Connected` is `Clone`, unsure if we want to
// keep that contract...
#[cfg(feature = "http2")]
pub(super) fn clone(&self) -> Connected {
Connected {
alpn: self.alpn,
Expand Down
34 changes: 33 additions & 1 deletion tests/legacy_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use http_body_util::{Empty, Full, StreamBody};
use hyper::body::Bytes;
use hyper::body::Frame;
use hyper::Request;
use hyper_util::client::legacy::connect::HttpConnector;
use hyper_util::client::legacy::connect::{capture_connection, HttpConnector};
use hyper_util::client::legacy::Client;
use hyper_util::rt::{TokioExecutor, TokioIo};

Expand Down Expand Up @@ -876,3 +876,35 @@ fn alpn_h2() {
);
drop(client);
}

#[cfg(not(miri))]
#[test]
fn capture_connection_on_client() {
let _ = pretty_env_logger::try_init();

let rt = runtime();
let connector = DebugConnector::new();

let client = Client::builder(TokioExecutor::new()).build(connector);

let server = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = server.local_addr().unwrap();
thread::spawn(move || {
let mut sock = server.accept().unwrap().0;
//drop(server);
sock.set_read_timeout(Some(Duration::from_secs(5))).unwrap();
sock.set_write_timeout(Some(Duration::from_secs(5)))
.unwrap();
let mut buf = [0; 4096];
sock.read(&mut buf).expect("read 1");
sock.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
.expect("write 1");
});
let mut req = Request::builder()
.uri(&*format!("http://{}/a", addr))
.body(Empty::<Bytes>::new())
.unwrap();
let captured_conn = capture_connection(&mut req);
rt.block_on(client.request(req)).expect("200 OK");
assert!(captured_conn.connection_metadata().is_some());
}

0 comments on commit a77d866

Please sign in to comment.