From d779733f2dc1e96dec4a1bba8002b9c084d278cf Mon Sep 17 00:00:00 2001 From: Mariusz Masztalerczuk Date: Thu, 26 Mar 2020 10:55:36 +0100 Subject: [PATCH] Support websocket proxies configured via environment variables PR #4661 Resolves #4648 Co-authored-by: Sviatoslav Sydorenko --- CHANGES/4648.bugfix | 1 + CONTRIBUTORS.txt | 1 + aiohttp/helpers.py | 65 ++++++++++++++++++++++++++++++---------- docs/client_advanced.rst | 4 +-- tests/test_helpers.py | 53 +++++++++++++++++--------------- 5 files changed, 82 insertions(+), 42 deletions(-) create mode 100644 CHANGES/4648.bugfix diff --git a/CHANGES/4648.bugfix b/CHANGES/4648.bugfix new file mode 100644 index 00000000000..094eb9d4925 --- /dev/null +++ b/CHANGES/4648.bugfix @@ -0,0 +1 @@ +Fix supporting WebSockets proxies configured via environment variables. diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index ad63ce9e4de..58626ee4dcd 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -186,6 +186,7 @@ Manuel Miranda Marat Sharafutdinov Marco Paolini Mariano Anaya +Mariusz Masztalerczuk Martijn Pieters Martin Melka Martin Richard diff --git a/aiohttp/helpers.py b/aiohttp/helpers.py index bbf5f1298fb..4a7ca01c19a 100644 --- a/aiohttp/helpers.py +++ b/aiohttp/helpers.py @@ -90,7 +90,8 @@ def all_tasks( # N.B. sys.flags.dev_mode is available on Python 3.7+, use getattr # for compatibility with older versions DEBUG = getattr(sys.flags, "dev_mode", False) or ( - not sys.flags.ignore_environment and bool(os.environ.get("PYTHONASYNCIODEBUG")) + not sys.flags.ignore_environment + and bool(os.environ.get("PYTHONASYNCIODEBUG")) ) # type: bool @@ -140,7 +141,9 @@ def __new__( raise ValueError("None is not allowed as password value") if ":" in login: - raise ValueError('A ":" is not allowed in login (RFC 1945#section-11.1)') + raise ValueError( + 'A ":" is not allowed in login (RFC 1945#section-11.1)' + ) return super().__new__(cls, login, password, encoding) @@ -174,7 +177,9 @@ def decode(cls, auth_header: str, encoding: str = "latin1") -> "BasicAuth": return cls(username, password, encoding=encoding) @classmethod - def from_url(cls, url: URL, *, encoding: str = "latin1") -> Optional["BasicAuth"]: + def from_url( + cls, url: URL, *, encoding: str = "latin1" + ) -> Optional["BasicAuth"]: """Create BasicAuth from url.""" if not isinstance(url, URL): raise TypeError("url should be yarl.URL instance") @@ -243,14 +248,22 @@ class ProxyInfo: def proxies_from_env() -> Dict[str, ProxyInfo]: - proxy_urls = {k: URL(v) for k, v in getproxies().items() if k in ("http", "https")} + proxy_urls = { + k: URL(v) + for k, v in getproxies().items() + if k in ("http", "https", "ws", "wss") + } netrc_obj = netrc_from_env() stripped = {k: strip_auth_from_url(v) for k, v in proxy_urls.items()} ret = {} for proto, val in stripped.items(): proxy, auth = val - if proxy.scheme == "https": - client_logger.warning("HTTPS proxies %s are not supported, ignoring", proxy) + if proxy.scheme in ("https", "wss"): + client_logger.warning( + "%s proxies %s are not supported, ignoring", + proxy.scheme.upper(), + proxy, + ) continue if netrc_obj and auth is None: auth_from_netrc = None @@ -289,7 +302,8 @@ def get_running_loop( ) if loop.get_debug(): internal_logger.warning( - "The object should be created within an async function", stack_info=True + "The object should be created within an async function", + stack_info=True, ) return loop @@ -327,7 +341,10 @@ def parse_mimetype(mimetype: str) -> MimeType: """ if not mimetype: return MimeType( - type="", subtype="", suffix="", parameters=MultiDictProxy(MultiDict()) + type="", + subtype="", + suffix="", + parameters=MultiDictProxy(MultiDict()), ) parts = mimetype.split(";") @@ -350,11 +367,16 @@ def parse_mimetype(mimetype: str) -> MimeType: else (fulltype, "") ) stype, suffix = ( - cast(Tuple[str, str], stype.split("+", 1)) if "+" in stype else (stype, "") + cast(Tuple[str, str], stype.split("+", 1)) + if "+" in stype + else (stype, "") ) return MimeType( - type=mtype, subtype=stype, suffix=suffix, parameters=MultiDictProxy(params) + type=mtype, + subtype=stype, + suffix=suffix, + parameters=MultiDictProxy(params), ) @@ -376,7 +398,9 @@ def content_disposition_header( params is a dict with disposition params. """ if not disptype or not (TOKEN > set(disptype)): - raise ValueError("bad content disposition type {!r}" "".format(disptype)) + raise ValueError( + "bad content disposition type {!r}" "".format(disptype) + ) value = disptype if params: @@ -384,7 +408,8 @@ def content_disposition_header( for key, val in params.items(): if not key or not (TOKEN > set(key)): raise ValueError( - "bad content disposition parameter" " {!r}={!r}".format(key, val) + "bad content disposition parameter" + " {!r}={!r}".format(key, val) ) qval = quote(val, "") if quote_fields else val lparams.append((key, '"%s"' % qval)) @@ -461,7 +486,9 @@ def __set__(self, inst: _TSelf, value: _T) -> None: def _is_ip_address( - regex: Pattern[str], regexb: Pattern[bytes], host: Optional[Union[str, bytes]] + regex: Pattern[str], + regexb: Pattern[bytes], + host: Optional[Union[str, bytes]], ) -> bool: if host is None: return False @@ -470,14 +497,18 @@ def _is_ip_address( elif isinstance(host, (bytes, bytearray, memoryview)): return bool(regexb.match(host)) else: - raise TypeError("{} [{}] is not a str or bytes".format(host, type(host))) + raise TypeError( + "{} [{}] is not a str or bytes".format(host, type(host)) + ) is_ipv4_address = functools.partial(_is_ip_address, _ipv4_regex, _ipv4_regexb) is_ipv6_address = functools.partial(_is_ip_address, _ipv6_regex, _ipv6_regexb) -def is_ip_address(host: Optional[Union[str, bytes, bytearray, memoryview]]) -> bool: +def is_ip_address( + host: Optional[Union[str, bytes, bytearray, memoryview]] +) -> bool: return is_ipv4_address(host) or is_ipv6_address(host) @@ -683,7 +714,9 @@ def __enter__(self) -> async_timeout.timeout: class HeadersMixin: - ATTRS = frozenset(["_content_type", "_content_dict", "_stored_content_type"]) + ATTRS = frozenset( + ["_content_type", "_content_dict", "_stored_content_type"] + ) _content_type = None # type: Optional[str] _content_dict = None # type: Optional[Dict[str, str]] diff --git a/docs/client_advanced.rst b/docs/client_advanced.rst index e4e0919c7f0..03365584cc8 100644 --- a/docs/client_advanced.rst +++ b/docs/client_advanced.rst @@ -549,8 +549,8 @@ Contrary to the ``requests`` library, it won't read environment variables by default. But you can do so by passing ``trust_env=True`` into :class:`aiohttp.ClientSession` constructor for extracting proxy configuration from -*HTTP_PROXY* or *HTTPS_PROXY* *environment variables* (both are case -insensitive):: +*HTTP_PROXY*, *HTTPS_PROXY*, *WS_PROXY* or *WSS_PROXY* *environment +variables* (all are case insensitive):: async with aiohttp.ClientSession(trust_env=True) as session: async with session.get("http://python.org") as resp: diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 3367c24b78a..f485a8eca4b 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -25,7 +25,10 @@ [ ("", helpers.MimeType("", "", "", MultiDict())), ("*", helpers.MimeType("*", "*", "", MultiDict())), - ("application/json", helpers.MimeType("application", "json", "", MultiDict())), + ( + "application/json", + helpers.MimeType("application", "json", "", MultiDict()), + ), ( "application/json; charset=utf-8", helpers.MimeType( @@ -147,15 +150,21 @@ def test_basic_auth_decode_invalid_credentials() -> None: (":", helpers.BasicAuth(login="", password="", encoding="latin1")), ( "username:", - helpers.BasicAuth(login="username", password="", encoding="latin1"), + helpers.BasicAuth( + login="username", password="", encoding="latin1" + ), ), ( ":password", - helpers.BasicAuth(login="", password="password", encoding="latin1"), + helpers.BasicAuth( + login="", password="password", encoding="latin1" + ), ), ( "username:password", - helpers.BasicAuth(login="username", password="password", encoding="latin1"), + helpers.BasicAuth( + login="username", password="password", encoding="latin1" + ), ), ), ) @@ -472,32 +481,28 @@ def test_set_content_disposition_bad_param() -> None: # --------------------- proxies_from_env ------------------------------ -def test_proxies_from_env_http(mocker) -> None: - url = URL("http://aiohttp.io/path") - mocker.patch.dict(os.environ, {"http_proxy": str(url)}) - ret = helpers.proxies_from_env() - assert ret.keys() == {"http"} - assert ret["http"].proxy == url - assert ret["http"].proxy_auth is None - - -def test_proxies_from_env_http_proxy_for_https_proto(mocker) -> None: +@pytest.mark.parametrize("protocol", ["http", "https", "ws", "wss"]) +def test_proxies_from_env(monkeypatch, protocol) -> None: url = URL("http://aiohttp.io/path") - mocker.patch.dict(os.environ, {"https_proxy": str(url)}) + monkeypatch.setenv(protocol + "_proxy", str(url)) ret = helpers.proxies_from_env() - assert ret.keys() == {"https"} - assert ret["https"].proxy == url - assert ret["https"].proxy_auth is None + assert ret.keys() == {protocol} + assert ret[protocol].proxy == url + assert ret[protocol].proxy_auth is None -def test_proxies_from_env_https_proxy_skipped(mocker) -> None: - url = URL("https://aiohttp.io/path") - mocker.patch.dict(os.environ, {"https_proxy": str(url)}) - log = mocker.patch("aiohttp.log.client_logger.warning") +@pytest.mark.parametrize("protocol", ["https", "wss"]) +def test_proxies_from_env_skipped(monkeypatch, caplog, protocol) -> None: + url = URL(protocol + "://aiohttp.io/path") + monkeypatch.setenv(protocol + "_proxy", str(url)) assert helpers.proxies_from_env() == {} - log.assert_called_with( - "HTTPS proxies %s are not supported, ignoring", URL("https://aiohttp.io/path") + assert len(caplog.records) == 1 + log_message = ( + "{proto!s} proxies {url!s} are not supported, ignoring".format( + proto=protocol.upper(), url=url + ) ) + assert caplog.record_tuples == [("aiohttp.client", 30, log_message)] def test_proxies_from_env_http_with_auth(mocker) -> None: