Skip to content

Commit

Permalink
Support websocket proxies configured via environment variables
Browse files Browse the repository at this point in the history
PR aio-libs#4661

Resolves aio-libs#4648

Co-authored-by: Sviatoslav Sydorenko <wk.cvs.github@sydorenko.org.ua>
  • Loading branch information
2 people authored and icamposrivera committed Oct 21, 2021
1 parent 184274d commit d779733
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 42 deletions.
1 change: 1 addition & 0 deletions CHANGES/4648.bugfix
@@ -0,0 +1 @@
Fix supporting WebSockets proxies configured via environment variables.
1 change: 1 addition & 0 deletions CONTRIBUTORS.txt
Expand Up @@ -186,6 +186,7 @@ Manuel Miranda
Marat Sharafutdinov
Marco Paolini
Mariano Anaya
Mariusz Masztalerczuk
Martijn Pieters
Martin Melka
Martin Richard
Expand Down
65 changes: 49 additions & 16 deletions aiohttp/helpers.py
Expand Up @@ -90,7 +90,8 @@ def all_tasks(
# N.B. sys.flags.dev_mode is available on Python 3.7+, use getattr
# for compatibility with older versions
DEBUG = getattr(sys.flags, "dev_mode", False) or (
not sys.flags.ignore_environment and bool(os.environ.get("PYTHONASYNCIODEBUG"))
not sys.flags.ignore_environment
and bool(os.environ.get("PYTHONASYNCIODEBUG"))
) # type: bool


Expand Down Expand Up @@ -140,7 +141,9 @@ def __new__(
raise ValueError("None is not allowed as password value")

if ":" in login:
raise ValueError('A ":" is not allowed in login (RFC 1945#section-11.1)')
raise ValueError(
'A ":" is not allowed in login (RFC 1945#section-11.1)'
)

return super().__new__(cls, login, password, encoding)

Expand Down Expand Up @@ -174,7 +177,9 @@ def decode(cls, auth_header: str, encoding: str = "latin1") -> "BasicAuth":
return cls(username, password, encoding=encoding)

@classmethod
def from_url(cls, url: URL, *, encoding: str = "latin1") -> Optional["BasicAuth"]:
def from_url(
cls, url: URL, *, encoding: str = "latin1"
) -> Optional["BasicAuth"]:
"""Create BasicAuth from url."""
if not isinstance(url, URL):
raise TypeError("url should be yarl.URL instance")
Expand Down Expand Up @@ -243,14 +248,22 @@ class ProxyInfo:


def proxies_from_env() -> Dict[str, ProxyInfo]:
proxy_urls = {k: URL(v) for k, v in getproxies().items() if k in ("http", "https")}
proxy_urls = {
k: URL(v)
for k, v in getproxies().items()
if k in ("http", "https", "ws", "wss")
}
netrc_obj = netrc_from_env()
stripped = {k: strip_auth_from_url(v) for k, v in proxy_urls.items()}
ret = {}
for proto, val in stripped.items():
proxy, auth = val
if proxy.scheme == "https":
client_logger.warning("HTTPS proxies %s are not supported, ignoring", proxy)
if proxy.scheme in ("https", "wss"):
client_logger.warning(
"%s proxies %s are not supported, ignoring",
proxy.scheme.upper(),
proxy,
)
continue
if netrc_obj and auth is None:
auth_from_netrc = None
Expand Down Expand Up @@ -289,7 +302,8 @@ def get_running_loop(
)
if loop.get_debug():
internal_logger.warning(
"The object should be created within an async function", stack_info=True
"The object should be created within an async function",
stack_info=True,
)
return loop

Expand Down Expand Up @@ -327,7 +341,10 @@ def parse_mimetype(mimetype: str) -> MimeType:
"""
if not mimetype:
return MimeType(
type="", subtype="", suffix="", parameters=MultiDictProxy(MultiDict())
type="",
subtype="",
suffix="",
parameters=MultiDictProxy(MultiDict()),
)

parts = mimetype.split(";")
Expand All @@ -350,11 +367,16 @@ def parse_mimetype(mimetype: str) -> MimeType:
else (fulltype, "")
)
stype, suffix = (
cast(Tuple[str, str], stype.split("+", 1)) if "+" in stype else (stype, "")
cast(Tuple[str, str], stype.split("+", 1))
if "+" in stype
else (stype, "")
)

return MimeType(
type=mtype, subtype=stype, suffix=suffix, parameters=MultiDictProxy(params)
type=mtype,
subtype=stype,
suffix=suffix,
parameters=MultiDictProxy(params),
)


Expand All @@ -376,15 +398,18 @@ def content_disposition_header(
params is a dict with disposition params.
"""
if not disptype or not (TOKEN > set(disptype)):
raise ValueError("bad content disposition type {!r}" "".format(disptype))
raise ValueError(
"bad content disposition type {!r}" "".format(disptype)
)

value = disptype
if params:
lparams = []
for key, val in params.items():
if not key or not (TOKEN > set(key)):
raise ValueError(
"bad content disposition parameter" " {!r}={!r}".format(key, val)
"bad content disposition parameter"
" {!r}={!r}".format(key, val)
)
qval = quote(val, "") if quote_fields else val
lparams.append((key, '"%s"' % qval))
Expand Down Expand Up @@ -461,7 +486,9 @@ def __set__(self, inst: _TSelf, value: _T) -> None:


def _is_ip_address(
regex: Pattern[str], regexb: Pattern[bytes], host: Optional[Union[str, bytes]]
regex: Pattern[str],
regexb: Pattern[bytes],
host: Optional[Union[str, bytes]],
) -> bool:
if host is None:
return False
Expand All @@ -470,14 +497,18 @@ def _is_ip_address(
elif isinstance(host, (bytes, bytearray, memoryview)):
return bool(regexb.match(host))
else:
raise TypeError("{} [{}] is not a str or bytes".format(host, type(host)))
raise TypeError(
"{} [{}] is not a str or bytes".format(host, type(host))
)


is_ipv4_address = functools.partial(_is_ip_address, _ipv4_regex, _ipv4_regexb)
is_ipv6_address = functools.partial(_is_ip_address, _ipv6_regex, _ipv6_regexb)


def is_ip_address(host: Optional[Union[str, bytes, bytearray, memoryview]]) -> bool:
def is_ip_address(
host: Optional[Union[str, bytes, bytearray, memoryview]]
) -> bool:
return is_ipv4_address(host) or is_ipv6_address(host)


Expand Down Expand Up @@ -683,7 +714,9 @@ def __enter__(self) -> async_timeout.timeout:

class HeadersMixin:

ATTRS = frozenset(["_content_type", "_content_dict", "_stored_content_type"])
ATTRS = frozenset(
["_content_type", "_content_dict", "_stored_content_type"]
)

_content_type = None # type: Optional[str]
_content_dict = None # type: Optional[Dict[str, str]]
Expand Down
4 changes: 2 additions & 2 deletions docs/client_advanced.rst
Expand Up @@ -549,8 +549,8 @@ Contrary to the ``requests`` library, it won't read environment
variables by default. But you can do so by passing
``trust_env=True`` into :class:`aiohttp.ClientSession`
constructor for extracting proxy configuration from
*HTTP_PROXY* or *HTTPS_PROXY* *environment variables* (both are case
insensitive)::
*HTTP_PROXY*, *HTTPS_PROXY*, *WS_PROXY* or *WSS_PROXY* *environment
variables* (all are case insensitive)::

async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get("http://python.org") as resp:
Expand Down
53 changes: 29 additions & 24 deletions tests/test_helpers.py
Expand Up @@ -25,7 +25,10 @@
[
("", helpers.MimeType("", "", "", MultiDict())),
("*", helpers.MimeType("*", "*", "", MultiDict())),
("application/json", helpers.MimeType("application", "json", "", MultiDict())),
(
"application/json",
helpers.MimeType("application", "json", "", MultiDict()),
),
(
"application/json; charset=utf-8",
helpers.MimeType(
Expand Down Expand Up @@ -147,15 +150,21 @@ def test_basic_auth_decode_invalid_credentials() -> None:
(":", helpers.BasicAuth(login="", password="", encoding="latin1")),
(
"username:",
helpers.BasicAuth(login="username", password="", encoding="latin1"),
helpers.BasicAuth(
login="username", password="", encoding="latin1"
),
),
(
":password",
helpers.BasicAuth(login="", password="password", encoding="latin1"),
helpers.BasicAuth(
login="", password="password", encoding="latin1"
),
),
(
"username:password",
helpers.BasicAuth(login="username", password="password", encoding="latin1"),
helpers.BasicAuth(
login="username", password="password", encoding="latin1"
),
),
),
)
Expand Down Expand Up @@ -472,32 +481,28 @@ def test_set_content_disposition_bad_param() -> None:
# --------------------- proxies_from_env ------------------------------


def test_proxies_from_env_http(mocker) -> None:
url = URL("http://aiohttp.io/path")
mocker.patch.dict(os.environ, {"http_proxy": str(url)})
ret = helpers.proxies_from_env()
assert ret.keys() == {"http"}
assert ret["http"].proxy == url
assert ret["http"].proxy_auth is None


def test_proxies_from_env_http_proxy_for_https_proto(mocker) -> None:
@pytest.mark.parametrize("protocol", ["http", "https", "ws", "wss"])
def test_proxies_from_env(monkeypatch, protocol) -> None:
url = URL("http://aiohttp.io/path")
mocker.patch.dict(os.environ, {"https_proxy": str(url)})
monkeypatch.setenv(protocol + "_proxy", str(url))
ret = helpers.proxies_from_env()
assert ret.keys() == {"https"}
assert ret["https"].proxy == url
assert ret["https"].proxy_auth is None
assert ret.keys() == {protocol}
assert ret[protocol].proxy == url
assert ret[protocol].proxy_auth is None


def test_proxies_from_env_https_proxy_skipped(mocker) -> None:
url = URL("https://aiohttp.io/path")
mocker.patch.dict(os.environ, {"https_proxy": str(url)})
log = mocker.patch("aiohttp.log.client_logger.warning")
@pytest.mark.parametrize("protocol", ["https", "wss"])
def test_proxies_from_env_skipped(monkeypatch, caplog, protocol) -> None:
url = URL(protocol + "://aiohttp.io/path")
monkeypatch.setenv(protocol + "_proxy", str(url))
assert helpers.proxies_from_env() == {}
log.assert_called_with(
"HTTPS proxies %s are not supported, ignoring", URL("https://aiohttp.io/path")
assert len(caplog.records) == 1
log_message = (
"{proto!s} proxies {url!s} are not supported, ignoring".format(
proto=protocol.upper(), url=url
)
)
assert caplog.record_tuples == [("aiohttp.client", 30, log_message)]


def test_proxies_from_env_http_with_auth(mocker) -> None:
Expand Down

0 comments on commit d779733

Please sign in to comment.