diff --git a/src/connector.rs b/src/connector.rs index e97be1d..076870e 100644 --- a/src/connector.rs +++ b/src/connector.rs @@ -21,6 +21,7 @@ pub struct HttpsConnector { force_https: bool, http: T, tls_config: Arc, + override_server_name: Option, } impl fmt::Debug for HttpsConnector { @@ -40,6 +41,7 @@ where force_https: false, http, tls_config: cfg.into(), + override_server_name: None, } } } @@ -83,10 +85,17 @@ where Box::pin(f) } else if sch == &http::uri::Scheme::HTTPS { let cfg = self.tls_config.clone(); - let hostname = dst - .host() - .unwrap_or_default() - .to_string(); + let hostname = match self.override_server_name.as_deref() { + Some(h) => h, + None => dst.host().unwrap_or_default(), + }; + let hostname = match rustls::ServerName::try_from(hostname) { + Ok(dnsname) => dnsname, + Err(_) => { + let err = io::Error::new(io::ErrorKind::Other, "invalid dnsname"); + return Box::pin(async move { Err(Box::new(err).into()) }); + } + }; let connecting_future = self.http.call(dst); let f = async move { @@ -94,10 +103,8 @@ where .await .map_err(Into::into)?; let connector = TlsConnector::from(cfg); - let dnsname = rustls::ServerName::try_from(hostname.as_str()) - .map_err(|_| io::Error::new(io::ErrorKind::Other, "invalid dnsname"))?; let tls = connector - .connect(dnsname, tcp) + .connect(hostname, tcp) .await .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; Ok(MaybeHttpsStream::Https(tls)) diff --git a/src/connector/builder.rs b/src/connector/builder.rs index 0398f8b..dd4f48c 100644 --- a/src/connector/builder.rs +++ b/src/connector/builder.rs @@ -106,6 +106,7 @@ impl ConnectorBuilder { ConnectorBuilder(WantsProtocols1 { tls_config: self.0.tls_config, https_only: true, + override_server_name: None, }) } @@ -117,6 +118,7 @@ impl ConnectorBuilder { ConnectorBuilder(WantsProtocols1 { tls_config: self.0.tls_config, https_only: false, + override_server_name: None, }) } } @@ -128,6 +130,7 @@ impl ConnectorBuilder { pub struct WantsProtocols1 { tls_config: ClientConfig, https_only: bool, + override_server_name: Option, } impl WantsProtocols1 { @@ -136,6 +139,7 @@ impl WantsProtocols1 { force_https: self.https_only, http: conn, tls_config: std::sync::Arc::new(self.tls_config), + override_server_name: self.override_server_name, } } @@ -169,6 +173,20 @@ impl ConnectorBuilder { enable_http1: false, }) } + + /// Override server name for the TLS stack + /// + /// By default, for each connection hyper-rustls will extract host portion + /// of the destination URL and verify that server certificate contains + /// this value. + /// + /// If this method is called, hyper-rustls will instead verify that server + /// certificate contains `override_server_name`. Domain name included in + /// the URL will not affect certificate validation. + pub fn with_server_name(mut self, override_server_name: String) -> Self { + self.0.override_server_name = Some(override_server_name); + self + } } /// State of a builder with HTTP1 enabled, that may have some other