Skip to content

Commit

Permalink
Allow overriding server name
Browse files Browse the repository at this point in the history
Signed-off-by: Mikail Bagishov <bagishov.mikail@yandex.ru>
  • Loading branch information
MikailBag committed Sep 17, 2022
1 parent cfaff38 commit 76ced6d
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 7 deletions.
25 changes: 18 additions & 7 deletions src/connector.rs
Expand Up @@ -21,6 +21,7 @@ pub struct HttpsConnector<T> {
force_https: bool,
http: T,
tls_config: Arc<rustls::ClientConfig>,
server_name_override: Option<String>,
}

impl<T> fmt::Debug for HttpsConnector<T> {
Expand All @@ -40,6 +41,7 @@ where
force_https: false,
http,
tls_config: cfg.into(),
server_name_override: None,
}
}
}
Expand Down Expand Up @@ -83,21 +85,30 @@ where
Box::pin(f)
} else if sch == &http::uri::Scheme::HTTPS {
let cfg = self.tls_config.clone();
let hostname = dst
.host()
.unwrap_or_default()
.to_string();
let hostname = match self.server_name_override.as_deref() {
Some(h) => h,
None => dst
.host()
.unwrap_or_default()
};
let hostname = match rustls::ServerName::try_from(hostname) {
Ok(dnsname) => dnsname,
Err(_) => {
let err = io::Error::new(io::ErrorKind::Other, "invalid dnsname");
let err: Box<dyn std::error::Error + Send + Sync + 'static> = Box::new(err);
let res = std::future::ready(Err(err));
return Box::pin(res);
}
};
let connecting_future = self.http.call(dst);

let f = async move {
let tcp = connecting_future
.await
.map_err(Into::into)?;
let connector = TlsConnector::from(cfg);
let dnsname = rustls::ServerName::try_from(hostname.as_str())
.map_err(|_| io::Error::new(io::ErrorKind::Other, "invalid dnsname"))?;
let tls = connector
.connect(dnsname, tcp)
.connect(hostname, tcp)
.await
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
Ok(MaybeHttpsStream::Https(tls))
Expand Down
31 changes: 31 additions & 0 deletions src/connector/builder.rs
Expand Up @@ -106,6 +106,7 @@ impl ConnectorBuilder<WantsSchemes> {
ConnectorBuilder(WantsProtocols1 {
tls_config: self.0.tls_config,
https_only: true,
server_name_override: None,
})
}

Expand All @@ -117,6 +118,7 @@ impl ConnectorBuilder<WantsSchemes> {
ConnectorBuilder(WantsProtocols1 {
tls_config: self.0.tls_config,
https_only: false,
server_name_override: None
})
}
}
Expand All @@ -128,6 +130,7 @@ impl ConnectorBuilder<WantsSchemes> {
pub struct WantsProtocols1 {
tls_config: ClientConfig,
https_only: bool,
server_name_override: Option<String>
}

impl WantsProtocols1 {
Expand All @@ -136,6 +139,7 @@ impl WantsProtocols1 {
force_https: self.https_only,
http: conn,
tls_config: std::sync::Arc::new(self.tls_config),
server_name_override: None
}
}

Expand Down Expand Up @@ -169,6 +173,15 @@ impl ConnectorBuilder<WantsProtocols1> {
enable_http1: false,
})
}

/// Override expected server name
///
/// If called, server certificates will be validated against `server_name_override`,
/// and host portion of URL will not be used for server authentication.
pub fn override_server_name(mut self, server_name_override: String) -> Self {
self.0.server_name_override = Some(server_name_override);
self
}
}

/// State of a builder with HTTP1 enabled, that may have some other
Expand All @@ -195,6 +208,15 @@ impl ConnectorBuilder<WantsProtocols2> {
})
}

/// Override expected server name
///
/// If called, server certificates will be validated against `server_name_override`,
/// and host portion of URL will not be used for server authentication.
pub fn override_server_name(mut self, server_name_override: String) -> Self {
self.0.inner.server_name_override = Some(server_name_override);
self
}

/// This builds an [`HttpsConnector`] built on hyper's default [`HttpConnector`]
#[cfg(feature = "tokio-runtime")]
pub fn build(self) -> HttpsConnector<HttpConnector> {
Expand Down Expand Up @@ -226,6 +248,15 @@ pub struct WantsProtocols3 {

#[cfg(feature = "http2")]
impl ConnectorBuilder<WantsProtocols3> {
/// Override expected server name
///
/// If called, server certificates will be validated against `server_name_override`,
/// and host portion of URL will not be used for server authentication.
pub fn override_server_name(mut self, server_name_override: String) -> Self {
self.0.inner.server_name_override = Some(server_name_override);
self
}

/// This builds an [`HttpsConnector`] built on hyper's default [`HttpConnector`]
#[cfg(feature = "tokio-runtime")]
pub fn build(self) -> HttpsConnector<HttpConnector> {
Expand Down

0 comments on commit 76ced6d

Please sign in to comment.