diff --git a/Cargo.toml b/Cargo.toml index 1e517f824..77d06ef1b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -87,6 +87,8 @@ bytes = "1.0" serde = "1.0" serde_urlencoded = "0.7.1" tower-service = "0.3" +async-trait = "0.1" +dyn-clone = "1" futures-core = { version = "0.3.0", default-features = false } futures-util = { version = "0.3.0", default-features = false } diff --git a/examples/custom_proxy_protocol.rs b/examples/custom_proxy_protocol.rs new file mode 100644 index 000000000..ce06c10e0 --- /dev/null +++ b/examples/custom_proxy_protocol.rs @@ -0,0 +1,98 @@ +use std::{error::Error, io::Write, pin::pin}; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + net::TcpStream, +}; + +use async_trait::async_trait; +use http::Uri; +use reqwest::{AsyncStreamWrapper, Client, CustomProxyProtocol, Proxy}; + +#[tokio::main] +async fn main() { + let proxy: Box = Box::new(Example()); + let client = Client::builder() + .proxy(Proxy::all(proxy).unwrap()) + .http1_only() + .build() + .unwrap(); + let mut response = client + .get("http://www.hal.ipc.i.u-tokyo.ac.jp/~nakada/prog2015/alice.txt") + .send() + .await + .unwrap(); + + let mut stdout = std::io::stdout(); + while let Some(chunk) = response.chunk().await.unwrap() { + stdout.write_all(&chunk).unwrap(); + } + stdout.flush().unwrap(); +} + +#[derive(Clone)] +struct Example(); +#[async_trait] +impl CustomProxyProtocol for Example { + async fn connect( + &self, + dst: Uri, + ) -> Result> { + let host = dst.host().ok_or("host is None")?; + let port = match (dst.scheme_str(), dst.port_u16()) { + (_, Some(p)) => p, + (Some("http"), None) => 80, + (Some("https"), None) => 443, + _ => return Err("scheme is unknown and port is None.".into()), + }; + eprintln!("Connecting to {}:{}", host, port); + Ok(AsyncStreamWrapper::new( + WrapStream(TcpStream::connect(format!("{}:{}", host, port)).await?), + false, + )) + } +} + +struct WrapStream(RW) +where + RW: AsyncRead + AsyncWrite + Send + Unpin + 'static; +impl AsyncRead for WrapStream +where + RW: AsyncRead + AsyncWrite + Send + Unpin + 'static, +{ + fn poll_read( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + eprintln!("read"); + pin!(&mut self.0).poll_read(cx, buf) + } +} +impl AsyncWrite for WrapStream +where + RW: AsyncRead + AsyncWrite + Send + Unpin + 'static, +{ + fn poll_write( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + eprintln!("write"); + std::io::stderr().write_all(buf).unwrap(); + pin!(&mut self.0).poll_write(cx, buf) + } + fn poll_flush( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + eprintln!("flush"); + pin!(&mut self.0).poll_flush(cx) + } + fn poll_shutdown( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + eprintln!("shutdown"); + pin!(&mut self.0).poll_shutdown(cx) + } +} diff --git a/src/connect.rs b/src/connect.rs index c171dd18d..5840e3581 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -23,7 +23,7 @@ use self::native_tls_conn::NativeTlsConn; use self::rustls_tls_conn::RustlsTlsConn; use crate::dns::DynResolver; use crate::error::BoxError; -use crate::proxy::{Proxy, ProxyScheme}; +use crate::proxy::{AsyncStreamWrapper, Proxy, ProxyScheme}; pub(crate) type HttpConnector = hyper::client::HttpConnector; @@ -179,7 +179,7 @@ impl Connector { ProxyScheme::Socks5 { remote_dns: true, .. } => socks::DnsResolve::Proxy, - ProxyScheme::Http { .. } | ProxyScheme::Https { .. } => { + _ => { unreachable!("connect_socks is only called for socks proxies"); } }; @@ -319,6 +319,54 @@ impl Connector { let (proxy_dst, _auth) = match proxy_scheme { ProxyScheme::Http { host, auth } => (into_uri(Scheme::HTTP, host), auth), ProxyScheme::Https { host, auth } => (into_uri(Scheme::HTTPS, host), auth), + ProxyScheme::Custom(ref p) => { + let p = p.clone(); + match &self.inner { + #[cfg(feature = "default-tls")] + Inner::DefaultTls(_http, tls) => { + if dst.scheme() == Some(&Scheme::HTTPS) { + let host = dst.host().ok_or("no host in url")?.to_string(); + let conn = p.connect(dst).await?; + let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone()); + let io = tls_connector.connect(&host, conn).await?; + return Ok(Conn { + inner: self.verbose.wrap(NativeTlsConn { inner: io }), + is_proxy: false, + tls_info: self.tls_info, + }); + } + } + #[cfg(feature = "__rustls")] + Inner::RustlsTls { tls, .. } => { + if dst.scheme() == Some(&Scheme::HTTPS) { + use std::convert::TryFrom; + use tokio_rustls::TlsConnector as RustlsConnector; + + let host = dst.host().ok_or("no host in url")?.to_string(); + let conn = p.connect(dst).await?; + let server_name = rustls::ServerName::try_from(host.as_str()) + .map_err(|_| "Invalid Server Name")?; + let tls = tls.clone(); + let io = RustlsConnector::from(tls) + .connect(server_name, conn) + .await?; + return Ok(Conn { + inner: self.verbose.wrap(RustlsTlsConn { inner: io }), + is_proxy: false, + tls_info: false, + }); + } + } + #[cfg(not(feature = "__tls"))] + Inner::Http(_) => (), + } + + return p.connect(dst).await.map(|tcp| Conn { + is_proxy: tcp.is_http_proxy, + inner: self.verbose.wrap(tcp), + tls_info: false, + }); + } #[cfg(feature = "socks")] ProxyScheme::Socks5 { .. } => return self.connect_socks(dst, proxy_scheme).await, }; @@ -466,6 +514,13 @@ trait TlsInfoFactory { fn tls_info(&self) -> Option; } +#[cfg(feature = "__tls")] +impl TlsInfoFactory for AsyncStreamWrapper { + fn tls_info(&self) -> Option { + None + } +} + #[cfg(feature = "__tls")] impl TlsInfoFactory for tokio::net::TcpStream { fn tls_info(&self) -> Option { @@ -474,7 +529,10 @@ impl TlsInfoFactory for tokio::net::TcpStream { } #[cfg(feature = "default-tls")] -impl TlsInfoFactory for hyper_tls::MaybeHttpsStream { +impl TlsInfoFactory for hyper_tls::MaybeHttpsStream +where + RW: AsyncRead + AsyncWrite + Unpin, +{ fn tls_info(&self) -> Option { match self { hyper_tls::MaybeHttpsStream::Https(tls) => tls.tls_info(), @@ -484,20 +542,10 @@ impl TlsInfoFactory for hyper_tls::MaybeHttpsStream { } #[cfg(feature = "default-tls")] -impl TlsInfoFactory for hyper_tls::TlsStream> { - fn tls_info(&self) -> Option { - let peer_certificate = self - .get_ref() - .peer_certificate() - .ok() - .flatten() - .and_then(|c| c.to_der().ok()); - Some(crate::tls::TlsInfo { peer_certificate }) - } -} - -#[cfg(feature = "default-tls")] -impl TlsInfoFactory for tokio_native_tls::TlsStream { +impl TlsInfoFactory for tokio_native_tls::TlsStream +where + RW: AsyncRead + AsyncWrite + Unpin, +{ fn tls_info(&self) -> Option { let peer_certificate = self .get_ref() @@ -510,7 +558,7 @@ impl TlsInfoFactory for tokio_native_tls::TlsStream { } #[cfg(feature = "__rustls")] -impl TlsInfoFactory for hyper_rustls::MaybeHttpsStream { +impl TlsInfoFactory for hyper_rustls::MaybeHttpsStream { fn tls_info(&self) -> Option { match self { hyper_rustls::MaybeHttpsStream::Https(tls) => tls.tls_info(), @@ -520,22 +568,7 @@ impl TlsInfoFactory for hyper_rustls::MaybeHttpsStream { } #[cfg(feature = "__rustls")] -impl TlsInfoFactory for tokio_rustls::TlsStream { - fn tls_info(&self) -> Option { - let peer_certificate = self - .get_ref() - .1 - .peer_certificates() - .and_then(|certs| certs.first()) - .map(|c| c.0.clone()); - Some(crate::tls::TlsInfo { peer_certificate }) - } -} - -#[cfg(feature = "__rustls")] -impl TlsInfoFactory - for tokio_rustls::client::TlsStream> -{ +impl TlsInfoFactory for tokio_rustls::TlsStream { fn tls_info(&self) -> Option { let peer_certificate = self .get_ref() @@ -548,7 +581,7 @@ impl TlsInfoFactory } #[cfg(feature = "__rustls")] -impl TlsInfoFactory for tokio_rustls::client::TlsStream { +impl TlsInfoFactory for tokio_rustls::client::TlsStream { fn tls_info(&self) -> Option { let peer_certificate = self .get_ref() @@ -561,11 +594,10 @@ impl TlsInfoFactory for tokio_rustls::client::TlsStream { } pub(crate) trait AsyncConn: - AsyncRead + AsyncWrite + Connection + Send + Sync + Unpin + 'static + AsyncRead + AsyncWrite + Connection + Send + Unpin + 'static { } - -impl AsyncConn for T {} +impl AsyncConn for T {} #[cfg(feature = "__tls")] trait AsyncConnWithInfo: AsyncConn + TlsInfoFactory {} @@ -824,13 +856,10 @@ mod native_tls_conn { } } - impl TlsInfoFactory for NativeTlsConn { - fn tls_info(&self) -> Option { - self.inner.tls_info() - } - } - - impl TlsInfoFactory for NativeTlsConn> { + impl TlsInfoFactory for NativeTlsConn + where + RW: AsyncRead + AsyncWrite + Unpin, + { fn tls_info(&self) -> Option { self.inner.tls_info() } @@ -917,13 +946,7 @@ mod rustls_tls_conn { } } - impl TlsInfoFactory for RustlsTlsConn { - fn tls_info(&self) -> Option { - self.inner.tls_info() - } - } - - impl TlsInfoFactory for RustlsTlsConn> { + impl TlsInfoFactory for RustlsTlsConn { fn tls_info(&self) -> Option { self.inner.tls_info() } diff --git a/src/lib.rs b/src/lib.rs index 188ba4f02..eafc6d154 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -326,7 +326,7 @@ if_hyper! { pub use self::async_impl::{ Body, Client, ClientBuilder, Request, RequestBuilder, Response, Upgraded, }; - pub use self::proxy::{Proxy,NoProxy}; + pub use self::proxy::{Proxy,NoProxy,CustomProxyProtocol,AsyncStreamWrapper}; #[cfg(feature = "__tls")] // Re-exports, to be removed in a future release pub use tls::{Certificate, Identity}; diff --git a/src/proxy.rs b/src/proxy.rs index 6e1bfcc73..6980e41e5 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -1,11 +1,14 @@ -use std::fmt; +use std::fmt::{self, Debug}; #[cfg(feature = "socks")] use std::net::SocketAddr; use std::sync::Arc; use crate::into_url::{IntoUrl, IntoUrlSealed}; use crate::Url; +use async_trait::async_trait; +use dyn_clone::DynClone; use http::{header::HeaderValue, Uri}; +use hyper::client::connect::{Connected, Connection}; use ipnet::IpNet; use once_cell::sync::Lazy; use percent_encoding::percent_decode; @@ -13,6 +16,7 @@ use std::collections::HashMap; use std::env; use std::error::Error; use std::net::IpAddr; +use std::pin::pin; #[cfg(target_os = "macos")] use system_configuration::{ core_foundation::{ @@ -29,6 +33,7 @@ use system_configuration::{ sys::schema_definitions::kSCPropNetProxiesHTTPSPort, sys::schema_definitions::kSCPropNetProxiesHTTPSProxy, }; +use tokio::io::{AsyncRead, AsyncWrite}; #[cfg(target_os = "windows")] use winreg::enums::HKEY_CURRENT_USER; #[cfg(target_os = "windows")] @@ -96,6 +101,129 @@ pub struct NoProxy { domains: DomainMatcher, } +pub trait AsyncStream: AsyncRead + AsyncWrite + Send + Unpin + 'static {} +impl AsyncStream for RW {} +/// A wrapper for proxy connections and related information. +/// return type of [CustomProxyProtocol::connect]. +pub struct AsyncStreamWrapper { + pub(crate) inner: Box, + pub(crate) is_http_proxy: bool, +} +impl AsyncStreamWrapper { + /// Make a new instance of [AsyncStreamWrapper]. + /// If is_http_proxy is set to true, the connection will be treated as a connection to an http proxy. + /// This does not affect https. + pub fn new(stream: RW, is_http_proxy: bool) -> Self + where + RW: AsyncRead + AsyncWrite + Send + Unpin + 'static, + { + Self { + inner: Box::new(stream), + is_http_proxy, + } + } +} +impl AsyncRead for AsyncStreamWrapper { + fn poll_read( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + pin!(&mut self.inner).poll_read(cx, buf) + } +} +impl AsyncWrite for AsyncStreamWrapper { + fn poll_flush( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + pin!(&mut self.inner).poll_flush(cx) + } + fn poll_shutdown( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + pin!(&mut self.inner).poll_shutdown(cx) + } + fn poll_write( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + pin!(&mut self.inner).poll_write(cx, buf) + } + fn poll_write_vectored( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> std::task::Poll> { + pin!(&mut self.inner).poll_write_vectored(cx, bufs) + } + fn is_write_vectored(&self) -> bool { + self.inner.is_write_vectored() + } +} +impl Debug for AsyncStreamWrapper { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "AsyncStreamWrapper") + } +} +impl Connection for AsyncStreamWrapper { + fn connected(&self) -> Connected { + Connected::new() + } +} + +/// A trait to define custom proxy protocol. +/// `Box` implements `IntoProxyScheme`. +/// # Example +/// ``` +/// use std::error::Error; +/// use tokio::net::TcpStream; +/// +/// use async_trait::async_trait; +/// use http::Uri; +/// use reqwest::{AsyncStreamWrapper, CustomProxyProtocol}; +/// +/// #[derive(Clone)] +/// struct Example(); +/// #[async_trait] +/// impl CustomProxyProtocol for Example { +/// async fn connect( +/// &self, +/// dst: Uri, +/// ) -> Result> { +/// let host = dst.host().ok_or("host is None")?; +/// let port = match (dst.scheme_str(), dst.port_u16()) { +/// (_, Some(p)) => p, +/// (Some("http"), None) => 80, +/// (Some("https"), None) => 443, +/// _ => return Err("scheme is unknown and port is None.".into()), +/// }; +/// eprintln!("Connecting to {}:{}", host, port); +/// Ok(AsyncStreamWrapper::new( +/// TcpStream::connect(format!("{}:{}", host, port)).await?, +/// false, +/// )) +/// } +/// } +/// ``` +#[async_trait] +pub trait CustomProxyProtocol: Sync + Send + DynClone { + /// Establish an TCP connection to the web server. + async fn connect( + &self, + dst: Uri, + ) -> Result>; +} +dyn_clone::clone_trait_object!(CustomProxyProtocol); + +impl IntoProxyScheme for Box { + fn into_proxy_scheme(self) -> crate::Result { + Ok(ProxyScheme::Custom(self)) + } +} + /// A particular scheme used for proxying requests. /// /// For example, HTTP vs SOCKS5 @@ -109,6 +237,7 @@ pub enum ProxyScheme { auth: Option, host: http::uri::Authority, }, + Custom(Box), #[cfg(feature = "socks")] Socks5 { addr: SocketAddr, @@ -121,7 +250,6 @@ impl ProxyScheme { fn maybe_http_auth(&self) -> Option<&HeaderValue> { match self { ProxyScheme::Http { auth, .. } | ProxyScheme::Https { auth, .. } => auth.as_ref(), - #[cfg(feature = "socks")] _ => None, } } @@ -612,6 +740,7 @@ impl ProxyScheme { let header = encode_basic_auth(&username.into(), &password.into()); *auth = Some(header); } + ProxyScheme::Custom(_) => {} #[cfg(feature = "socks")] ProxyScheme::Socks5 { ref mut auth, .. } => { *auth = Some((username.into(), password.into())); @@ -631,6 +760,7 @@ impl ProxyScheme { *auth = update.clone(); } } + ProxyScheme::Custom(_) => {} #[cfg(feature = "socks")] ProxyScheme::Socks5 { .. } => {} } @@ -684,6 +814,7 @@ impl ProxyScheme { match self { ProxyScheme::Http { .. } => "http", ProxyScheme::Https { .. } => "https", + ProxyScheme::Custom(_) => "custom", #[cfg(feature = "socks")] ProxyScheme::Socks5 { .. } => "socks5", } @@ -694,6 +825,7 @@ impl ProxyScheme { match self { ProxyScheme::Http { host, .. } => host.as_str(), ProxyScheme::Https { host, .. } => host.as_str(), + ProxyScheme::Custom(_) => panic!("custom"), #[cfg(feature = "socks")] ProxyScheme::Socks5 { .. } => panic!("socks5"), } @@ -705,6 +837,7 @@ impl fmt::Debug for ProxyScheme { match self { ProxyScheme::Http { auth: _auth, host } => write!(f, "http://{}", host), ProxyScheme::Https { auth: _auth, host } => write!(f, "https://{}", host), + ProxyScheme::Custom(_) => write!(f, "custom://"), #[cfg(feature = "socks")] ProxyScheme::Socks5 { addr, @@ -1075,8 +1208,7 @@ mod tests { let (scheme, host) = match p.intercept(&url(s)).unwrap() { ProxyScheme::Http { host, .. } => ("http", host), ProxyScheme::Https { host, .. } => ("https", host), - #[cfg(feature = "socks")] - _ => panic!("intercepted as socks"), + _ => panic!("intercepted as not http or https"), }; http::Uri::builder() .scheme(scheme)